Speed up multivector store

Co-authored-by: Arnav Agrawal <aa779@cornell.edu>
This commit is contained in:
Adityavardhan Agrawal 2025-04-20 20:02:42 -07:00
parent 89633ce761
commit bce0e1cfe1

View File

@ -181,7 +181,8 @@ class MultiVectorStore(BaseVectorStore):
""" """
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$ CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
WITH queries AS ( WITH queries AS (
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo SELECT row_number() OVER () AS query_number, *
FROM (SELECT unnest(query) AS query) AS foo
), ),
documents AS ( documents AS (
SELECT unnest(document) AS document SELECT unnest(document) AS document
@ -189,7 +190,8 @@ class MultiVectorStore(BaseVectorStore):
similarities AS ( similarities AS (
SELECT SELECT
query_number, query_number,
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity 1.0 - (bit_count(document # query)::float /
greatest(bit_length(query), 1)::float) AS similarity
FROM queries CROSS JOIN documents FROM queries CROSS JOIN documents
), ),
max_similarities AS ( max_similarities AS (
@ -217,10 +219,6 @@ class MultiVectorStore(BaseVectorStore):
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray): if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
embeddings = np.array(embeddings) embeddings = np.array(embeddings)
# Add this check to ensure pgvector is registered for the connection
with self.get_connection() as conn:
register_vector(conn)
return [Bit(embedding > 0) for embedding in embeddings] return [Bit(embedding > 0) for embedding in embeddings]
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]: async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
@ -230,21 +228,21 @@ class MultiVectorStore(BaseVectorStore):
return True, [] return True, []
stored_ids = [] stored_ids = []
with self.get_connection() as conn:
for chunk in chunks:
# Ensure embeddings exist
if not hasattr(chunk, "embedding") or chunk.embedding is None:
logger.error(f"Missing embeddings for chunk {chunk.document_id}-{chunk.chunk_number}")
continue
for chunk in chunks: # For multi-vector embeddings, we expect a list of vectors
# Ensure embeddings exist embeddings = chunk.embedding
if not hasattr(chunk, "embedding") or chunk.embedding is None:
logger.error(f"Missing embeddings for chunk {chunk.document_id}-{chunk.chunk_number}")
continue
# For multi-vector embeddings, we expect a list of vectors # Create binary representation for each vector
embeddings = chunk.embedding binary_embeddings = self._binary_quantize(embeddings)
# Create binary representation for each vector # Insert into database with retry logic
binary_embeddings = self._binary_quantize(embeddings)
# Insert into database with retry logic
with self.get_connection() as conn:
conn.execute( conn.execute(
""" """
INSERT INTO multi_vector_embeddings INSERT INTO multi_vector_embeddings
@ -260,7 +258,7 @@ class MultiVectorStore(BaseVectorStore):
), ),
) )
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}") stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
logger.debug(f"{len(stored_ids)} vector embeddings added successfully!") logger.debug(f"{len(stored_ids)} vector embeddings added successfully!")
return len(stored_ids) > 0, stored_ids return len(stored_ids) > 0, stored_ids