mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add batching for gpus and performance logging (#124)
This commit is contained in:
parent
c6ec5cc9fb
commit
12bf224191
@ -1,14 +1,16 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Union
|
||||
import time
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
||||
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
|
||||
from PIL.Image import Image
|
||||
from PIL.Image import open as open_image
|
||||
|
||||
from core.config import get_settings
|
||||
from core.embedding.base_embedding_model import BaseEmbeddingModel
|
||||
from core.models.chunk import Chunk
|
||||
|
||||
@ -18,51 +20,221 @@ logger = logging.getLogger(__name__)
|
||||
class ColpaliEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self):
|
||||
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
"vidore/colqwen2-v1.0",
|
||||
logger.info(f"Initializing ColpaliEmbeddingModel with device: {device}")
|
||||
start_time = time.time()
|
||||
self.model = ColQwen2_5.from_pretrained(
|
||||
"tsystems/colqwen2.5-3b-multilingual-v1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device, # Automatically detect and use available device
|
||||
attn_implementation="flash_attention_2" if device == "cuda" else "eager",
|
||||
).eval()
|
||||
self.processor: ColQwen2Processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
|
||||
self.processor: ColQwen2_5_Processor = ColQwen2_5_Processor.from_pretrained(
|
||||
"tsystems/colqwen2.5-3b-multilingual-v1.0"
|
||||
)
|
||||
self.settings = get_settings()
|
||||
self.mode = self.settings.MODE
|
||||
# Set batch size based on mode
|
||||
self.batch_size = 8 if self.mode == "cloud" else 1
|
||||
logger.info(f"Colpali running in mode: {self.mode} with batch size: {self.batch_size}")
|
||||
total_init_time = time.time() - start_time
|
||||
logger.info(f"Colpali initialization time: {total_init_time:.2f} seconds")
|
||||
|
||||
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[np.ndarray]:
|
||||
job_start_time = time.time()
|
||||
if isinstance(chunks, Chunk):
|
||||
chunks = [chunks]
|
||||
|
||||
contents = []
|
||||
for chunk in chunks:
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Processing {len(chunks)} chunks for Colpali embedding in {self.mode} mode (batch size: {self.batch_size})"
|
||||
)
|
||||
|
||||
image_items: List[Tuple[int, Image]] = []
|
||||
text_items: List[Tuple[int, str]] = []
|
||||
sorting_start = time.time()
|
||||
|
||||
for index, chunk in enumerate(chunks):
|
||||
if chunk.metadata.get("is_image"):
|
||||
try:
|
||||
# Handle data URI format "data:image/png;base64,..."
|
||||
content = chunk.content
|
||||
if content.startswith("data:"):
|
||||
# Extract the base64 part after the comma
|
||||
content = content.split(",", 1)[1]
|
||||
|
||||
# Now decode the base64 string
|
||||
image_bytes = base64.b64decode(content)
|
||||
image = open_image(io.BytesIO(image_bytes))
|
||||
contents.append(image)
|
||||
image_items.append((index, image))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image: {str(e)}")
|
||||
# Fall back to using the content as text
|
||||
contents.append(chunk.content)
|
||||
logger.error(f"Error processing image chunk {index}: {str(e)}. Falling back to text.")
|
||||
text_items.append((index, chunk.content)) # Fallback: treat content as text
|
||||
else:
|
||||
contents.append(chunk.content)
|
||||
text_items.append((index, chunk.content))
|
||||
|
||||
return [await self.generate_embeddings(content) for content in contents]
|
||||
sorting_time = time.time() - sorting_start
|
||||
logger.info(
|
||||
f"Chunk sorting took {sorting_time:.2f}s - "
|
||||
f"Found {len(image_items)} images and {len(text_items)} text chunks"
|
||||
)
|
||||
|
||||
# Initialize results array to preserve order
|
||||
results: List[np.ndarray | None] = [None] * len(chunks)
|
||||
|
||||
# Process image batches
|
||||
if image_items:
|
||||
img_start = time.time()
|
||||
indices_to_process = [item[0] for item in image_items]
|
||||
images_to_process = [item[1] for item in image_items]
|
||||
for i in range(0, len(images_to_process), self.batch_size):
|
||||
batch_indices = indices_to_process[i : i + self.batch_size]
|
||||
batch_images = images_to_process[i : i + self.batch_size]
|
||||
logger.debug(
|
||||
f"Processing image batch {i//self.batch_size + 1}/"
|
||||
f"{(len(images_to_process)-1)//self.batch_size + 1} with {len(batch_images)} images"
|
||||
)
|
||||
batch_start = time.time()
|
||||
batch_embeddings = await self.generate_embeddings_batch_images(batch_images)
|
||||
# Place embeddings in the correct position in results
|
||||
for original_index, embedding in zip(batch_indices, batch_embeddings):
|
||||
results[original_index] = embedding
|
||||
batch_time = time.time() - batch_start
|
||||
logger.debug(
|
||||
f"Image batch {i//self.batch_size + 1} processing took {batch_time:.2f}s "
|
||||
f"({batch_time/len(batch_images):.2f}s per image)"
|
||||
)
|
||||
img_time = time.time() - img_start
|
||||
logger.info(f"All image embedding took {img_time:.2f}s ({img_time/len(images_to_process):.2f}s per image)")
|
||||
|
||||
# Process text batches
|
||||
if text_items:
|
||||
text_start = time.time()
|
||||
indices_to_process = [item[0] for item in text_items]
|
||||
texts_to_process = [item[1] for item in text_items]
|
||||
for i in range(0, len(texts_to_process), self.batch_size):
|
||||
batch_indices = indices_to_process[i : i + self.batch_size]
|
||||
batch_texts = texts_to_process[i : i + self.batch_size]
|
||||
logger.debug(
|
||||
f"Processing text batch {i//self.batch_size + 1}/"
|
||||
f"{(len(texts_to_process)-1)//self.batch_size + 1} with {len(batch_texts)} texts"
|
||||
)
|
||||
batch_start = time.time()
|
||||
batch_embeddings = await self.generate_embeddings_batch_texts(batch_texts)
|
||||
# Place embeddings in the correct position in results
|
||||
for original_index, embedding in zip(batch_indices, batch_embeddings):
|
||||
results[original_index] = embedding
|
||||
batch_time = time.time() - batch_start
|
||||
logger.debug(
|
||||
f"Text batch {i//self.batch_size + 1} processing took {batch_time:.2f}s "
|
||||
f"({batch_time/len(batch_texts):.2f}s per text)"
|
||||
)
|
||||
text_time = time.time() - text_start
|
||||
logger.info(f"All text embedding took {text_time:.2f}s ({text_time/len(texts_to_process):.2f}s per text)")
|
||||
|
||||
# Ensure all chunks were processed (handle potential None entries if errors occurred,
|
||||
# though unlikely with fallback)
|
||||
final_results = [res for res in results if res is not None]
|
||||
if len(final_results) != len(chunks):
|
||||
logger.warning(
|
||||
f"Number of embeddings ({len(final_results)}) does not match number of chunks "
|
||||
f"({len(chunks)}). Some chunks might have failed."
|
||||
)
|
||||
# Fill potential gaps if necessary, though the current logic should cover all chunks
|
||||
# For safety, let's reconstruct based on successfully processed indices, though it shouldn't be needed
|
||||
processed_indices = {idx for idx, _ in image_items} | {idx for idx, _ in text_items}
|
||||
if len(processed_indices) != len(chunks):
|
||||
logger.error("Mismatch in processed indices vs original chunks count. This indicates a logic error.")
|
||||
# Assuming results contains embeddings at correct original indices, filter out Nones
|
||||
final_results = [results[i] for i in range(len(chunks)) if results[i] is not None]
|
||||
|
||||
total_time = time.time() - job_start_time
|
||||
logger.info(
|
||||
f"Total Colpali embed_for_ingestion took {total_time:.2f}s for {len(chunks)} chunks "
|
||||
f"({total_time/len(chunks) if chunks else 0:.2f}s per chunk)"
|
||||
)
|
||||
# Cast is safe because we filter out Nones, though Nones shouldn't occur with the fallback logic
|
||||
return final_results # type: ignore
|
||||
|
||||
async def embed_for_query(self, text: str) -> torch.Tensor:
|
||||
return await self.generate_embeddings(text)
|
||||
start_time = time.time()
|
||||
result = await self.generate_embeddings(text)
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Colpali query embedding took {elapsed:.2f}s")
|
||||
return result
|
||||
|
||||
async def generate_embeddings(self, content: str | Image) -> np.ndarray:
|
||||
async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray:
|
||||
start_time = time.time()
|
||||
content_type = "image" if isinstance(content, Image) else "text"
|
||||
process_start = time.time()
|
||||
if isinstance(content, Image):
|
||||
processed = self.processor.process_images([content]).to(self.model.device)
|
||||
else:
|
||||
processed = self.processor.process_queries([content]).to(self.model.device)
|
||||
|
||||
process_time = time.time() - process_start
|
||||
|
||||
model_start = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
embeddings: torch.Tensor = self.model(**processed)
|
||||
|
||||
return embeddings.to(torch.float32).numpy(force=True)[0]
|
||||
model_time = time.time() - model_start
|
||||
|
||||
convert_start = time.time()
|
||||
|
||||
result = embeddings.to(torch.float32).numpy(force=True)[0]
|
||||
|
||||
convert_time = time.time() - convert_start
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"Generate embeddings ({content_type}): process={process_time:.2f}s, model={model_time:.2f}s, "
|
||||
f"convert={convert_time:.2f}s, total={total_time:.2f}s"
|
||||
)
|
||||
return result
|
||||
|
||||
# ---- Batch processing methods (only used in 'cloud' mode) ----
|
||||
|
||||
async def generate_embeddings_batch_images(self, images: List[Image]) -> List[np.ndarray]:
|
||||
batch_start_time = time.time()
|
||||
process_start = time.time()
|
||||
processed_images = self.processor.process_images(images).to(self.model.device)
|
||||
process_time = time.time() - process_start
|
||||
|
||||
model_start = time.time()
|
||||
with torch.no_grad():
|
||||
image_embeddings = self.model(**processed_images)
|
||||
model_time = time.time() - model_start
|
||||
|
||||
convert_start = time.time()
|
||||
image_embeddings_np = image_embeddings.to(torch.float32).numpy(force=True)
|
||||
result = [emb for emb in image_embeddings_np]
|
||||
convert_time = time.time() - convert_start
|
||||
|
||||
total_batch_time = time.time() - batch_start_time
|
||||
logger.debug(
|
||||
f"Batch images ({len(images)}): process={process_time:.2f}s, model={model_time:.2f}s, "
|
||||
f"convert={convert_time:.2f}s, total={total_batch_time:.2f}s ({total_batch_time/len(images):.3f}s/image)"
|
||||
)
|
||||
return result
|
||||
|
||||
async def generate_embeddings_batch_texts(self, texts: List[str]) -> List[np.ndarray]:
|
||||
batch_start_time = time.time()
|
||||
process_start = time.time()
|
||||
processed_texts = self.processor.process_queries(texts).to(self.model.device)
|
||||
process_time = time.time() - process_start
|
||||
|
||||
model_start = time.time()
|
||||
with torch.no_grad():
|
||||
text_embeddings = self.model(**processed_texts)
|
||||
model_time = time.time() - model_start
|
||||
|
||||
convert_start = time.time()
|
||||
text_embeddings_np = text_embeddings.to(torch.float32).numpy(force=True)
|
||||
result = [emb for emb in text_embeddings_np]
|
||||
convert_time = time.time() - convert_start
|
||||
|
||||
total_batch_time = time.time() - batch_start_time
|
||||
logger.debug(
|
||||
f"Batch texts ({len(texts)}): process={process_time:.2f}s, model={model_time:.2f}s, "
|
||||
f"convert={convert_time:.2f}s, total={total_batch_time:.2f}s ({total_batch_time/len(texts):.3f}s/text)"
|
||||
)
|
||||
return result
|
||||
|
@ -1778,8 +1778,8 @@ class DocumentService:
|
||||
# Also handle the case of multiple file versions in storage_files
|
||||
if hasattr(document, "storage_files") and document.storage_files:
|
||||
for file_info in document.storage_files:
|
||||
bucket = file_info.get("bucket")
|
||||
key = file_info.get("key")
|
||||
bucket = file_info.bucket
|
||||
key = file_info.key
|
||||
if bucket and key and hasattr(self.storage, "delete_file"):
|
||||
storage_deletion_tasks.append(self.storage.delete_file(bucket, key))
|
||||
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import urllib.parse as up
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -26,6 +27,19 @@ from core.vector_store.pgvector_store import PGVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure logger for ingestion worker (restored from diff)
|
||||
settings = get_settings() # Need settings for log level potentially, though INFO used here
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
# Set up file handler for worker_ingestion.log
|
||||
file_handler = logging.FileHandler("logs/worker_ingestion.log")
|
||||
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
||||
logger.addHandler(file_handler)
|
||||
# Set logger level based on settings (diff used INFO directly)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def get_document_with_retry(document_service, document_id, auth, max_retries=3, initial_delay=0.3):
|
||||
"""
|
||||
@ -116,10 +130,14 @@ async def process_ingestion_job(
|
||||
A dictionary with the document ID and processing status
|
||||
"""
|
||||
try:
|
||||
# Start performance timer
|
||||
job_start_time = time.time()
|
||||
phase_times = {}
|
||||
# 1. Log the start of the job
|
||||
logger.info(f"Starting ingestion job for file: {original_filename}")
|
||||
|
||||
# 2. Deserialize metadata and auth
|
||||
deserialize_start = time.time()
|
||||
metadata = json.loads(metadata_json) if metadata_json else {}
|
||||
auth = AuthContext(
|
||||
entity_type=EntityType(auth_dict.get("entity_type", "unknown")),
|
||||
@ -128,23 +146,32 @@ async def process_ingestion_job(
|
||||
permissions=set(auth_dict.get("permissions", ["read"])),
|
||||
user_id=auth_dict.get("user_id", auth_dict.get("entity_id", "")),
|
||||
)
|
||||
phase_times["deserialize_auth"] = time.time() - deserialize_start
|
||||
|
||||
# Get document service from the context
|
||||
document_service: DocumentService = ctx["document_service"]
|
||||
|
||||
# 3. Download the file from storage
|
||||
logger.info(f"Downloading file from {bucket}/{file_key}")
|
||||
download_start = time.time()
|
||||
file_content = await document_service.storage.download_file(bucket, file_key)
|
||||
|
||||
# Ensure file_content is bytes
|
||||
if hasattr(file_content, "read"):
|
||||
file_content = file_content.read()
|
||||
download_time = time.time() - download_start
|
||||
phase_times["download_file"] = download_time
|
||||
logger.info(f"File download took {download_time:.2f}s for {len(file_content)/1024/1024:.2f}MB")
|
||||
|
||||
# 4. Parse file to text
|
||||
parse_start = time.time()
|
||||
additional_metadata, text = await document_service.parser.parse_file_to_text(file_content, original_filename)
|
||||
logger.debug(f"Parsed file into text of length {len(text)}")
|
||||
parse_time = time.time() - parse_start
|
||||
phase_times["parse_file"] = parse_time
|
||||
|
||||
# === Apply post_parsing rules ===
|
||||
rules_start = time.time()
|
||||
document_rule_metadata = {}
|
||||
if rules_list:
|
||||
logger.info("Applying post-parsing rules...")
|
||||
@ -154,8 +181,13 @@ async def process_ingestion_job(
|
||||
metadata.update(document_rule_metadata) # Merge metadata into main doc metadata
|
||||
logger.info(f"Document metadata after post-parsing rules: {metadata}")
|
||||
logger.info(f"Content length after post-parsing rules: {len(text)}")
|
||||
rules_time = time.time() - rules_start
|
||||
phase_times["apply_post_parsing_rules"] = rules_time
|
||||
if rules_list:
|
||||
logger.info(f"Post-parsing rules processing took {rules_time:.2f}s")
|
||||
|
||||
# 6. Retrieve the existing document
|
||||
retrieve_start = time.time()
|
||||
logger.debug(f"Retrieving document with ID: {document_id}")
|
||||
logger.debug(
|
||||
f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}"
|
||||
@ -163,6 +195,9 @@ async def process_ingestion_job(
|
||||
|
||||
# Use the retry helper function with initial delay to handle race conditions
|
||||
doc = await get_document_with_retry(document_service, document_id, auth, max_retries=5, initial_delay=1.0)
|
||||
retrieve_time = time.time() - retrieve_start
|
||||
phase_times["retrieve_document"] = retrieve_time
|
||||
logger.info(f"Document retrieval took {retrieve_time:.2f}s")
|
||||
|
||||
if not doc:
|
||||
logger.error(f"Document {document_id} not found in database after multiple retries")
|
||||
@ -193,7 +228,11 @@ async def process_ingestion_job(
|
||||
updates["system_metadata"]["end_user_id"] = end_user_id
|
||||
|
||||
# Update the document in the database
|
||||
update_start = time.time()
|
||||
success = await document_service.db.update_document(document_id=document_id, updates=updates, auth=auth)
|
||||
update_time = time.time() - update_start
|
||||
phase_times["update_document_parsed"] = update_time
|
||||
logger.info(f"Initial document update took {update_time:.2f}s")
|
||||
|
||||
if not success:
|
||||
raise ValueError(f"Failed to update document {document_id}")
|
||||
@ -203,13 +242,18 @@ async def process_ingestion_job(
|
||||
logger.debug("Updated document in database with parsed content")
|
||||
|
||||
# 7. Split text into chunks
|
||||
chunking_start = time.time()
|
||||
parsed_chunks = await document_service.parser.split_text(text)
|
||||
if not parsed_chunks:
|
||||
raise ValueError("No content chunks extracted after rules processing")
|
||||
logger.debug(f"Split processed text into {len(parsed_chunks)} chunks")
|
||||
chunking_time = time.time() - chunking_start
|
||||
phase_times["split_into_chunks"] = chunking_time
|
||||
logger.info(f"Text chunking took {chunking_time:.2f}s to create {len(parsed_chunks)} chunks")
|
||||
|
||||
# 8. Handle ColPali embeddings if enabled - IMPORTANT: Do this BEFORE applying chunk rules
|
||||
# so that image chunks can be processed by rules when use_images=True
|
||||
colpali_processing_start = time.time()
|
||||
using_colpali = (
|
||||
use_colpali and document_service.colpali_embedding_model and document_service.colpali_vector_store
|
||||
)
|
||||
@ -227,6 +271,10 @@ async def process_ingestion_job(
|
||||
file_type, file_content_base64, file_content, parsed_chunks
|
||||
)
|
||||
logger.debug(f"Created {len(chunks_multivector)} chunks for multivector embedding")
|
||||
colpali_create_chunks_time = time.time() - colpali_processing_start
|
||||
phase_times["colpali_create_chunks"] = colpali_create_chunks_time
|
||||
if using_colpali:
|
||||
logger.info(f"Colpali chunk creation took {colpali_create_chunks_time:.2f}s")
|
||||
|
||||
# 9. Apply post_chunking rules and aggregate metadata
|
||||
processed_chunks = []
|
||||
@ -302,14 +350,27 @@ async def process_ingestion_job(
|
||||
processed_chunks_multivector = chunks_multivector # No rules, use original multivector chunks
|
||||
|
||||
# 10. Generate embeddings for processed chunks
|
||||
embedding_start = time.time()
|
||||
embeddings = await document_service.embedding_model.embed_for_ingestion(processed_chunks)
|
||||
logger.debug(f"Generated {len(embeddings)} embeddings")
|
||||
embedding_time = time.time() - embedding_start
|
||||
phase_times["generate_embeddings"] = embedding_time
|
||||
embeddings_per_second = len(embeddings) / embedding_time if embedding_time > 0 else 0
|
||||
logger.info(
|
||||
f"Embedding generation took {embedding_time:.2f}s for {len(embeddings)} embeddings "
|
||||
f"({embeddings_per_second:.2f} embeddings/s)"
|
||||
)
|
||||
|
||||
# 11. Create chunk objects with potentially modified chunk content and metadata
|
||||
chunk_objects_start = time.time()
|
||||
chunk_objects = document_service._create_chunk_objects(doc.external_id, processed_chunks, embeddings)
|
||||
logger.debug(f"Created {len(chunk_objects)} chunk objects")
|
||||
chunk_objects_time = time.time() - chunk_objects_start
|
||||
phase_times["create_chunk_objects"] = chunk_objects_time
|
||||
logger.debug(f"Creating chunk objects took {chunk_objects_time:.2f}s")
|
||||
|
||||
# 12. Handle ColPali embeddings
|
||||
colpali_embed_start = time.time()
|
||||
chunk_objects_multivector = []
|
||||
if using_colpali:
|
||||
colpali_embeddings = await document_service.colpali_embedding_model.embed_for_ingestion(
|
||||
@ -320,6 +381,14 @@ async def process_ingestion_job(
|
||||
chunk_objects_multivector = document_service._create_chunk_objects(
|
||||
doc.external_id, processed_chunks_multivector, colpali_embeddings
|
||||
)
|
||||
colpali_embed_time = time.time() - colpali_embed_start
|
||||
phase_times["colpali_generate_embeddings"] = colpali_embed_time
|
||||
if using_colpali:
|
||||
embeddings_per_second = len(colpali_embeddings) / colpali_embed_time if colpali_embed_time > 0 else 0
|
||||
logger.info(
|
||||
f"Colpali embedding took {colpali_embed_time:.2f}s for {len(colpali_embeddings)} embeddings "
|
||||
f"({embeddings_per_second:.2f} embeddings/s)"
|
||||
)
|
||||
|
||||
# === Merge aggregated chunk metadata into document metadata ===
|
||||
if aggregated_chunk_metadata:
|
||||
@ -336,14 +405,28 @@ async def process_ingestion_job(
|
||||
doc.system_metadata["updated_at"] = datetime.now(UTC)
|
||||
|
||||
# 11. Store chunks and update document with is_update=True
|
||||
store_start = time.time()
|
||||
await document_service._store_chunks_and_doc(
|
||||
chunk_objects, doc, use_colpali, chunk_objects_multivector, is_update=True, auth=auth
|
||||
)
|
||||
store_time = time.time() - store_start
|
||||
phase_times["store_chunks_and_update_doc"] = store_time
|
||||
logger.info(f"Storing chunks and final document update took {store_time:.2f}s for {len(chunk_objects)} chunks")
|
||||
|
||||
logger.debug(f"Successfully completed processing for document {doc.external_id}")
|
||||
|
||||
# 13. Log successful completion
|
||||
logger.info(f"Successfully completed ingestion for {original_filename}, document ID: {doc.external_id}")
|
||||
# Performance summary
|
||||
total_time = time.time() - job_start_time
|
||||
|
||||
# Log performance summary
|
||||
logger.info("=== Ingestion Performance Summary ===")
|
||||
logger.info(f"Total processing time: {total_time:.2f}s")
|
||||
for phase, duration in sorted(phase_times.items(), key=lambda x: x[1], reverse=True):
|
||||
percentage = (duration / total_time) * 100 if total_time > 0 else 0
|
||||
logger.info(f" - {phase}: {duration:.2f}s ({percentage:.1f}%)")
|
||||
logger.info("=====================================")
|
||||
|
||||
# 14. Return document ID
|
||||
return {
|
||||
|
73
scripts/migrate_multivector_embeddings.py
Normal file
73
scripts/migrate_multivector_embeddings.py
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# Usage: cd $(dirname "$0")/.. && PYTHONPATH=. python3 $0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
|
||||
from core.config import get_settings
|
||||
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
|
||||
from core.models.chunk import Chunk
|
||||
from core.vector_store.multi_vector_store import MultiVectorStore
|
||||
|
||||
|
||||
async def migrate_multivector_embeddings():
|
||||
settings = get_settings()
|
||||
uri = settings.POSTGRES_URI
|
||||
# Convert SQLAlchemy URI to psycopg format if needed
|
||||
if uri.startswith("postgresql+asyncpg://"):
|
||||
uri = uri.replace("postgresql+asyncpg://", "postgresql://")
|
||||
mv_store = MultiVectorStore(uri)
|
||||
if not mv_store.initialize():
|
||||
print("Failed to initialize MultiVectorStore")
|
||||
return
|
||||
|
||||
embedding_model = ColpaliEmbeddingModel()
|
||||
|
||||
conn = psycopg.connect(uri, autocommit=True)
|
||||
register_vector(conn)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT id, document_id, chunk_number, content, chunk_metadata " "FROM multi_vector_embeddings")
|
||||
rows = cursor.fetchall()
|
||||
total = len(rows)
|
||||
print(f"Found {total} multivector records to migrate...")
|
||||
|
||||
for idx, (row_id, doc_id, chunk_num, content, meta_json) in enumerate(rows, start=1):
|
||||
try:
|
||||
# Parse metadata (JSON preferred, fallback to Python literal)
|
||||
try:
|
||||
metadata = json.loads(meta_json) if meta_json else {}
|
||||
except json.JSONDecodeError:
|
||||
import ast
|
||||
|
||||
try:
|
||||
metadata = ast.literal_eval(meta_json)
|
||||
except Exception as exc:
|
||||
print(f"Warning: failed to parse metadata for row {row_id}: {exc}")
|
||||
metadata = {}
|
||||
|
||||
# Create a chunk and recompute its multivector embedding
|
||||
chunk = Chunk(content=content, metadata=metadata)
|
||||
vectors = await embedding_model.embed_for_ingestion([chunk])
|
||||
vector = vectors[0]
|
||||
bits = mv_store._binary_quantize(vector)
|
||||
|
||||
# Update the embeddings in-place
|
||||
cursor.execute(
|
||||
"UPDATE multi_vector_embeddings SET embeddings = %s WHERE id = %s",
|
||||
(bits, row_id),
|
||||
)
|
||||
print(f"[{idx}/{total}] Updated doc={doc_id} chunk={chunk_num}")
|
||||
except Exception as e:
|
||||
print(f"Error migrating row {row_id} (doc={doc_id}, chunk={chunk_num}): {e}")
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
print("Migration complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(migrate_multivector_embeddings())
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "morphik"
|
||||
version = "0.1.4"
|
||||
version = "0.1.5"
|
||||
authors = [
|
||||
{ name = "Morphik", email = "founders@morphik.ai" },
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user