mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Speed up multivector store
Co-authored-by: Arnav Agrawal <aa779@cornell.edu>
This commit is contained in:
parent
89633ce761
commit
bce0e1cfe1
@ -181,7 +181,8 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
"""
|
"""
|
||||||
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
|
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
|
||||||
WITH queries AS (
|
WITH queries AS (
|
||||||
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
|
SELECT row_number() OVER () AS query_number, *
|
||||||
|
FROM (SELECT unnest(query) AS query) AS foo
|
||||||
),
|
),
|
||||||
documents AS (
|
documents AS (
|
||||||
SELECT unnest(document) AS document
|
SELECT unnest(document) AS document
|
||||||
@ -189,7 +190,8 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
similarities AS (
|
similarities AS (
|
||||||
SELECT
|
SELECT
|
||||||
query_number,
|
query_number,
|
||||||
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
|
1.0 - (bit_count(document # query)::float /
|
||||||
|
greatest(bit_length(query), 1)::float) AS similarity
|
||||||
FROM queries CROSS JOIN documents
|
FROM queries CROSS JOIN documents
|
||||||
),
|
),
|
||||||
max_similarities AS (
|
max_similarities AS (
|
||||||
@ -217,10 +219,6 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
|
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
|
||||||
embeddings = np.array(embeddings)
|
embeddings = np.array(embeddings)
|
||||||
|
|
||||||
# Add this check to ensure pgvector is registered for the connection
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
register_vector(conn)
|
|
||||||
|
|
||||||
return [Bit(embedding > 0) for embedding in embeddings]
|
return [Bit(embedding > 0) for embedding in embeddings]
|
||||||
|
|
||||||
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
||||||
@ -230,21 +228,21 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
return True, []
|
return True, []
|
||||||
|
|
||||||
stored_ids = []
|
stored_ids = []
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
for chunk in chunks:
|
||||||
|
# Ensure embeddings exist
|
||||||
|
if not hasattr(chunk, "embedding") or chunk.embedding is None:
|
||||||
|
logger.error(f"Missing embeddings for chunk {chunk.document_id}-{chunk.chunk_number}")
|
||||||
|
continue
|
||||||
|
|
||||||
for chunk in chunks:
|
# For multi-vector embeddings, we expect a list of vectors
|
||||||
# Ensure embeddings exist
|
embeddings = chunk.embedding
|
||||||
if not hasattr(chunk, "embedding") or chunk.embedding is None:
|
|
||||||
logger.error(f"Missing embeddings for chunk {chunk.document_id}-{chunk.chunk_number}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# For multi-vector embeddings, we expect a list of vectors
|
# Create binary representation for each vector
|
||||||
embeddings = chunk.embedding
|
binary_embeddings = self._binary_quantize(embeddings)
|
||||||
|
|
||||||
# Create binary representation for each vector
|
# Insert into database with retry logic
|
||||||
binary_embeddings = self._binary_quantize(embeddings)
|
|
||||||
|
|
||||||
# Insert into database with retry logic
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO multi_vector_embeddings
|
INSERT INTO multi_vector_embeddings
|
||||||
@ -260,7 +258,7 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
|
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
|
||||||
|
|
||||||
logger.debug(f"{len(stored_ids)} vector embeddings added successfully!")
|
logger.debug(f"{len(stored_ids)} vector embeddings added successfully!")
|
||||||
return len(stored_ids) > 0, stored_ids
|
return len(stored_ids) > 0, stored_ids
|
||||||
|
Loading…
x
Reference in New Issue
Block a user