From 4b9869baf7c788a473ced13d1013855c7efbea9b Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Sun, 20 Apr 2025 22:36:27 -0700 Subject: [PATCH] Parallelize tasks: query, ingest, search (#104) --- core/services/document_service.py | 465 +++++++++++++++++------------- 1 file changed, 259 insertions(+), 206 deletions(-) diff --git a/core/services/document_service.py b/core/services/document_service.py index 5ebf674..fc4fdae 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -155,16 +155,11 @@ class DocumentService: should_rerank = use_reranking if use_reranking is not None else settings.USE_RERANKING using_colpali = use_colpali if use_colpali is not None else False - # Get embedding for query - query_embedding_regular = await self.embedding_model.embed_for_query(query) - query_embedding_multivector = ( - await self.colpali_embedding_model.embed_for_query(query) - if (using_colpali and self.colpali_embedding_model) - else None - ) - logger.info("Generated query embedding") + # Launch embedding queries concurrently + embedding_tasks = [self.embedding_model.embed_for_query(query)] + if using_colpali and self.colpali_embedding_model: + embedding_tasks.append(self.colpali_embedding_model.embed_for_query(query)) - # Find authorized documents # Build system filters for folder_name and end_user_id system_filters = {} if folder_name: @@ -172,7 +167,18 @@ class DocumentService: if end_user_id: system_filters["end_user_id"] = end_user_id - doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters, system_filters) + # Run embeddings and document authorization in parallel + results = await asyncio.gather( + asyncio.gather(*embedding_tasks), + self.db.find_authorized_and_filtered_documents(auth, filters, system_filters), + ) + + embedding_results, doc_ids = results + query_embedding_regular = embedding_results[0] + query_embedding_multivector = embedding_results[1] if len(embedding_results) > 1 else None + + logger.info("Generated query embedding") + if not doc_ids: logger.info("No authorized documents found") return [] @@ -185,22 +191,28 @@ class DocumentService: # For colpali reranking, we'll handle it in _combine_multi_and_regular_chunks use_standard_reranker = should_rerank and (not search_multi) and self.reranker is not None - # Search chunks with vector similarity + # Search chunks with vector similarity in parallel # When using standard reranker, we get more chunks initially to improve reranking quality - chunks = await self.vector_store.query_similar( - query_embedding_regular, k=10 * k if use_standard_reranker else k, doc_ids=doc_ids - ) + search_tasks = [ + self.vector_store.query_similar( + query_embedding_regular, k=10 * k if use_standard_reranker else k, doc_ids=doc_ids + ) + ] - chunks_multivector = ( - await self.colpali_vector_store.query_similar(query_embedding_multivector, k=k, doc_ids=doc_ids) - if search_multi - else [] - ) + if search_multi: + search_tasks.append( + self.colpali_vector_store.query_similar(query_embedding_multivector, k=k, doc_ids=doc_ids) + ) + + search_results = await asyncio.gather(*search_tasks) + chunks = search_results[0] + chunks_multivector = search_results[1] if len(search_results) > 1 else [] logger.debug(f"Found {len(chunks)} similar chunks via regular embedding") if using_colpali: logger.debug( - f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali" + f"Found {len(chunks_multivector)} similar chunks via multivector embedding " + f"since we are also using colpali" ) # Rerank chunks using the standard reranker if enabled and available @@ -273,7 +285,8 @@ class DocumentService: model_name, torch_dtype=torch.bfloat16, device_map=device, # "cuda:0", # or "mps" if on Apple Silicon - attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps" + attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, + # or "eager" if "mps" ).eval() processor = ColIdefics3Processor.from_pretrained(model_name) @@ -386,15 +399,24 @@ class DocumentService: # Create list of (document_id, chunk_number) tuples for vector store query chunk_identifiers = [(source.document_id, source.chunk_number) for source in authorized_sources] - # Retrieve the chunks from vector store in a single query - chunks = await self.vector_store.get_chunks_by_id(chunk_identifiers) + # Set up vector store retrieval tasks + retrieval_tasks = [self.vector_store.get_chunks_by_id(chunk_identifiers)] - # Check if we should use colpali for image chunks + # Add colpali vector store task if needed if use_colpali and self.colpali_vector_store: - logger.info("Trying to retrieve chunks from colpali vector store") - try: - # Also try to retrieve from the colpali vector store - colpali_chunks = await self.colpali_vector_store.get_chunks_by_id(chunk_identifiers) + logger.info("Preparing to retrieve chunks from both regular and colpali vector stores") + retrieval_tasks.append(self.colpali_vector_store.get_chunks_by_id(chunk_identifiers)) + + # Execute vector store retrievals in parallel + try: + vector_results = await asyncio.gather(*retrieval_tasks, return_exceptions=True) + + # Process regular chunks + chunks = vector_results[0] if not isinstance(vector_results[0], Exception) else [] + + # Process colpali chunks if available + if len(vector_results) > 1 and not isinstance(vector_results[1], Exception): + colpali_chunks = vector_results[1] if colpali_chunks: # Create a dictionary of (doc_id, chunk_number) -> chunk for fast lookup @@ -402,9 +424,6 @@ class DocumentService: logger.debug(f"Found {len(colpali_chunks)} chunks in colpali store") for colpali_chunk in colpali_chunks: - logger.debug( - f"Found colpali chunk: doc={colpali_chunk.document_id}, chunk={colpali_chunk.chunk_number}" - ) key = (colpali_chunk.document_id, colpali_chunk.chunk_number) # Replace chunks with colpali chunks when available chunk_dict[key] = colpali_chunk @@ -412,10 +431,18 @@ class DocumentService: # Update chunks list with the combined/replaced chunks chunks = list(chunk_dict.values()) logger.info(f"Enhanced {len(colpali_chunks)} chunks with colpali/multimodal data") - else: - logger.warning("No chunks found in colpali vector store") - except Exception as e: - logger.error(f"Error retrieving chunks from colpali vector store: {e}", exc_info=True) + + # Handle any exceptions that occurred during retrieval + for i, result in enumerate(vector_results): + if isinstance(result, Exception): + store_type = "regular" if i == 0 else "colpali" + logger.error(f"Error retrieving chunks from {store_type} vector store: {result}", exc_info=True) + if i == 0: # If regular store failed, we can't proceed + return [] + + except Exception as e: + logger.error(f"Error during parallel chunk retrieval: {e}", exc_info=True) + return [] # Convert to chunk results results = await self._create_chunk_results(auth, chunks) @@ -933,138 +960,124 @@ class DocumentService: max_retries = 3 retry_delay = 1.0 - # Store chunks in vector store with retry - attempt = 0 - success = False - result = None - - while attempt < max_retries and not success: - try: - success, result = await self.vector_store.store_embeddings(chunk_objects) - if not success: - raise Exception("Failed to store chunk embeddings") - break - except Exception as e: - attempt += 1 - error_msg = str(e) - if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg: - if attempt < max_retries: - logger.warning( - f"Database connection error during embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s..." - ) - await asyncio.sleep(retry_delay) - # Increase delay for next retry (exponential backoff) - retry_delay *= 2 - else: - logger.error( - f"All database connection attempts failed after {max_retries} retries: {error_msg}" - ) - raise Exception("Failed to store chunk embeddings after multiple retries") - else: - # For other exceptions, don't retry - logger.error(f"Error storing embeddings: {error_msg}") - raise - - logger.debug("Stored chunk embeddings in vector store") - doc.chunk_ids = result - - if use_colpali and self.colpali_vector_store and chunk_objects_multivector: - # Reset retry variables for colpali storage + # Helper function to store embeddings with retry + async def store_with_retry(store, objects, store_name="regular"): attempt = 0 - retry_delay = 1.0 success = False - result_multivector = None + result = None + current_retry_delay = retry_delay while attempt < max_retries and not success: try: - success, result_multivector = await self.colpali_vector_store.store_embeddings( - chunk_objects_multivector - ) + success, result = await store.store_embeddings(objects) if not success: - raise Exception("Failed to store multivector chunk embeddings") - break + raise Exception(f"Failed to store {store_name} chunk embeddings") + return result except Exception as e: attempt += 1 error_msg = str(e) if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg: if attempt < max_retries: logger.warning( - f"Database connection error during colpali embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s..." + f"Database connection error during {store_name} embeddings storage " + f"(attempt {attempt}/{max_retries}): {error_msg}. " + f"Retrying in {current_retry_delay}s..." ) - await asyncio.sleep(retry_delay) + await asyncio.sleep(current_retry_delay) # Increase delay for next retry (exponential backoff) - retry_delay *= 2 + current_retry_delay *= 2 else: logger.error( - f"All colpali database connection attempts failed after {max_retries} retries: {error_msg}" + f"All {store_name} database connection attempts failed " + f"after {max_retries} retries: {error_msg}" ) - raise Exception("Failed to store multivector chunk embeddings after multiple retries") + raise Exception(f"Failed to store {store_name} chunk embeddings after multiple retries") else: # For other exceptions, don't retry - logger.error(f"Error storing colpali embeddings: {error_msg}") + logger.error(f"Error storing {store_name} embeddings: {error_msg}") raise - logger.debug("Stored multivector chunk embeddings in vector store") - doc.chunk_ids += result_multivector - # Store document metadata with retry - attempt = 0 - retry_delay = 1.0 - success = False + async def store_document_with_retry(): + attempt = 0 + success = False + current_retry_delay = retry_delay - while attempt < max_retries and not success: - try: - if is_update and auth: - # For updates, use update_document, serialize StorageFileInfo into plain dicts - updates = { - "chunk_ids": doc.chunk_ids, - "metadata": doc.metadata, - "system_metadata": doc.system_metadata, - "filename": doc.filename, - "content_type": doc.content_type, - "storage_info": doc.storage_info, - "storage_files": ( - [ - ( - file.model_dump() - if hasattr(file, "model_dump") - else (file.dict() if hasattr(file, "dict") else file) - ) - for file in doc.storage_files - ] - if doc.storage_files - else [] - ), - } - success = await self.db.update_document(doc.external_id, updates, auth) - if not success: - raise Exception("Failed to update document metadata") - else: - # For new documents, use store_document - success = await self.db.store_document(doc) - if not success: - raise Exception("Failed to store document metadata") - break - except Exception as e: - attempt += 1 - error_msg = str(e) - if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg: - if attempt < max_retries: - logger.warning( - f"Database connection error during document metadata storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s..." - ) - await asyncio.sleep(retry_delay) - # Increase delay for next retry (exponential backoff) - retry_delay *= 2 + while attempt < max_retries and not success: + try: + if is_update and auth: + # For updates, use update_document, serialize StorageFileInfo into plain dicts + updates = { + "chunk_ids": doc.chunk_ids, + "metadata": doc.metadata, + "system_metadata": doc.system_metadata, + "filename": doc.filename, + "content_type": doc.content_type, + "storage_info": doc.storage_info, + "storage_files": ( + [ + ( + file.model_dump() + if hasattr(file, "model_dump") + else (file.dict() if hasattr(file, "dict") else file) + ) + for file in doc.storage_files + ] + if doc.storage_files + else [] + ), + } + success = await self.db.update_document(doc.external_id, updates, auth) + if not success: + raise Exception("Failed to update document metadata") else: - logger.error( - f"All database connection attempts failed after {max_retries} retries: {error_msg}" - ) - raise Exception("Failed to store document metadata after multiple retries") - else: - # For other exceptions, don't retry - logger.error(f"Error storing document metadata: {error_msg}") - raise + # For new documents, use store_document + success = await self.db.store_document(doc) + if not success: + raise Exception("Failed to store document metadata") + return success + except Exception as e: + attempt += 1 + error_msg = str(e) + if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg: + if attempt < max_retries: + logger.warning( + f"Database connection error during document metadata storage " + f"(attempt {attempt}/{max_retries}): {error_msg}. " + f"Retrying in {current_retry_delay}s..." + ) + await asyncio.sleep(current_retry_delay) + # Increase delay for next retry (exponential backoff) + current_retry_delay *= 2 + else: + logger.error( + f"All database connection attempts failed " f"after {max_retries} retries: {error_msg}" + ) + raise Exception("Failed to store document metadata after multiple retries") + else: + # For other exceptions, don't retry + logger.error(f"Error storing document metadata: {error_msg}") + raise + + # Run storage operations in parallel when possible + storage_tasks = [store_with_retry(self.vector_store, chunk_objects, "regular")] + + # Add colpali storage task if needed + if use_colpali and self.colpali_vector_store and chunk_objects_multivector: + storage_tasks.append(store_with_retry(self.colpali_vector_store, chunk_objects_multivector, "colpali")) + + # Execute storage tasks concurrently + storage_results = await asyncio.gather(*storage_tasks) + + # Combine chunk IDs + regular_chunk_ids = storage_results[0] + colpali_chunk_ids = storage_results[1] if len(storage_results) > 1 else [] + doc.chunk_ids = regular_chunk_ids + colpali_chunk_ids + + logger.debug(f"Stored chunk embeddings in vector stores: {len(doc.chunk_ids)} chunks total") + + # Store document metadata (this must be done after chunk storage) + await store_document_with_retry() logger.debug("Stored document metadata in database") logger.debug(f"Chunk IDs stored: {doc.chunk_ids}") @@ -1073,21 +1086,37 @@ class DocumentService: async def _create_chunk_results(self, auth: AuthContext, chunks: List[DocumentChunk]) -> List[ChunkResult]: """Create ChunkResult objects with document metadata.""" results = [] + if not chunks: + logger.info("No chunks provided, returning empty results") + return results + + # Collect all unique document IDs from chunks + unique_doc_ids = list({chunk.document_id for chunk in chunks}) + + # Fetch all required documents in a single batch query + docs = await self.batch_retrieve_documents(unique_doc_ids, auth) + + # Create a lookup dictionary of documents by ID + doc_map = {doc.external_id: doc for doc in docs} + logger.debug(f"Retrieved metadata for {len(doc_map)} unique documents in a single batch") + + # Generate download URLs for all documents that have storage info + download_urls = {} + for doc_id, doc in doc_map.items(): + if doc.storage_info: + download_urls[doc_id] = await self.storage.get_download_url( + doc.storage_info["bucket"], doc.storage_info["key"] + ) + logger.debug(f"Generated download URL for document {doc_id}") + + # Create chunk results using the lookup dictionaries for chunk in chunks: - # Get document metadata - doc = await self.db.get_document(chunk.document_id, auth) + doc = doc_map.get(chunk.document_id) if not doc: logger.warning(f"Document {chunk.document_id} not found") continue - logger.debug(f"Retrieved metadata for document {chunk.document_id}") - # Generate download URL if needed - download_url = None - if doc.storage_info: - download_url = await self.storage.get_download_url(doc.storage_info["bucket"], doc.storage_info["key"]) - logger.debug(f"Generated download URL for document {chunk.document_id}") - - metadata = doc.metadata + metadata = doc.metadata.copy() metadata["is_image"] = chunk.metadata.get("is_image", False) results.append( ChunkResult( @@ -1098,7 +1127,7 @@ class DocumentService: metadata=metadata, content_type=doc.content_type, filename=doc.filename, - download_url=download_url, + download_url=download_urls.get(chunk.document_id), ) ) @@ -1107,31 +1136,53 @@ class DocumentService: async def _create_document_results(self, auth: AuthContext, chunks: List[ChunkResult]) -> Dict[str, DocumentResult]: """Group chunks by document and create DocumentResult objects.""" + if not chunks: + logger.info("No chunks provided, returning empty results") + return {} + # Group chunks by document and get highest scoring chunk per doc doc_chunks: Dict[str, ChunkResult] = {} for chunk in chunks: if chunk.document_id not in doc_chunks or chunk.score > doc_chunks[chunk.document_id].score: doc_chunks[chunk.document_id] = chunk logger.info(f"Grouped chunks into {len(doc_chunks)} documents") - logger.debug(f"Document chunks: {doc_chunks}") + + # Get unique document IDs + unique_doc_ids = list(doc_chunks.keys()) + + # Fetch all documents in a single batch query + docs = await self.batch_retrieve_documents(unique_doc_ids, auth) + + # Create a lookup dictionary of documents by ID + doc_map = {doc.external_id: doc for doc in docs} + logger.debug(f"Retrieved metadata for {len(doc_map)} unique documents in a single batch") + + # Generate download URLs for non-text documents in a single loop + download_urls = {} + for doc_id, doc in doc_map.items(): + if doc.content_type != "text/plain" and doc.storage_info: + download_urls[doc_id] = await self.storage.get_download_url( + doc.storage_info["bucket"], doc.storage_info["key"] + ) + logger.debug(f"Generated download URL for document {doc_id}") + + # Create document results using the lookup dictionaries results = {} for doc_id, chunk in doc_chunks.items(): - # Get document metadata - doc = await self.db.get_document(doc_id, auth) + doc = doc_map.get(doc_id) if not doc: logger.warning(f"Document {doc_id} not found") continue - logger.info(f"Retrieved metadata for document {doc_id}") # Create DocumentContent based on content type if doc.content_type == "text/plain": content = DocumentContent(type="string", value=chunk.content, filename=None) logger.debug(f"Created text content for document {doc_id}") else: - # Generate download URL for file types - download_url = await self.storage.get_download_url(doc.storage_info["bucket"], doc.storage_info["key"]) - content = DocumentContent(type="url", value=download_url, filename=doc.filename) + # Use pre-generated download URL for file types + content = DocumentContent(type="url", value=download_urls.get(doc_id), filename=doc.filename) logger.debug(f"Created URL content for document {doc_id}") + results[doc_id] = DocumentResult( score=chunk.score, document_id=doc_id, @@ -1287,7 +1338,8 @@ class DocumentService: else: updated_content = self._apply_update_strategy(current_content, update_content, update_strategy) logger.info( - f"Applied update strategy '{update_strategy}': original length={len(current_content)}, new length={len(updated_content)}" + f"Applied update strategy '{update_strategy}': original length={len(current_content)}, " + f"new length={len(updated_content)}" ) # Always update the content in system_metadata @@ -1532,9 +1584,7 @@ class DocumentService: else ( StorageFileInfo(**file.model_dump()) if hasattr(file, "model_dump") - else StorageFileInfo(**file.dict()) - if hasattr(file, "dict") - else file + else StorageFileInfo(**file.dict()) if hasattr(file, "dict") else file ) ) ) @@ -1735,53 +1785,56 @@ class DocumentService: logger.info(f"Deleted document {document_id} from database") - # Try to delete chunks from vector store if they exist + # Collect storage deletion tasks + storage_deletion_tasks = [] + + # Collect vector store deletion tasks + vector_deletion_tasks = [] + + # Add vector store deletion tasks if chunks exist if hasattr(document, "chunk_ids") and document.chunk_ids: - try: - # Try to delete chunks by document ID - # Note: Some vector stores may not implement this method - if hasattr(self.vector_store, "delete_chunks_by_document_id"): - await self.vector_store.delete_chunks_by_document_id(document_id) - logger.info(f"Deleted chunks for document {document_id} from vector store") - else: - logger.warning("Vector store does not support deleting chunks by document ID") + # Try to delete chunks by document ID + # Note: Some vector stores may not implement this method + if hasattr(self.vector_store, "delete_chunks_by_document_id"): + vector_deletion_tasks.append(self.vector_store.delete_chunks_by_document_id(document_id)) - # Try to delete from colpali vector store as well - if self.colpali_vector_store and hasattr(self.colpali_vector_store, "delete_chunks_by_document_id"): - await self.colpali_vector_store.delete_chunks_by_document_id(document_id) - logger.info(f"Deleted chunks for document {document_id} from colpali vector store") - except Exception as e: - logger.error(f"Error deleting chunks for document {document_id}: {e}") - # We continue even if chunk deletion fails - don't block document deletion + # Try to delete from colpali vector store as well + if self.colpali_vector_store and hasattr(self.colpali_vector_store, "delete_chunks_by_document_id"): + vector_deletion_tasks.append(self.colpali_vector_store.delete_chunks_by_document_id(document_id)) - # Delete file from storage if it exists + # Collect storage file deletion tasks if hasattr(document, "storage_info") and document.storage_info: - try: - bucket = document.storage_info.get("bucket") - key = document.storage_info.get("key") - if bucket and key: - # Check if the storage provider supports deletion - if hasattr(self.storage, "delete_file"): - await self.storage.delete_file(bucket, key) - logger.info( - f"Deleted file for document {document_id} from storage (bucket: {bucket}, key: {key})" - ) - else: - logger.warning("Storage provider does not support file deletion") + bucket = document.storage_info.get("bucket") + key = document.storage_info.get("key") + if bucket and key and hasattr(self.storage, "delete_file"): + storage_deletion_tasks.append(self.storage.delete_file(bucket, key)) + + # Also handle the case of multiple file versions in storage_files + if hasattr(document, "storage_files") and document.storage_files: + for file_info in document.storage_files: + bucket = file_info.get("bucket") + key = file_info.get("key") + if bucket and key and hasattr(self.storage, "delete_file"): + storage_deletion_tasks.append(self.storage.delete_file(bucket, key)) + + # Execute deletion tasks in parallel + if vector_deletion_tasks or storage_deletion_tasks: + try: + # Run all deletion tasks concurrently + all_deletion_results = await asyncio.gather( + *vector_deletion_tasks, *storage_deletion_tasks, return_exceptions=True + ) + + # Log any errors but continue with deletion + for i, result in enumerate(all_deletion_results): + if isinstance(result, Exception): + # Determine if this was a vector store or storage deletion + task_type = "vector store" if i < len(vector_deletion_tasks) else "storage" + logger.error(f"Error during {task_type} deletion for document {document_id}: {result}") - # Also handle the case of multiple file versions in storage_files - if hasattr(document, "storage_files") and document.storage_files: - for file_info in document.storage_files: - bucket = file_info.get("bucket") - key = file_info.get("key") - if bucket and key and hasattr(self.storage, "delete_file"): - await self.storage.delete_file(bucket, key) - logger.info( - f"Deleted file version for document {document_id} from storage (bucket: {bucket}, key: {key})" - ) except Exception as e: - logger.error(f"Error deleting file for document {document_id}: {e}") - # We continue even if file deletion fails - don't block document deletion + logger.error(f"Error during parallel deletion operations for document {document_id}: {e}") + # We continue even if deletions fail - document is already deleted from DB logger.info(f"Successfully deleted document {document_id} and all associated data") return True