diff --git a/core/vector_store/multi_vector_store.py b/core/vector_store/multi_vector_store.py index db44960..9d504f0 100644 --- a/core/vector_store/multi_vector_store.py +++ b/core/vector_store/multi_vector_store.py @@ -181,7 +181,8 @@ class MultiVectorStore(BaseVectorStore): """ CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$ WITH queries AS ( - SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo + SELECT row_number() OVER () AS query_number, * + FROM (SELECT unnest(query) AS query) AS foo ), documents AS ( SELECT unnest(document) AS document @@ -189,7 +190,8 @@ class MultiVectorStore(BaseVectorStore): similarities AS ( SELECT query_number, - 1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity + 1.0 - (bit_count(document # query)::float / + greatest(bit_length(query), 1)::float) AS similarity FROM queries CROSS JOIN documents ), max_similarities AS ( @@ -217,10 +219,6 @@ class MultiVectorStore(BaseVectorStore): if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray): embeddings = np.array(embeddings) - # Add this check to ensure pgvector is registered for the connection - with self.get_connection() as conn: - register_vector(conn) - return [Bit(embedding > 0) for embedding in embeddings] async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]: @@ -230,21 +228,21 @@ class MultiVectorStore(BaseVectorStore): return True, [] stored_ids = [] + with self.get_connection() as conn: + for chunk in chunks: + # Ensure embeddings exist + if not hasattr(chunk, "embedding") or chunk.embedding is None: + logger.error(f"Missing embeddings for chunk {chunk.document_id}-{chunk.chunk_number}") + continue - for chunk in chunks: - # Ensure embeddings exist - if not hasattr(chunk, "embedding") or chunk.embedding is None: - logger.error(f"Missing embeddings for chunk {chunk.document_id}-{chunk.chunk_number}") - continue + # For multi-vector embeddings, we expect a list of vectors + embeddings = chunk.embedding - # For multi-vector embeddings, we expect a list of vectors - embeddings = chunk.embedding + # Create binary representation for each vector + binary_embeddings = self._binary_quantize(embeddings) - # Create binary representation for each vector - binary_embeddings = self._binary_quantize(embeddings) + # Insert into database with retry logic - # Insert into database with retry logic - with self.get_connection() as conn: conn.execute( """ INSERT INTO multi_vector_embeddings @@ -260,7 +258,7 @@ class MultiVectorStore(BaseVectorStore): ), ) - stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}") + stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}") logger.debug(f"{len(stored_ids)} vector embeddings added successfully!") return len(stored_ids) > 0, stored_ids