From 12bf22419111bb9a7daf5f4652193faf4c9d265f Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Wed, 30 Apr 2025 19:42:43 -0700 Subject: [PATCH] Add batching for gpus and performance logging (#124) --- core/embedding/colpali_embedding_model.py | 212 ++++++++++++++++++++-- core/services/document_service.py | 4 +- core/workers/ingestion_worker.py | 83 +++++++++ scripts/migrate_multivector_embeddings.py | 73 ++++++++ sdks/python/pyproject.toml | 2 +- 5 files changed, 351 insertions(+), 23 deletions(-) create mode 100644 scripts/migrate_multivector_embeddings.py diff --git a/core/embedding/colpali_embedding_model.py b/core/embedding/colpali_embedding_model.py index 1b6ac93..56564ff 100644 --- a/core/embedding/colpali_embedding_model.py +++ b/core/embedding/colpali_embedding_model.py @@ -1,14 +1,16 @@ import base64 import io import logging -from typing import List, Union +import time +from typing import List, Tuple, Union import numpy as np import torch -from colpali_engine.models import ColQwen2, ColQwen2Processor +from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor from PIL.Image import Image from PIL.Image import open as open_image +from core.config import get_settings from core.embedding.base_embedding_model import BaseEmbeddingModel from core.models.chunk import Chunk @@ -18,51 +20,221 @@ logger = logging.getLogger(__name__) class ColpaliEmbeddingModel(BaseEmbeddingModel): def __init__(self): device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" - self.model = ColQwen2.from_pretrained( - "vidore/colqwen2-v1.0", + logger.info(f"Initializing ColpaliEmbeddingModel with device: {device}") + start_time = time.time() + self.model = ColQwen2_5.from_pretrained( + "tsystems/colqwen2.5-3b-multilingual-v1.0", torch_dtype=torch.bfloat16, device_map=device, # Automatically detect and use available device attn_implementation="flash_attention_2" if device == "cuda" else "eager", ).eval() - self.processor: ColQwen2Processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") + self.processor: ColQwen2_5_Processor = ColQwen2_5_Processor.from_pretrained( + "tsystems/colqwen2.5-3b-multilingual-v1.0" + ) + self.settings = get_settings() + self.mode = self.settings.MODE + # Set batch size based on mode + self.batch_size = 8 if self.mode == "cloud" else 1 + logger.info(f"Colpali running in mode: {self.mode} with batch size: {self.batch_size}") + total_init_time = time.time() - start_time + logger.info(f"Colpali initialization time: {total_init_time:.2f} seconds") async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[np.ndarray]: + job_start_time = time.time() if isinstance(chunks, Chunk): chunks = [chunks] - contents = [] - for chunk in chunks: + if not chunks: + return [] + + logger.info( + f"Processing {len(chunks)} chunks for Colpali embedding in {self.mode} mode (batch size: {self.batch_size})" + ) + + image_items: List[Tuple[int, Image]] = [] + text_items: List[Tuple[int, str]] = [] + sorting_start = time.time() + + for index, chunk in enumerate(chunks): if chunk.metadata.get("is_image"): try: - # Handle data URI format "data:image/png;base64,..." content = chunk.content if content.startswith("data:"): - # Extract the base64 part after the comma content = content.split(",", 1)[1] - - # Now decode the base64 string image_bytes = base64.b64decode(content) image = open_image(io.BytesIO(image_bytes)) - contents.append(image) + image_items.append((index, image)) except Exception as e: - logger.error(f"Error processing image: {str(e)}") - # Fall back to using the content as text - contents.append(chunk.content) + logger.error(f"Error processing image chunk {index}: {str(e)}. Falling back to text.") + text_items.append((index, chunk.content)) # Fallback: treat content as text else: - contents.append(chunk.content) + text_items.append((index, chunk.content)) - return [await self.generate_embeddings(content) for content in contents] + sorting_time = time.time() - sorting_start + logger.info( + f"Chunk sorting took {sorting_time:.2f}s - " + f"Found {len(image_items)} images and {len(text_items)} text chunks" + ) + + # Initialize results array to preserve order + results: List[np.ndarray | None] = [None] * len(chunks) + + # Process image batches + if image_items: + img_start = time.time() + indices_to_process = [item[0] for item in image_items] + images_to_process = [item[1] for item in image_items] + for i in range(0, len(images_to_process), self.batch_size): + batch_indices = indices_to_process[i : i + self.batch_size] + batch_images = images_to_process[i : i + self.batch_size] + logger.debug( + f"Processing image batch {i//self.batch_size + 1}/" + f"{(len(images_to_process)-1)//self.batch_size + 1} with {len(batch_images)} images" + ) + batch_start = time.time() + batch_embeddings = await self.generate_embeddings_batch_images(batch_images) + # Place embeddings in the correct position in results + for original_index, embedding in zip(batch_indices, batch_embeddings): + results[original_index] = embedding + batch_time = time.time() - batch_start + logger.debug( + f"Image batch {i//self.batch_size + 1} processing took {batch_time:.2f}s " + f"({batch_time/len(batch_images):.2f}s per image)" + ) + img_time = time.time() - img_start + logger.info(f"All image embedding took {img_time:.2f}s ({img_time/len(images_to_process):.2f}s per image)") + + # Process text batches + if text_items: + text_start = time.time() + indices_to_process = [item[0] for item in text_items] + texts_to_process = [item[1] for item in text_items] + for i in range(0, len(texts_to_process), self.batch_size): + batch_indices = indices_to_process[i : i + self.batch_size] + batch_texts = texts_to_process[i : i + self.batch_size] + logger.debug( + f"Processing text batch {i//self.batch_size + 1}/" + f"{(len(texts_to_process)-1)//self.batch_size + 1} with {len(batch_texts)} texts" + ) + batch_start = time.time() + batch_embeddings = await self.generate_embeddings_batch_texts(batch_texts) + # Place embeddings in the correct position in results + for original_index, embedding in zip(batch_indices, batch_embeddings): + results[original_index] = embedding + batch_time = time.time() - batch_start + logger.debug( + f"Text batch {i//self.batch_size + 1} processing took {batch_time:.2f}s " + f"({batch_time/len(batch_texts):.2f}s per text)" + ) + text_time = time.time() - text_start + logger.info(f"All text embedding took {text_time:.2f}s ({text_time/len(texts_to_process):.2f}s per text)") + + # Ensure all chunks were processed (handle potential None entries if errors occurred, + # though unlikely with fallback) + final_results = [res for res in results if res is not None] + if len(final_results) != len(chunks): + logger.warning( + f"Number of embeddings ({len(final_results)}) does not match number of chunks " + f"({len(chunks)}). Some chunks might have failed." + ) + # Fill potential gaps if necessary, though the current logic should cover all chunks + # For safety, let's reconstruct based on successfully processed indices, though it shouldn't be needed + processed_indices = {idx for idx, _ in image_items} | {idx for idx, _ in text_items} + if len(processed_indices) != len(chunks): + logger.error("Mismatch in processed indices vs original chunks count. This indicates a logic error.") + # Assuming results contains embeddings at correct original indices, filter out Nones + final_results = [results[i] for i in range(len(chunks)) if results[i] is not None] + + total_time = time.time() - job_start_time + logger.info( + f"Total Colpali embed_for_ingestion took {total_time:.2f}s for {len(chunks)} chunks " + f"({total_time/len(chunks) if chunks else 0:.2f}s per chunk)" + ) + # Cast is safe because we filter out Nones, though Nones shouldn't occur with the fallback logic + return final_results # type: ignore async def embed_for_query(self, text: str) -> torch.Tensor: - return await self.generate_embeddings(text) + start_time = time.time() + result = await self.generate_embeddings(text) + elapsed = time.time() - start_time + logger.info(f"Colpali query embedding took {elapsed:.2f}s") + return result - async def generate_embeddings(self, content: str | Image) -> np.ndarray: + async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray: + start_time = time.time() + content_type = "image" if isinstance(content, Image) else "text" + process_start = time.time() if isinstance(content, Image): processed = self.processor.process_images([content]).to(self.model.device) else: processed = self.processor.process_queries([content]).to(self.model.device) + process_time = time.time() - process_start + + model_start = time.time() + with torch.no_grad(): embeddings: torch.Tensor = self.model(**processed) - return embeddings.to(torch.float32).numpy(force=True)[0] + model_time = time.time() - model_start + + convert_start = time.time() + + result = embeddings.to(torch.float32).numpy(force=True)[0] + + convert_time = time.time() - convert_start + + total_time = time.time() - start_time + logger.debug( + f"Generate embeddings ({content_type}): process={process_time:.2f}s, model={model_time:.2f}s, " + f"convert={convert_time:.2f}s, total={total_time:.2f}s" + ) + return result + + # ---- Batch processing methods (only used in 'cloud' mode) ---- + + async def generate_embeddings_batch_images(self, images: List[Image]) -> List[np.ndarray]: + batch_start_time = time.time() + process_start = time.time() + processed_images = self.processor.process_images(images).to(self.model.device) + process_time = time.time() - process_start + + model_start = time.time() + with torch.no_grad(): + image_embeddings = self.model(**processed_images) + model_time = time.time() - model_start + + convert_start = time.time() + image_embeddings_np = image_embeddings.to(torch.float32).numpy(force=True) + result = [emb for emb in image_embeddings_np] + convert_time = time.time() - convert_start + + total_batch_time = time.time() - batch_start_time + logger.debug( + f"Batch images ({len(images)}): process={process_time:.2f}s, model={model_time:.2f}s, " + f"convert={convert_time:.2f}s, total={total_batch_time:.2f}s ({total_batch_time/len(images):.3f}s/image)" + ) + return result + + async def generate_embeddings_batch_texts(self, texts: List[str]) -> List[np.ndarray]: + batch_start_time = time.time() + process_start = time.time() + processed_texts = self.processor.process_queries(texts).to(self.model.device) + process_time = time.time() - process_start + + model_start = time.time() + with torch.no_grad(): + text_embeddings = self.model(**processed_texts) + model_time = time.time() - model_start + + convert_start = time.time() + text_embeddings_np = text_embeddings.to(torch.float32).numpy(force=True) + result = [emb for emb in text_embeddings_np] + convert_time = time.time() - convert_start + + total_batch_time = time.time() - batch_start_time + logger.debug( + f"Batch texts ({len(texts)}): process={process_time:.2f}s, model={model_time:.2f}s, " + f"convert={convert_time:.2f}s, total={total_batch_time:.2f}s ({total_batch_time/len(texts):.3f}s/text)" + ) + return result diff --git a/core/services/document_service.py b/core/services/document_service.py index f652b65..97bac92 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -1778,8 +1778,8 @@ class DocumentService: # Also handle the case of multiple file versions in storage_files if hasattr(document, "storage_files") and document.storage_files: for file_info in document.storage_files: - bucket = file_info.get("bucket") - key = file_info.get("key") + bucket = file_info.bucket + key = file_info.key if bucket and key and hasattr(self.storage, "delete_file"): storage_deletion_tasks.append(self.storage.delete_file(bucket, key)) diff --git a/core/workers/ingestion_worker.py b/core/workers/ingestion_worker.py index d1b4bc7..8d833ea 100644 --- a/core/workers/ingestion_worker.py +++ b/core/workers/ingestion_worker.py @@ -2,6 +2,7 @@ import asyncio import json import logging import os +import time import urllib.parse as up from datetime import UTC, datetime from typing import Any, Dict, List, Optional @@ -26,6 +27,19 @@ from core.vector_store.pgvector_store import PGVectorStore logger = logging.getLogger(__name__) +# Configure logger for ingestion worker (restored from diff) +settings = get_settings() # Need settings for log level potentially, though INFO used here + +# Create logs directory if it doesn't exist +os.makedirs("logs", exist_ok=True) + +# Set up file handler for worker_ingestion.log +file_handler = logging.FileHandler("logs/worker_ingestion.log") +file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) +logger.addHandler(file_handler) +# Set logger level based on settings (diff used INFO directly) +logger.setLevel(logging.INFO) + async def get_document_with_retry(document_service, document_id, auth, max_retries=3, initial_delay=0.3): """ @@ -116,10 +130,14 @@ async def process_ingestion_job( A dictionary with the document ID and processing status """ try: + # Start performance timer + job_start_time = time.time() + phase_times = {} # 1. Log the start of the job logger.info(f"Starting ingestion job for file: {original_filename}") # 2. Deserialize metadata and auth + deserialize_start = time.time() metadata = json.loads(metadata_json) if metadata_json else {} auth = AuthContext( entity_type=EntityType(auth_dict.get("entity_type", "unknown")), @@ -128,23 +146,32 @@ async def process_ingestion_job( permissions=set(auth_dict.get("permissions", ["read"])), user_id=auth_dict.get("user_id", auth_dict.get("entity_id", "")), ) + phase_times["deserialize_auth"] = time.time() - deserialize_start # Get document service from the context document_service: DocumentService = ctx["document_service"] # 3. Download the file from storage logger.info(f"Downloading file from {bucket}/{file_key}") + download_start = time.time() file_content = await document_service.storage.download_file(bucket, file_key) # Ensure file_content is bytes if hasattr(file_content, "read"): file_content = file_content.read() + download_time = time.time() - download_start + phase_times["download_file"] = download_time + logger.info(f"File download took {download_time:.2f}s for {len(file_content)/1024/1024:.2f}MB") # 4. Parse file to text + parse_start = time.time() additional_metadata, text = await document_service.parser.parse_file_to_text(file_content, original_filename) logger.debug(f"Parsed file into text of length {len(text)}") + parse_time = time.time() - parse_start + phase_times["parse_file"] = parse_time # === Apply post_parsing rules === + rules_start = time.time() document_rule_metadata = {} if rules_list: logger.info("Applying post-parsing rules...") @@ -154,8 +181,13 @@ async def process_ingestion_job( metadata.update(document_rule_metadata) # Merge metadata into main doc metadata logger.info(f"Document metadata after post-parsing rules: {metadata}") logger.info(f"Content length after post-parsing rules: {len(text)}") + rules_time = time.time() - rules_start + phase_times["apply_post_parsing_rules"] = rules_time + if rules_list: + logger.info(f"Post-parsing rules processing took {rules_time:.2f}s") # 6. Retrieve the existing document + retrieve_start = time.time() logger.debug(f"Retrieving document with ID: {document_id}") logger.debug( f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}" @@ -163,6 +195,9 @@ async def process_ingestion_job( # Use the retry helper function with initial delay to handle race conditions doc = await get_document_with_retry(document_service, document_id, auth, max_retries=5, initial_delay=1.0) + retrieve_time = time.time() - retrieve_start + phase_times["retrieve_document"] = retrieve_time + logger.info(f"Document retrieval took {retrieve_time:.2f}s") if not doc: logger.error(f"Document {document_id} not found in database after multiple retries") @@ -193,7 +228,11 @@ async def process_ingestion_job( updates["system_metadata"]["end_user_id"] = end_user_id # Update the document in the database + update_start = time.time() success = await document_service.db.update_document(document_id=document_id, updates=updates, auth=auth) + update_time = time.time() - update_start + phase_times["update_document_parsed"] = update_time + logger.info(f"Initial document update took {update_time:.2f}s") if not success: raise ValueError(f"Failed to update document {document_id}") @@ -203,13 +242,18 @@ async def process_ingestion_job( logger.debug("Updated document in database with parsed content") # 7. Split text into chunks + chunking_start = time.time() parsed_chunks = await document_service.parser.split_text(text) if not parsed_chunks: raise ValueError("No content chunks extracted after rules processing") logger.debug(f"Split processed text into {len(parsed_chunks)} chunks") + chunking_time = time.time() - chunking_start + phase_times["split_into_chunks"] = chunking_time + logger.info(f"Text chunking took {chunking_time:.2f}s to create {len(parsed_chunks)} chunks") # 8. Handle ColPali embeddings if enabled - IMPORTANT: Do this BEFORE applying chunk rules # so that image chunks can be processed by rules when use_images=True + colpali_processing_start = time.time() using_colpali = ( use_colpali and document_service.colpali_embedding_model and document_service.colpali_vector_store ) @@ -227,6 +271,10 @@ async def process_ingestion_job( file_type, file_content_base64, file_content, parsed_chunks ) logger.debug(f"Created {len(chunks_multivector)} chunks for multivector embedding") + colpali_create_chunks_time = time.time() - colpali_processing_start + phase_times["colpali_create_chunks"] = colpali_create_chunks_time + if using_colpali: + logger.info(f"Colpali chunk creation took {colpali_create_chunks_time:.2f}s") # 9. Apply post_chunking rules and aggregate metadata processed_chunks = [] @@ -302,14 +350,27 @@ async def process_ingestion_job( processed_chunks_multivector = chunks_multivector # No rules, use original multivector chunks # 10. Generate embeddings for processed chunks + embedding_start = time.time() embeddings = await document_service.embedding_model.embed_for_ingestion(processed_chunks) logger.debug(f"Generated {len(embeddings)} embeddings") + embedding_time = time.time() - embedding_start + phase_times["generate_embeddings"] = embedding_time + embeddings_per_second = len(embeddings) / embedding_time if embedding_time > 0 else 0 + logger.info( + f"Embedding generation took {embedding_time:.2f}s for {len(embeddings)} embeddings " + f"({embeddings_per_second:.2f} embeddings/s)" + ) # 11. Create chunk objects with potentially modified chunk content and metadata + chunk_objects_start = time.time() chunk_objects = document_service._create_chunk_objects(doc.external_id, processed_chunks, embeddings) logger.debug(f"Created {len(chunk_objects)} chunk objects") + chunk_objects_time = time.time() - chunk_objects_start + phase_times["create_chunk_objects"] = chunk_objects_time + logger.debug(f"Creating chunk objects took {chunk_objects_time:.2f}s") # 12. Handle ColPali embeddings + colpali_embed_start = time.time() chunk_objects_multivector = [] if using_colpali: colpali_embeddings = await document_service.colpali_embedding_model.embed_for_ingestion( @@ -320,6 +381,14 @@ async def process_ingestion_job( chunk_objects_multivector = document_service._create_chunk_objects( doc.external_id, processed_chunks_multivector, colpali_embeddings ) + colpali_embed_time = time.time() - colpali_embed_start + phase_times["colpali_generate_embeddings"] = colpali_embed_time + if using_colpali: + embeddings_per_second = len(colpali_embeddings) / colpali_embed_time if colpali_embed_time > 0 else 0 + logger.info( + f"Colpali embedding took {colpali_embed_time:.2f}s for {len(colpali_embeddings)} embeddings " + f"({embeddings_per_second:.2f} embeddings/s)" + ) # === Merge aggregated chunk metadata into document metadata === if aggregated_chunk_metadata: @@ -336,14 +405,28 @@ async def process_ingestion_job( doc.system_metadata["updated_at"] = datetime.now(UTC) # 11. Store chunks and update document with is_update=True + store_start = time.time() await document_service._store_chunks_and_doc( chunk_objects, doc, use_colpali, chunk_objects_multivector, is_update=True, auth=auth ) + store_time = time.time() - store_start + phase_times["store_chunks_and_update_doc"] = store_time + logger.info(f"Storing chunks and final document update took {store_time:.2f}s for {len(chunk_objects)} chunks") logger.debug(f"Successfully completed processing for document {doc.external_id}") # 13. Log successful completion logger.info(f"Successfully completed ingestion for {original_filename}, document ID: {doc.external_id}") + # Performance summary + total_time = time.time() - job_start_time + + # Log performance summary + logger.info("=== Ingestion Performance Summary ===") + logger.info(f"Total processing time: {total_time:.2f}s") + for phase, duration in sorted(phase_times.items(), key=lambda x: x[1], reverse=True): + percentage = (duration / total_time) * 100 if total_time > 0 else 0 + logger.info(f" - {phase}: {duration:.2f}s ({percentage:.1f}%)") + logger.info("=====================================") # 14. Return document ID return { diff --git a/scripts/migrate_multivector_embeddings.py b/scripts/migrate_multivector_embeddings.py new file mode 100644 index 0000000..f137631 --- /dev/null +++ b/scripts/migrate_multivector_embeddings.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Usage: cd $(dirname "$0")/.. && PYTHONPATH=. python3 $0 + +import asyncio +import json + +import psycopg +from pgvector.psycopg import register_vector + +from core.config import get_settings +from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel +from core.models.chunk import Chunk +from core.vector_store.multi_vector_store import MultiVectorStore + + +async def migrate_multivector_embeddings(): + settings = get_settings() + uri = settings.POSTGRES_URI + # Convert SQLAlchemy URI to psycopg format if needed + if uri.startswith("postgresql+asyncpg://"): + uri = uri.replace("postgresql+asyncpg://", "postgresql://") + mv_store = MultiVectorStore(uri) + if not mv_store.initialize(): + print("Failed to initialize MultiVectorStore") + return + + embedding_model = ColpaliEmbeddingModel() + + conn = psycopg.connect(uri, autocommit=True) + register_vector(conn) + cursor = conn.cursor() + + cursor.execute("SELECT id, document_id, chunk_number, content, chunk_metadata " "FROM multi_vector_embeddings") + rows = cursor.fetchall() + total = len(rows) + print(f"Found {total} multivector records to migrate...") + + for idx, (row_id, doc_id, chunk_num, content, meta_json) in enumerate(rows, start=1): + try: + # Parse metadata (JSON preferred, fallback to Python literal) + try: + metadata = json.loads(meta_json) if meta_json else {} + except json.JSONDecodeError: + import ast + + try: + metadata = ast.literal_eval(meta_json) + except Exception as exc: + print(f"Warning: failed to parse metadata for row {row_id}: {exc}") + metadata = {} + + # Create a chunk and recompute its multivector embedding + chunk = Chunk(content=content, metadata=metadata) + vectors = await embedding_model.embed_for_ingestion([chunk]) + vector = vectors[0] + bits = mv_store._binary_quantize(vector) + + # Update the embeddings in-place + cursor.execute( + "UPDATE multi_vector_embeddings SET embeddings = %s WHERE id = %s", + (bits, row_id), + ) + print(f"[{idx}/{total}] Updated doc={doc_id} chunk={chunk_num}") + except Exception as e: + print(f"Error migrating row {row_id} (doc={doc_id}, chunk={chunk_num}): {e}") + + cursor.close() + conn.close() + print("Migration complete.") + + +if __name__ == "__main__": + asyncio.run(migrate_multivector_embeddings()) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 401cbc4..b687961 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "morphik" -version = "0.1.4" +version = "0.1.5" authors = [ { name = "Morphik", email = "founders@morphik.ai" }, ]