Add batching for gpus and performance logging (#124)

This commit is contained in:
Adityavardhan Agrawal 2025-04-30 19:42:43 -07:00 committed by GitHub
parent c6ec5cc9fb
commit 12bf224191
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 351 additions and 23 deletions

View File

@ -1,14 +1,16 @@
import base64
import io
import logging
from typing import List, Union
import time
from typing import List, Tuple, Union
import numpy as np
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
from PIL.Image import Image
from PIL.Image import open as open_image
from core.config import get_settings
from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.models.chunk import Chunk
@ -18,51 +20,221 @@ logger = logging.getLogger(__name__)
class ColpaliEmbeddingModel(BaseEmbeddingModel):
def __init__(self):
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
self.model = ColQwen2.from_pretrained(
"vidore/colqwen2-v1.0",
logger.info(f"Initializing ColpaliEmbeddingModel with device: {device}")
start_time = time.time()
self.model = ColQwen2_5.from_pretrained(
"tsystems/colqwen2.5-3b-multilingual-v1.0",
torch_dtype=torch.bfloat16,
device_map=device, # Automatically detect and use available device
attn_implementation="flash_attention_2" if device == "cuda" else "eager",
).eval()
self.processor: ColQwen2Processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
self.processor: ColQwen2_5_Processor = ColQwen2_5_Processor.from_pretrained(
"tsystems/colqwen2.5-3b-multilingual-v1.0"
)
self.settings = get_settings()
self.mode = self.settings.MODE
# Set batch size based on mode
self.batch_size = 8 if self.mode == "cloud" else 1
logger.info(f"Colpali running in mode: {self.mode} with batch size: {self.batch_size}")
total_init_time = time.time() - start_time
logger.info(f"Colpali initialization time: {total_init_time:.2f} seconds")
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[np.ndarray]:
job_start_time = time.time()
if isinstance(chunks, Chunk):
chunks = [chunks]
contents = []
for chunk in chunks:
if not chunks:
return []
logger.info(
f"Processing {len(chunks)} chunks for Colpali embedding in {self.mode} mode (batch size: {self.batch_size})"
)
image_items: List[Tuple[int, Image]] = []
text_items: List[Tuple[int, str]] = []
sorting_start = time.time()
for index, chunk in enumerate(chunks):
if chunk.metadata.get("is_image"):
try:
# Handle data URI format "data:image/png;base64,..."
content = chunk.content
if content.startswith("data:"):
# Extract the base64 part after the comma
content = content.split(",", 1)[1]
# Now decode the base64 string
image_bytes = base64.b64decode(content)
image = open_image(io.BytesIO(image_bytes))
contents.append(image)
image_items.append((index, image))
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
# Fall back to using the content as text
contents.append(chunk.content)
logger.error(f"Error processing image chunk {index}: {str(e)}. Falling back to text.")
text_items.append((index, chunk.content)) # Fallback: treat content as text
else:
contents.append(chunk.content)
text_items.append((index, chunk.content))
return [await self.generate_embeddings(content) for content in contents]
sorting_time = time.time() - sorting_start
logger.info(
f"Chunk sorting took {sorting_time:.2f}s - "
f"Found {len(image_items)} images and {len(text_items)} text chunks"
)
# Initialize results array to preserve order
results: List[np.ndarray | None] = [None] * len(chunks)
# Process image batches
if image_items:
img_start = time.time()
indices_to_process = [item[0] for item in image_items]
images_to_process = [item[1] for item in image_items]
for i in range(0, len(images_to_process), self.batch_size):
batch_indices = indices_to_process[i : i + self.batch_size]
batch_images = images_to_process[i : i + self.batch_size]
logger.debug(
f"Processing image batch {i//self.batch_size + 1}/"
f"{(len(images_to_process)-1)//self.batch_size + 1} with {len(batch_images)} images"
)
batch_start = time.time()
batch_embeddings = await self.generate_embeddings_batch_images(batch_images)
# Place embeddings in the correct position in results
for original_index, embedding in zip(batch_indices, batch_embeddings):
results[original_index] = embedding
batch_time = time.time() - batch_start
logger.debug(
f"Image batch {i//self.batch_size + 1} processing took {batch_time:.2f}s "
f"({batch_time/len(batch_images):.2f}s per image)"
)
img_time = time.time() - img_start
logger.info(f"All image embedding took {img_time:.2f}s ({img_time/len(images_to_process):.2f}s per image)")
# Process text batches
if text_items:
text_start = time.time()
indices_to_process = [item[0] for item in text_items]
texts_to_process = [item[1] for item in text_items]
for i in range(0, len(texts_to_process), self.batch_size):
batch_indices = indices_to_process[i : i + self.batch_size]
batch_texts = texts_to_process[i : i + self.batch_size]
logger.debug(
f"Processing text batch {i//self.batch_size + 1}/"
f"{(len(texts_to_process)-1)//self.batch_size + 1} with {len(batch_texts)} texts"
)
batch_start = time.time()
batch_embeddings = await self.generate_embeddings_batch_texts(batch_texts)
# Place embeddings in the correct position in results
for original_index, embedding in zip(batch_indices, batch_embeddings):
results[original_index] = embedding
batch_time = time.time() - batch_start
logger.debug(
f"Text batch {i//self.batch_size + 1} processing took {batch_time:.2f}s "
f"({batch_time/len(batch_texts):.2f}s per text)"
)
text_time = time.time() - text_start
logger.info(f"All text embedding took {text_time:.2f}s ({text_time/len(texts_to_process):.2f}s per text)")
# Ensure all chunks were processed (handle potential None entries if errors occurred,
# though unlikely with fallback)
final_results = [res for res in results if res is not None]
if len(final_results) != len(chunks):
logger.warning(
f"Number of embeddings ({len(final_results)}) does not match number of chunks "
f"({len(chunks)}). Some chunks might have failed."
)
# Fill potential gaps if necessary, though the current logic should cover all chunks
# For safety, let's reconstruct based on successfully processed indices, though it shouldn't be needed
processed_indices = {idx for idx, _ in image_items} | {idx for idx, _ in text_items}
if len(processed_indices) != len(chunks):
logger.error("Mismatch in processed indices vs original chunks count. This indicates a logic error.")
# Assuming results contains embeddings at correct original indices, filter out Nones
final_results = [results[i] for i in range(len(chunks)) if results[i] is not None]
total_time = time.time() - job_start_time
logger.info(
f"Total Colpali embed_for_ingestion took {total_time:.2f}s for {len(chunks)} chunks "
f"({total_time/len(chunks) if chunks else 0:.2f}s per chunk)"
)
# Cast is safe because we filter out Nones, though Nones shouldn't occur with the fallback logic
return final_results # type: ignore
async def embed_for_query(self, text: str) -> torch.Tensor:
return await self.generate_embeddings(text)
start_time = time.time()
result = await self.generate_embeddings(text)
elapsed = time.time() - start_time
logger.info(f"Colpali query embedding took {elapsed:.2f}s")
return result
async def generate_embeddings(self, content: str | Image) -> np.ndarray:
async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray:
start_time = time.time()
content_type = "image" if isinstance(content, Image) else "text"
process_start = time.time()
if isinstance(content, Image):
processed = self.processor.process_images([content]).to(self.model.device)
else:
processed = self.processor.process_queries([content]).to(self.model.device)
process_time = time.time() - process_start
model_start = time.time()
with torch.no_grad():
embeddings: torch.Tensor = self.model(**processed)
return embeddings.to(torch.float32).numpy(force=True)[0]
model_time = time.time() - model_start
convert_start = time.time()
result = embeddings.to(torch.float32).numpy(force=True)[0]
convert_time = time.time() - convert_start
total_time = time.time() - start_time
logger.debug(
f"Generate embeddings ({content_type}): process={process_time:.2f}s, model={model_time:.2f}s, "
f"convert={convert_time:.2f}s, total={total_time:.2f}s"
)
return result
# ---- Batch processing methods (only used in 'cloud' mode) ----
async def generate_embeddings_batch_images(self, images: List[Image]) -> List[np.ndarray]:
batch_start_time = time.time()
process_start = time.time()
processed_images = self.processor.process_images(images).to(self.model.device)
process_time = time.time() - process_start
model_start = time.time()
with torch.no_grad():
image_embeddings = self.model(**processed_images)
model_time = time.time() - model_start
convert_start = time.time()
image_embeddings_np = image_embeddings.to(torch.float32).numpy(force=True)
result = [emb for emb in image_embeddings_np]
convert_time = time.time() - convert_start
total_batch_time = time.time() - batch_start_time
logger.debug(
f"Batch images ({len(images)}): process={process_time:.2f}s, model={model_time:.2f}s, "
f"convert={convert_time:.2f}s, total={total_batch_time:.2f}s ({total_batch_time/len(images):.3f}s/image)"
)
return result
async def generate_embeddings_batch_texts(self, texts: List[str]) -> List[np.ndarray]:
batch_start_time = time.time()
process_start = time.time()
processed_texts = self.processor.process_queries(texts).to(self.model.device)
process_time = time.time() - process_start
model_start = time.time()
with torch.no_grad():
text_embeddings = self.model(**processed_texts)
model_time = time.time() - model_start
convert_start = time.time()
text_embeddings_np = text_embeddings.to(torch.float32).numpy(force=True)
result = [emb for emb in text_embeddings_np]
convert_time = time.time() - convert_start
total_batch_time = time.time() - batch_start_time
logger.debug(
f"Batch texts ({len(texts)}): process={process_time:.2f}s, model={model_time:.2f}s, "
f"convert={convert_time:.2f}s, total={total_batch_time:.2f}s ({total_batch_time/len(texts):.3f}s/text)"
)
return result

View File

@ -1778,8 +1778,8 @@ class DocumentService:
# Also handle the case of multiple file versions in storage_files
if hasattr(document, "storage_files") and document.storage_files:
for file_info in document.storage_files:
bucket = file_info.get("bucket")
key = file_info.get("key")
bucket = file_info.bucket
key = file_info.key
if bucket and key and hasattr(self.storage, "delete_file"):
storage_deletion_tasks.append(self.storage.delete_file(bucket, key))

View File

@ -2,6 +2,7 @@ import asyncio
import json
import logging
import os
import time
import urllib.parse as up
from datetime import UTC, datetime
from typing import Any, Dict, List, Optional
@ -26,6 +27,19 @@ from core.vector_store.pgvector_store import PGVectorStore
logger = logging.getLogger(__name__)
# Configure logger for ingestion worker (restored from diff)
settings = get_settings() # Need settings for log level potentially, though INFO used here
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
# Set up file handler for worker_ingestion.log
file_handler = logging.FileHandler("logs/worker_ingestion.log")
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
logger.addHandler(file_handler)
# Set logger level based on settings (diff used INFO directly)
logger.setLevel(logging.INFO)
async def get_document_with_retry(document_service, document_id, auth, max_retries=3, initial_delay=0.3):
"""
@ -116,10 +130,14 @@ async def process_ingestion_job(
A dictionary with the document ID and processing status
"""
try:
# Start performance timer
job_start_time = time.time()
phase_times = {}
# 1. Log the start of the job
logger.info(f"Starting ingestion job for file: {original_filename}")
# 2. Deserialize metadata and auth
deserialize_start = time.time()
metadata = json.loads(metadata_json) if metadata_json else {}
auth = AuthContext(
entity_type=EntityType(auth_dict.get("entity_type", "unknown")),
@ -128,23 +146,32 @@ async def process_ingestion_job(
permissions=set(auth_dict.get("permissions", ["read"])),
user_id=auth_dict.get("user_id", auth_dict.get("entity_id", "")),
)
phase_times["deserialize_auth"] = time.time() - deserialize_start
# Get document service from the context
document_service: DocumentService = ctx["document_service"]
# 3. Download the file from storage
logger.info(f"Downloading file from {bucket}/{file_key}")
download_start = time.time()
file_content = await document_service.storage.download_file(bucket, file_key)
# Ensure file_content is bytes
if hasattr(file_content, "read"):
file_content = file_content.read()
download_time = time.time() - download_start
phase_times["download_file"] = download_time
logger.info(f"File download took {download_time:.2f}s for {len(file_content)/1024/1024:.2f}MB")
# 4. Parse file to text
parse_start = time.time()
additional_metadata, text = await document_service.parser.parse_file_to_text(file_content, original_filename)
logger.debug(f"Parsed file into text of length {len(text)}")
parse_time = time.time() - parse_start
phase_times["parse_file"] = parse_time
# === Apply post_parsing rules ===
rules_start = time.time()
document_rule_metadata = {}
if rules_list:
logger.info("Applying post-parsing rules...")
@ -154,8 +181,13 @@ async def process_ingestion_job(
metadata.update(document_rule_metadata) # Merge metadata into main doc metadata
logger.info(f"Document metadata after post-parsing rules: {metadata}")
logger.info(f"Content length after post-parsing rules: {len(text)}")
rules_time = time.time() - rules_start
phase_times["apply_post_parsing_rules"] = rules_time
if rules_list:
logger.info(f"Post-parsing rules processing took {rules_time:.2f}s")
# 6. Retrieve the existing document
retrieve_start = time.time()
logger.debug(f"Retrieving document with ID: {document_id}")
logger.debug(
f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}"
@ -163,6 +195,9 @@ async def process_ingestion_job(
# Use the retry helper function with initial delay to handle race conditions
doc = await get_document_with_retry(document_service, document_id, auth, max_retries=5, initial_delay=1.0)
retrieve_time = time.time() - retrieve_start
phase_times["retrieve_document"] = retrieve_time
logger.info(f"Document retrieval took {retrieve_time:.2f}s")
if not doc:
logger.error(f"Document {document_id} not found in database after multiple retries")
@ -193,7 +228,11 @@ async def process_ingestion_job(
updates["system_metadata"]["end_user_id"] = end_user_id
# Update the document in the database
update_start = time.time()
success = await document_service.db.update_document(document_id=document_id, updates=updates, auth=auth)
update_time = time.time() - update_start
phase_times["update_document_parsed"] = update_time
logger.info(f"Initial document update took {update_time:.2f}s")
if not success:
raise ValueError(f"Failed to update document {document_id}")
@ -203,13 +242,18 @@ async def process_ingestion_job(
logger.debug("Updated document in database with parsed content")
# 7. Split text into chunks
chunking_start = time.time()
parsed_chunks = await document_service.parser.split_text(text)
if not parsed_chunks:
raise ValueError("No content chunks extracted after rules processing")
logger.debug(f"Split processed text into {len(parsed_chunks)} chunks")
chunking_time = time.time() - chunking_start
phase_times["split_into_chunks"] = chunking_time
logger.info(f"Text chunking took {chunking_time:.2f}s to create {len(parsed_chunks)} chunks")
# 8. Handle ColPali embeddings if enabled - IMPORTANT: Do this BEFORE applying chunk rules
# so that image chunks can be processed by rules when use_images=True
colpali_processing_start = time.time()
using_colpali = (
use_colpali and document_service.colpali_embedding_model and document_service.colpali_vector_store
)
@ -227,6 +271,10 @@ async def process_ingestion_job(
file_type, file_content_base64, file_content, parsed_chunks
)
logger.debug(f"Created {len(chunks_multivector)} chunks for multivector embedding")
colpali_create_chunks_time = time.time() - colpali_processing_start
phase_times["colpali_create_chunks"] = colpali_create_chunks_time
if using_colpali:
logger.info(f"Colpali chunk creation took {colpali_create_chunks_time:.2f}s")
# 9. Apply post_chunking rules and aggregate metadata
processed_chunks = []
@ -302,14 +350,27 @@ async def process_ingestion_job(
processed_chunks_multivector = chunks_multivector # No rules, use original multivector chunks
# 10. Generate embeddings for processed chunks
embedding_start = time.time()
embeddings = await document_service.embedding_model.embed_for_ingestion(processed_chunks)
logger.debug(f"Generated {len(embeddings)} embeddings")
embedding_time = time.time() - embedding_start
phase_times["generate_embeddings"] = embedding_time
embeddings_per_second = len(embeddings) / embedding_time if embedding_time > 0 else 0
logger.info(
f"Embedding generation took {embedding_time:.2f}s for {len(embeddings)} embeddings "
f"({embeddings_per_second:.2f} embeddings/s)"
)
# 11. Create chunk objects with potentially modified chunk content and metadata
chunk_objects_start = time.time()
chunk_objects = document_service._create_chunk_objects(doc.external_id, processed_chunks, embeddings)
logger.debug(f"Created {len(chunk_objects)} chunk objects")
chunk_objects_time = time.time() - chunk_objects_start
phase_times["create_chunk_objects"] = chunk_objects_time
logger.debug(f"Creating chunk objects took {chunk_objects_time:.2f}s")
# 12. Handle ColPali embeddings
colpali_embed_start = time.time()
chunk_objects_multivector = []
if using_colpali:
colpali_embeddings = await document_service.colpali_embedding_model.embed_for_ingestion(
@ -320,6 +381,14 @@ async def process_ingestion_job(
chunk_objects_multivector = document_service._create_chunk_objects(
doc.external_id, processed_chunks_multivector, colpali_embeddings
)
colpali_embed_time = time.time() - colpali_embed_start
phase_times["colpali_generate_embeddings"] = colpali_embed_time
if using_colpali:
embeddings_per_second = len(colpali_embeddings) / colpali_embed_time if colpali_embed_time > 0 else 0
logger.info(
f"Colpali embedding took {colpali_embed_time:.2f}s for {len(colpali_embeddings)} embeddings "
f"({embeddings_per_second:.2f} embeddings/s)"
)
# === Merge aggregated chunk metadata into document metadata ===
if aggregated_chunk_metadata:
@ -336,14 +405,28 @@ async def process_ingestion_job(
doc.system_metadata["updated_at"] = datetime.now(UTC)
# 11. Store chunks and update document with is_update=True
store_start = time.time()
await document_service._store_chunks_and_doc(
chunk_objects, doc, use_colpali, chunk_objects_multivector, is_update=True, auth=auth
)
store_time = time.time() - store_start
phase_times["store_chunks_and_update_doc"] = store_time
logger.info(f"Storing chunks and final document update took {store_time:.2f}s for {len(chunk_objects)} chunks")
logger.debug(f"Successfully completed processing for document {doc.external_id}")
# 13. Log successful completion
logger.info(f"Successfully completed ingestion for {original_filename}, document ID: {doc.external_id}")
# Performance summary
total_time = time.time() - job_start_time
# Log performance summary
logger.info("=== Ingestion Performance Summary ===")
logger.info(f"Total processing time: {total_time:.2f}s")
for phase, duration in sorted(phase_times.items(), key=lambda x: x[1], reverse=True):
percentage = (duration / total_time) * 100 if total_time > 0 else 0
logger.info(f" - {phase}: {duration:.2f}s ({percentage:.1f}%)")
logger.info("=====================================")
# 14. Return document ID
return {

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Usage: cd $(dirname "$0")/.. && PYTHONPATH=. python3 $0
import asyncio
import json
import psycopg
from pgvector.psycopg import register_vector
from core.config import get_settings
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.models.chunk import Chunk
from core.vector_store.multi_vector_store import MultiVectorStore
async def migrate_multivector_embeddings():
settings = get_settings()
uri = settings.POSTGRES_URI
# Convert SQLAlchemy URI to psycopg format if needed
if uri.startswith("postgresql+asyncpg://"):
uri = uri.replace("postgresql+asyncpg://", "postgresql://")
mv_store = MultiVectorStore(uri)
if not mv_store.initialize():
print("Failed to initialize MultiVectorStore")
return
embedding_model = ColpaliEmbeddingModel()
conn = psycopg.connect(uri, autocommit=True)
register_vector(conn)
cursor = conn.cursor()
cursor.execute("SELECT id, document_id, chunk_number, content, chunk_metadata " "FROM multi_vector_embeddings")
rows = cursor.fetchall()
total = len(rows)
print(f"Found {total} multivector records to migrate...")
for idx, (row_id, doc_id, chunk_num, content, meta_json) in enumerate(rows, start=1):
try:
# Parse metadata (JSON preferred, fallback to Python literal)
try:
metadata = json.loads(meta_json) if meta_json else {}
except json.JSONDecodeError:
import ast
try:
metadata = ast.literal_eval(meta_json)
except Exception as exc:
print(f"Warning: failed to parse metadata for row {row_id}: {exc}")
metadata = {}
# Create a chunk and recompute its multivector embedding
chunk = Chunk(content=content, metadata=metadata)
vectors = await embedding_model.embed_for_ingestion([chunk])
vector = vectors[0]
bits = mv_store._binary_quantize(vector)
# Update the embeddings in-place
cursor.execute(
"UPDATE multi_vector_embeddings SET embeddings = %s WHERE id = %s",
(bits, row_id),
)
print(f"[{idx}/{total}] Updated doc={doc_id} chunk={chunk_num}")
except Exception as e:
print(f"Error migrating row {row_id} (doc={doc_id}, chunk={chunk_num}): {e}")
cursor.close()
conn.close()
print("Migration complete.")
if __name__ == "__main__":
asyncio.run(migrate_multivector_embeddings())

View File

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "morphik"
version = "0.1.4"
version = "0.1.5"
authors = [
{ name = "Morphik", email = "founders@morphik.ai" },
]