Parallelize tasks: query, ingest, search (#104)

This commit is contained in:
Adityavardhan Agrawal 2025-04-20 22:36:27 -07:00 committed by GitHub
parent bce0e1cfe1
commit 4b9869baf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -155,16 +155,11 @@ class DocumentService:
should_rerank = use_reranking if use_reranking is not None else settings.USE_RERANKING
using_colpali = use_colpali if use_colpali is not None else False
# Get embedding for query
query_embedding_regular = await self.embedding_model.embed_for_query(query)
query_embedding_multivector = (
await self.colpali_embedding_model.embed_for_query(query)
if (using_colpali and self.colpali_embedding_model)
else None
)
logger.info("Generated query embedding")
# Launch embedding queries concurrently
embedding_tasks = [self.embedding_model.embed_for_query(query)]
if using_colpali and self.colpali_embedding_model:
embedding_tasks.append(self.colpali_embedding_model.embed_for_query(query))
# Find authorized documents
# Build system filters for folder_name and end_user_id
system_filters = {}
if folder_name:
@ -172,7 +167,18 @@ class DocumentService:
if end_user_id:
system_filters["end_user_id"] = end_user_id
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters, system_filters)
# Run embeddings and document authorization in parallel
results = await asyncio.gather(
asyncio.gather(*embedding_tasks),
self.db.find_authorized_and_filtered_documents(auth, filters, system_filters),
)
embedding_results, doc_ids = results
query_embedding_regular = embedding_results[0]
query_embedding_multivector = embedding_results[1] if len(embedding_results) > 1 else None
logger.info("Generated query embedding")
if not doc_ids:
logger.info("No authorized documents found")
return []
@ -185,22 +191,28 @@ class DocumentService:
# For colpali reranking, we'll handle it in _combine_multi_and_regular_chunks
use_standard_reranker = should_rerank and (not search_multi) and self.reranker is not None
# Search chunks with vector similarity
# Search chunks with vector similarity in parallel
# When using standard reranker, we get more chunks initially to improve reranking quality
chunks = await self.vector_store.query_similar(
query_embedding_regular, k=10 * k if use_standard_reranker else k, doc_ids=doc_ids
)
search_tasks = [
self.vector_store.query_similar(
query_embedding_regular, k=10 * k if use_standard_reranker else k, doc_ids=doc_ids
)
]
chunks_multivector = (
await self.colpali_vector_store.query_similar(query_embedding_multivector, k=k, doc_ids=doc_ids)
if search_multi
else []
)
if search_multi:
search_tasks.append(
self.colpali_vector_store.query_similar(query_embedding_multivector, k=k, doc_ids=doc_ids)
)
search_results = await asyncio.gather(*search_tasks)
chunks = search_results[0]
chunks_multivector = search_results[1] if len(search_results) > 1 else []
logger.debug(f"Found {len(chunks)} similar chunks via regular embedding")
if using_colpali:
logger.debug(
f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali"
f"Found {len(chunks_multivector)} similar chunks via multivector embedding "
f"since we are also using colpali"
)
# Rerank chunks using the standard reranker if enabled and available
@ -273,7 +285,8 @@ class DocumentService:
model_name,
torch_dtype=torch.bfloat16,
device_map=device, # "cuda:0", # or "mps" if on Apple Silicon
attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps"
attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None,
# or "eager" if "mps"
).eval()
processor = ColIdefics3Processor.from_pretrained(model_name)
@ -386,15 +399,24 @@ class DocumentService:
# Create list of (document_id, chunk_number) tuples for vector store query
chunk_identifiers = [(source.document_id, source.chunk_number) for source in authorized_sources]
# Retrieve the chunks from vector store in a single query
chunks = await self.vector_store.get_chunks_by_id(chunk_identifiers)
# Set up vector store retrieval tasks
retrieval_tasks = [self.vector_store.get_chunks_by_id(chunk_identifiers)]
# Check if we should use colpali for image chunks
# Add colpali vector store task if needed
if use_colpali and self.colpali_vector_store:
logger.info("Trying to retrieve chunks from colpali vector store")
try:
# Also try to retrieve from the colpali vector store
colpali_chunks = await self.colpali_vector_store.get_chunks_by_id(chunk_identifiers)
logger.info("Preparing to retrieve chunks from both regular and colpali vector stores")
retrieval_tasks.append(self.colpali_vector_store.get_chunks_by_id(chunk_identifiers))
# Execute vector store retrievals in parallel
try:
vector_results = await asyncio.gather(*retrieval_tasks, return_exceptions=True)
# Process regular chunks
chunks = vector_results[0] if not isinstance(vector_results[0], Exception) else []
# Process colpali chunks if available
if len(vector_results) > 1 and not isinstance(vector_results[1], Exception):
colpali_chunks = vector_results[1]
if colpali_chunks:
# Create a dictionary of (doc_id, chunk_number) -> chunk for fast lookup
@ -402,9 +424,6 @@ class DocumentService:
logger.debug(f"Found {len(colpali_chunks)} chunks in colpali store")
for colpali_chunk in colpali_chunks:
logger.debug(
f"Found colpali chunk: doc={colpali_chunk.document_id}, chunk={colpali_chunk.chunk_number}"
)
key = (colpali_chunk.document_id, colpali_chunk.chunk_number)
# Replace chunks with colpali chunks when available
chunk_dict[key] = colpali_chunk
@ -412,10 +431,18 @@ class DocumentService:
# Update chunks list with the combined/replaced chunks
chunks = list(chunk_dict.values())
logger.info(f"Enhanced {len(colpali_chunks)} chunks with colpali/multimodal data")
else:
logger.warning("No chunks found in colpali vector store")
except Exception as e:
logger.error(f"Error retrieving chunks from colpali vector store: {e}", exc_info=True)
# Handle any exceptions that occurred during retrieval
for i, result in enumerate(vector_results):
if isinstance(result, Exception):
store_type = "regular" if i == 0 else "colpali"
logger.error(f"Error retrieving chunks from {store_type} vector store: {result}", exc_info=True)
if i == 0: # If regular store failed, we can't proceed
return []
except Exception as e:
logger.error(f"Error during parallel chunk retrieval: {e}", exc_info=True)
return []
# Convert to chunk results
results = await self._create_chunk_results(auth, chunks)
@ -933,138 +960,124 @@ class DocumentService:
max_retries = 3
retry_delay = 1.0
# Store chunks in vector store with retry
attempt = 0
success = False
result = None
while attempt < max_retries and not success:
try:
success, result = await self.vector_store.store_embeddings(chunk_objects)
if not success:
raise Exception("Failed to store chunk embeddings")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(
f"Database connection error during embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s..."
)
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(
f"All database connection attempts failed after {max_retries} retries: {error_msg}"
)
raise Exception("Failed to store chunk embeddings after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing embeddings: {error_msg}")
raise
logger.debug("Stored chunk embeddings in vector store")
doc.chunk_ids = result
if use_colpali and self.colpali_vector_store and chunk_objects_multivector:
# Reset retry variables for colpali storage
# Helper function to store embeddings with retry
async def store_with_retry(store, objects, store_name="regular"):
attempt = 0
retry_delay = 1.0
success = False
result_multivector = None
result = None
current_retry_delay = retry_delay
while attempt < max_retries and not success:
try:
success, result_multivector = await self.colpali_vector_store.store_embeddings(
chunk_objects_multivector
)
success, result = await store.store_embeddings(objects)
if not success:
raise Exception("Failed to store multivector chunk embeddings")
break
raise Exception(f"Failed to store {store_name} chunk embeddings")
return result
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(
f"Database connection error during colpali embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s..."
f"Database connection error during {store_name} embeddings storage "
f"(attempt {attempt}/{max_retries}): {error_msg}. "
f"Retrying in {current_retry_delay}s..."
)
await asyncio.sleep(retry_delay)
await asyncio.sleep(current_retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
current_retry_delay *= 2
else:
logger.error(
f"All colpali database connection attempts failed after {max_retries} retries: {error_msg}"
f"All {store_name} database connection attempts failed "
f"after {max_retries} retries: {error_msg}"
)
raise Exception("Failed to store multivector chunk embeddings after multiple retries")
raise Exception(f"Failed to store {store_name} chunk embeddings after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing colpali embeddings: {error_msg}")
logger.error(f"Error storing {store_name} embeddings: {error_msg}")
raise
logger.debug("Stored multivector chunk embeddings in vector store")
doc.chunk_ids += result_multivector
# Store document metadata with retry
attempt = 0
retry_delay = 1.0
success = False
async def store_document_with_retry():
attempt = 0
success = False
current_retry_delay = retry_delay
while attempt < max_retries and not success:
try:
if is_update and auth:
# For updates, use update_document, serialize StorageFileInfo into plain dicts
updates = {
"chunk_ids": doc.chunk_ids,
"metadata": doc.metadata,
"system_metadata": doc.system_metadata,
"filename": doc.filename,
"content_type": doc.content_type,
"storage_info": doc.storage_info,
"storage_files": (
[
(
file.model_dump()
if hasattr(file, "model_dump")
else (file.dict() if hasattr(file, "dict") else file)
)
for file in doc.storage_files
]
if doc.storage_files
else []
),
}
success = await self.db.update_document(doc.external_id, updates, auth)
if not success:
raise Exception("Failed to update document metadata")
else:
# For new documents, use store_document
success = await self.db.store_document(doc)
if not success:
raise Exception("Failed to store document metadata")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(
f"Database connection error during document metadata storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s..."
)
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
while attempt < max_retries and not success:
try:
if is_update and auth:
# For updates, use update_document, serialize StorageFileInfo into plain dicts
updates = {
"chunk_ids": doc.chunk_ids,
"metadata": doc.metadata,
"system_metadata": doc.system_metadata,
"filename": doc.filename,
"content_type": doc.content_type,
"storage_info": doc.storage_info,
"storage_files": (
[
(
file.model_dump()
if hasattr(file, "model_dump")
else (file.dict() if hasattr(file, "dict") else file)
)
for file in doc.storage_files
]
if doc.storage_files
else []
),
}
success = await self.db.update_document(doc.external_id, updates, auth)
if not success:
raise Exception("Failed to update document metadata")
else:
logger.error(
f"All database connection attempts failed after {max_retries} retries: {error_msg}"
)
raise Exception("Failed to store document metadata after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing document metadata: {error_msg}")
raise
# For new documents, use store_document
success = await self.db.store_document(doc)
if not success:
raise Exception("Failed to store document metadata")
return success
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(
f"Database connection error during document metadata storage "
f"(attempt {attempt}/{max_retries}): {error_msg}. "
f"Retrying in {current_retry_delay}s..."
)
await asyncio.sleep(current_retry_delay)
# Increase delay for next retry (exponential backoff)
current_retry_delay *= 2
else:
logger.error(
f"All database connection attempts failed " f"after {max_retries} retries: {error_msg}"
)
raise Exception("Failed to store document metadata after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing document metadata: {error_msg}")
raise
# Run storage operations in parallel when possible
storage_tasks = [store_with_retry(self.vector_store, chunk_objects, "regular")]
# Add colpali storage task if needed
if use_colpali and self.colpali_vector_store and chunk_objects_multivector:
storage_tasks.append(store_with_retry(self.colpali_vector_store, chunk_objects_multivector, "colpali"))
# Execute storage tasks concurrently
storage_results = await asyncio.gather(*storage_tasks)
# Combine chunk IDs
regular_chunk_ids = storage_results[0]
colpali_chunk_ids = storage_results[1] if len(storage_results) > 1 else []
doc.chunk_ids = regular_chunk_ids + colpali_chunk_ids
logger.debug(f"Stored chunk embeddings in vector stores: {len(doc.chunk_ids)} chunks total")
# Store document metadata (this must be done after chunk storage)
await store_document_with_retry()
logger.debug("Stored document metadata in database")
logger.debug(f"Chunk IDs stored: {doc.chunk_ids}")
@ -1073,21 +1086,37 @@ class DocumentService:
async def _create_chunk_results(self, auth: AuthContext, chunks: List[DocumentChunk]) -> List[ChunkResult]:
"""Create ChunkResult objects with document metadata."""
results = []
if not chunks:
logger.info("No chunks provided, returning empty results")
return results
# Collect all unique document IDs from chunks
unique_doc_ids = list({chunk.document_id for chunk in chunks})
# Fetch all required documents in a single batch query
docs = await self.batch_retrieve_documents(unique_doc_ids, auth)
# Create a lookup dictionary of documents by ID
doc_map = {doc.external_id: doc for doc in docs}
logger.debug(f"Retrieved metadata for {len(doc_map)} unique documents in a single batch")
# Generate download URLs for all documents that have storage info
download_urls = {}
for doc_id, doc in doc_map.items():
if doc.storage_info:
download_urls[doc_id] = await self.storage.get_download_url(
doc.storage_info["bucket"], doc.storage_info["key"]
)
logger.debug(f"Generated download URL for document {doc_id}")
# Create chunk results using the lookup dictionaries
for chunk in chunks:
# Get document metadata
doc = await self.db.get_document(chunk.document_id, auth)
doc = doc_map.get(chunk.document_id)
if not doc:
logger.warning(f"Document {chunk.document_id} not found")
continue
logger.debug(f"Retrieved metadata for document {chunk.document_id}")
# Generate download URL if needed
download_url = None
if doc.storage_info:
download_url = await self.storage.get_download_url(doc.storage_info["bucket"], doc.storage_info["key"])
logger.debug(f"Generated download URL for document {chunk.document_id}")
metadata = doc.metadata
metadata = doc.metadata.copy()
metadata["is_image"] = chunk.metadata.get("is_image", False)
results.append(
ChunkResult(
@ -1098,7 +1127,7 @@ class DocumentService:
metadata=metadata,
content_type=doc.content_type,
filename=doc.filename,
download_url=download_url,
download_url=download_urls.get(chunk.document_id),
)
)
@ -1107,31 +1136,53 @@ class DocumentService:
async def _create_document_results(self, auth: AuthContext, chunks: List[ChunkResult]) -> Dict[str, DocumentResult]:
"""Group chunks by document and create DocumentResult objects."""
if not chunks:
logger.info("No chunks provided, returning empty results")
return {}
# Group chunks by document and get highest scoring chunk per doc
doc_chunks: Dict[str, ChunkResult] = {}
for chunk in chunks:
if chunk.document_id not in doc_chunks or chunk.score > doc_chunks[chunk.document_id].score:
doc_chunks[chunk.document_id] = chunk
logger.info(f"Grouped chunks into {len(doc_chunks)} documents")
logger.debug(f"Document chunks: {doc_chunks}")
# Get unique document IDs
unique_doc_ids = list(doc_chunks.keys())
# Fetch all documents in a single batch query
docs = await self.batch_retrieve_documents(unique_doc_ids, auth)
# Create a lookup dictionary of documents by ID
doc_map = {doc.external_id: doc for doc in docs}
logger.debug(f"Retrieved metadata for {len(doc_map)} unique documents in a single batch")
# Generate download URLs for non-text documents in a single loop
download_urls = {}
for doc_id, doc in doc_map.items():
if doc.content_type != "text/plain" and doc.storage_info:
download_urls[doc_id] = await self.storage.get_download_url(
doc.storage_info["bucket"], doc.storage_info["key"]
)
logger.debug(f"Generated download URL for document {doc_id}")
# Create document results using the lookup dictionaries
results = {}
for doc_id, chunk in doc_chunks.items():
# Get document metadata
doc = await self.db.get_document(doc_id, auth)
doc = doc_map.get(doc_id)
if not doc:
logger.warning(f"Document {doc_id} not found")
continue
logger.info(f"Retrieved metadata for document {doc_id}")
# Create DocumentContent based on content type
if doc.content_type == "text/plain":
content = DocumentContent(type="string", value=chunk.content, filename=None)
logger.debug(f"Created text content for document {doc_id}")
else:
# Generate download URL for file types
download_url = await self.storage.get_download_url(doc.storage_info["bucket"], doc.storage_info["key"])
content = DocumentContent(type="url", value=download_url, filename=doc.filename)
# Use pre-generated download URL for file types
content = DocumentContent(type="url", value=download_urls.get(doc_id), filename=doc.filename)
logger.debug(f"Created URL content for document {doc_id}")
results[doc_id] = DocumentResult(
score=chunk.score,
document_id=doc_id,
@ -1287,7 +1338,8 @@ class DocumentService:
else:
updated_content = self._apply_update_strategy(current_content, update_content, update_strategy)
logger.info(
f"Applied update strategy '{update_strategy}': original length={len(current_content)}, new length={len(updated_content)}"
f"Applied update strategy '{update_strategy}': original length={len(current_content)}, "
f"new length={len(updated_content)}"
)
# Always update the content in system_metadata
@ -1532,9 +1584,7 @@ class DocumentService:
else (
StorageFileInfo(**file.model_dump())
if hasattr(file, "model_dump")
else StorageFileInfo(**file.dict())
if hasattr(file, "dict")
else file
else StorageFileInfo(**file.dict()) if hasattr(file, "dict") else file
)
)
)
@ -1735,53 +1785,56 @@ class DocumentService:
logger.info(f"Deleted document {document_id} from database")
# Try to delete chunks from vector store if they exist
# Collect storage deletion tasks
storage_deletion_tasks = []
# Collect vector store deletion tasks
vector_deletion_tasks = []
# Add vector store deletion tasks if chunks exist
if hasattr(document, "chunk_ids") and document.chunk_ids:
try:
# Try to delete chunks by document ID
# Note: Some vector stores may not implement this method
if hasattr(self.vector_store, "delete_chunks_by_document_id"):
await self.vector_store.delete_chunks_by_document_id(document_id)
logger.info(f"Deleted chunks for document {document_id} from vector store")
else:
logger.warning("Vector store does not support deleting chunks by document ID")
# Try to delete chunks by document ID
# Note: Some vector stores may not implement this method
if hasattr(self.vector_store, "delete_chunks_by_document_id"):
vector_deletion_tasks.append(self.vector_store.delete_chunks_by_document_id(document_id))
# Try to delete from colpali vector store as well
if self.colpali_vector_store and hasattr(self.colpali_vector_store, "delete_chunks_by_document_id"):
await self.colpali_vector_store.delete_chunks_by_document_id(document_id)
logger.info(f"Deleted chunks for document {document_id} from colpali vector store")
except Exception as e:
logger.error(f"Error deleting chunks for document {document_id}: {e}")
# We continue even if chunk deletion fails - don't block document deletion
# Try to delete from colpali vector store as well
if self.colpali_vector_store and hasattr(self.colpali_vector_store, "delete_chunks_by_document_id"):
vector_deletion_tasks.append(self.colpali_vector_store.delete_chunks_by_document_id(document_id))
# Delete file from storage if it exists
# Collect storage file deletion tasks
if hasattr(document, "storage_info") and document.storage_info:
try:
bucket = document.storage_info.get("bucket")
key = document.storage_info.get("key")
if bucket and key:
# Check if the storage provider supports deletion
if hasattr(self.storage, "delete_file"):
await self.storage.delete_file(bucket, key)
logger.info(
f"Deleted file for document {document_id} from storage (bucket: {bucket}, key: {key})"
)
else:
logger.warning("Storage provider does not support file deletion")
bucket = document.storage_info.get("bucket")
key = document.storage_info.get("key")
if bucket and key and hasattr(self.storage, "delete_file"):
storage_deletion_tasks.append(self.storage.delete_file(bucket, key))
# Also handle the case of multiple file versions in storage_files
if hasattr(document, "storage_files") and document.storage_files:
for file_info in document.storage_files:
bucket = file_info.get("bucket")
key = file_info.get("key")
if bucket and key and hasattr(self.storage, "delete_file"):
storage_deletion_tasks.append(self.storage.delete_file(bucket, key))
# Execute deletion tasks in parallel
if vector_deletion_tasks or storage_deletion_tasks:
try:
# Run all deletion tasks concurrently
all_deletion_results = await asyncio.gather(
*vector_deletion_tasks, *storage_deletion_tasks, return_exceptions=True
)
# Log any errors but continue with deletion
for i, result in enumerate(all_deletion_results):
if isinstance(result, Exception):
# Determine if this was a vector store or storage deletion
task_type = "vector store" if i < len(vector_deletion_tasks) else "storage"
logger.error(f"Error during {task_type} deletion for document {document_id}: {result}")
# Also handle the case of multiple file versions in storage_files
if hasattr(document, "storage_files") and document.storage_files:
for file_info in document.storage_files:
bucket = file_info.get("bucket")
key = file_info.get("key")
if bucket and key and hasattr(self.storage, "delete_file"):
await self.storage.delete_file(bucket, key)
logger.info(
f"Deleted file version for document {document_id} from storage (bucket: {bucket}, key: {key})"
)
except Exception as e:
logger.error(f"Error deleting file for document {document_id}: {e}")
# We continue even if file deletion fails - don't block document deletion
logger.error(f"Error during parallel deletion operations for document {document_id}: {e}")
# We continue even if deletions fail - document is already deleted from DB
logger.info(f"Successfully deleted document {document_id} and all associated data")
return True