From 3c1195e0015b11aa3a6387b87238f311082727d4 Mon Sep 17 00:00:00 2001 From: Arnav Agrawal <88790414+ArnavAgrawal03@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:31:49 -0700 Subject: [PATCH] add task queue (#87) * add task queue * ensure task queuing is working as expected. * add downstream sdk changes * bugs and pr comments * update docker arq running logic --- core/api.py | 335 ++++++++++++++++++++---- core/config.py | 13 + core/models/documents.py | 1 + core/models/request.py | 6 + core/parser/base_parser.py | 2 +- core/services/document_service.py | 3 +- core/workers/__init__.py | 1 + core/workers/ingestion_worker.py | 411 ++++++++++++++++++++++++++++++ docker-compose.yml | 52 +++- morphik.toml | 4 + requirements.txt | 1 + sdks/python/morphik/async_.py | 70 +++++ sdks/python/morphik/models.py | 54 ++++ sdks/python/morphik/sync.py | 70 +++++ start_server.py | 141 +++++++++- 15 files changed, 1103 insertions(+), 61 deletions(-) create mode 100644 core/workers/__init__.py create mode 100644 core/workers/ingestion_worker.py diff --git a/core/api.py b/core/api.py index 0bcdaa3..6e76589 100644 --- a/core/api.py +++ b/core/api.py @@ -1,5 +1,7 @@ import asyncio import json +import base64 +import uuid from datetime import datetime, UTC, timedelta from pathlib import Path import sys @@ -8,6 +10,7 @@ from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile, F from fastapi.middleware.cors import CORSMiddleware import jwt import logging +import arq from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from core.limits_utils import check_and_increment_limits from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, UpdateGraphRequest, BatchIngestResponse @@ -79,6 +82,8 @@ if not settings.POSTGRES_URI: raise ValueError("PostgreSQL URI is required for PostgreSQL database") database = PostgresDatabase(uri=settings.POSTGRES_URI) +# Redis settings already imported at top of file + @app.on_event("startup") async def initialize_database(): @@ -129,6 +134,37 @@ async def initialize_user_limits_database(): user_limits_db = UserLimitsDatabase(uri=settings.POSTGRES_URI) await user_limits_db.initialize() +@app.on_event("startup") +async def initialize_redis_pool(): + """Initialize the Redis connection pool for background tasks.""" + global redis_pool + logger.info("Initializing Redis connection pool...") + + # Get Redis settings from configuration + redis_host = settings.REDIS_HOST + redis_port = settings.REDIS_PORT + + # Log the Redis connection details + logger.info(f"Connecting to Redis at {redis_host}:{redis_port}") + + redis_settings = arq.connections.RedisSettings( + host=redis_host, + port=redis_port, + ) + + redis_pool = await arq.create_pool(redis_settings) + logger.info("Redis connection pool initialized successfully") + +@app.on_event("shutdown") +async def close_redis_pool(): + """Close the Redis connection pool on application shutdown.""" + global redis_pool + if redis_pool: + logger.info("Closing Redis connection pool...") + redis_pool.close() + await redis_pool.wait_closed() + logger.info("Redis connection pool closed") + # Initialize vector store if not settings.POSTGRES_URI: raise ValueError("PostgreSQL URI is required for pgvector store") @@ -319,6 +355,13 @@ async def ingest_text( raise HTTPException(status_code=403, detail=str(e)) +# Redis pool for background tasks +redis_pool = None + +def get_redis_pool(): + """Get the global Redis connection pool for background tasks.""" + return redis_pool + @app.post("/ingest/file", response_model=Document) async def ingest_file( file: UploadFile, @@ -328,9 +371,10 @@ async def ingest_file( use_colpali: Optional[bool] = None, folder_name: Optional[str] = Form(None), end_user_id: Optional[str] = Form(None), + redis: arq.ArqRedis = Depends(get_redis_pool), ) -> Document: """ - Ingest a file document. + Ingest a file document asynchronously. Args: file: File to ingest @@ -342,17 +386,23 @@ async def ingest_file( use_colpali: Whether to use ColPali embedding model folder_name: Optional folder to scope the document to end_user_id: Optional end-user ID to scope the document to + redis: Redis connection pool for background tasks Returns: - Document: Metadata of ingested document + Document with processing status that can be used to check progress """ try: + # Parse metadata and rules metadata_dict = json.loads(metadata) rules_list = json.loads(rules) use_colpali = bool(use_colpali) + + # Ensure user has write permission + if "write" not in auth.permissions: + raise PermissionError("User does not have write permission") async with telemetry.track_operation( - operation_type="ingest_file", + operation_type="queue_ingest_file", user_id=auth.entity_id, metadata={ "filename": file.filename, @@ -364,21 +414,97 @@ async def ingest_file( "end_user_id": end_user_id, }, ): - logger.debug(f"API: Ingesting file with use_colpali: {use_colpali}") + logger.debug(f"API: Queueing file ingestion with use_colpali: {use_colpali}") - return await document_service.ingest_file( - file=file, + # Create a document with processing status + doc = Document( + content_type=file.content_type, + filename=file.filename, metadata=metadata_dict, - auth=auth, - rules=rules_list, + owner={"type": auth.entity_type, "id": auth.entity_id}, + access_control={ + "readers": [auth.entity_id], + "writers": [auth.entity_id], + "admins": [auth.entity_id], + "user_id": [auth.user_id] if auth.user_id else [], + }, + system_metadata={"status": "processing"} + ) + + # Add folder_name and end_user_id to system_metadata if provided + if folder_name: + doc.system_metadata["folder_name"] = folder_name + if end_user_id: + doc.system_metadata["end_user_id"] = end_user_id + + # Set processing status + doc.system_metadata["status"] = "processing" + + # Store the document in the database + success = await database.store_document(doc) + if not success: + raise Exception("Failed to store document metadata") + + # Read file content + file_content = await file.read() + + # Generate a unique key for the file + file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}" + + # Store the file in the configured storage + file_content_base64 = base64.b64encode(file_content).decode() + bucket, stored_key = await storage.upload_from_base64( + file_content_base64, + file_key, + file.content_type + ) + logger.debug(f"Stored file in bucket {bucket} with key {stored_key}") + + # Update document with storage info + doc.storage_info = {"bucket": bucket, "key": stored_key} + await database.update_document( + document_id=doc.external_id, + updates={"storage_info": doc.storage_info}, + auth=auth + ) + + # Convert auth context to a dictionary for serialization + auth_dict = { + "entity_type": auth.entity_type.value, + "entity_id": auth.entity_id, + "app_id": auth.app_id, + "permissions": list(auth.permissions), + "user_id": auth.user_id + } + + # Enqueue the background job + job = await redis.enqueue_job( + 'process_ingestion_job', + document_id=doc.external_id, + file_key=stored_key, + bucket=bucket, + original_filename=file.filename, + content_type=file.content_type, + metadata_json=metadata, + auth_dict=auth_dict, + rules_list=rules_list, use_colpali=use_colpali, folder_name=folder_name, - end_user_id=end_user_id, + end_user_id=end_user_id ) + + logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}") + + # Return the document with processing status + return doc + except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) + except Exception as e: + logger.error(f"Error queueing file ingestion: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error queueing file ingestion: {str(e)}") @app.post("/ingest/files", response_model=BatchIngestResponse) @@ -387,13 +513,14 @@ async def batch_ingest_files( metadata: str = Form("{}"), rules: str = Form("[]"), use_colpali: Optional[bool] = Form(None), - parallel: bool = Form(True), + parallel: Optional[bool] = Form(True), folder_name: Optional[str] = Form(None), end_user_id: Optional[str] = Form(None), auth: AuthContext = Depends(verify_token), + redis: arq.ArqRedis = Depends(get_redis_pool), ) -> BatchIngestResponse: """ - Batch ingest multiple files. + Batch ingest multiple files using the task queue. Args: files: List of files to ingest @@ -402,15 +529,15 @@ async def batch_ingest_files( - A single list of rules to apply to all files - A list of rule lists, one per file use_colpali: Whether to use ColPali-style embedding - parallel: Whether to process files in parallel folder_name: Optional folder to scope the documents to end_user_id: Optional end-user ID to scope the documents to auth: Authentication context + redis: Redis connection pool for background tasks Returns: BatchIngestResponse containing: - - documents: List of successfully ingested documents - - errors: List of errors encountered during ingestion + - documents: List of created documents with processing status + - errors: List of errors that occurred during the batch operation """ if not files: raise HTTPException( @@ -421,8 +548,15 @@ async def batch_ingest_files( try: metadata_value = json.loads(metadata) rules_list = json.loads(rules) + use_colpali = bool(use_colpali) + + # Ensure user has write permission + if "write" not in auth.permissions: + raise PermissionError("User does not have write permission") except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + except PermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) # Validate metadata if it's a list if isinstance(metadata_value, list) and len(metadata_value) != len(files): @@ -439,13 +573,19 @@ async def batch_ingest_files( detail=f"Number of rule lists ({len(rules_list)}) must match number of files ({len(files)})" ) - documents = [] - errors = [] + # Convert auth context to a dictionary for serialization + auth_dict = { + "entity_type": auth.entity_type.value, + "entity_id": auth.entity_id, + "app_id": auth.app_id, + "permissions": list(auth.permissions), + "user_id": auth.user_id + } - # We'll pass folder_name and end_user_id directly to the ingest_file functions + created_documents = [] async with telemetry.track_operation( - operation_type="batch_ingest", + operation_type="queue_batch_ingest", user_id=auth.entity_id, metadata={ "file_count": len(files), @@ -455,54 +595,98 @@ async def batch_ingest_files( "end_user_id": end_user_id, }, ): - if parallel: - tasks = [] + try: for i, file in enumerate(files): + # Get the metadata and rules for this file metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list - task = document_service.ingest_file( - file=file, + + # Create a document with processing status + doc = Document( + content_type=file.content_type, + filename=file.filename, metadata=metadata_item, - auth=auth, - rules=file_rules, + owner={"type": auth.entity_type, "id": auth.entity_id}, + access_control={ + "readers": [auth.entity_id], + "writers": [auth.entity_id], + "admins": [auth.entity_id], + "user_id": [auth.user_id] if auth.user_id else [], + }, + ) + + # Add folder_name and end_user_id to system_metadata if provided + if folder_name: + doc.system_metadata["folder_name"] = folder_name + if end_user_id: + doc.system_metadata["end_user_id"] = end_user_id + + # Set processing status + doc.system_metadata["status"] = "processing" + + # Store the document in the database + success = await database.store_document(doc) + if not success: + raise Exception(f"Failed to store document metadata for {file.filename}") + + # Read file content + file_content = await file.read() + + # Generate a unique key for the file + file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}" + + # Store the file in the configured storage + file_content_base64 = base64.b64encode(file_content).decode() + bucket, stored_key = await storage.upload_from_base64( + file_content_base64, + file_key, + file.content_type + ) + logger.debug(f"Stored file in bucket {bucket} with key {stored_key}") + + # Update document with storage info + doc.storage_info = {"bucket": bucket, "key": stored_key} + await database.update_document( + document_id=doc.external_id, + updates={"storage_info": doc.storage_info}, + auth=auth + ) + + # Convert metadata to JSON string for job + metadata_json = json.dumps(metadata_item) + + # Enqueue the background job + job = await redis.enqueue_job( + 'process_ingestion_job', + document_id=doc.external_id, + file_key=stored_key, + bucket=bucket, + original_filename=file.filename, + content_type=file.content_type, + metadata_json=metadata_json, + auth_dict=auth_dict, + rules_list=file_rules, use_colpali=use_colpali, folder_name=folder_name, end_user_id=end_user_id ) - tasks.append(task) + + logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}") + + # Add document to the list + created_documents.append(doc) + + # Return information about created documents + return BatchIngestResponse( + documents=created_documents, + errors=[] + ) - results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, result in enumerate(results): - if isinstance(result, Exception): - errors.append({ - "filename": files[i].filename, - "error": str(result) - }) - else: - documents.append(result) - else: - for i, file in enumerate(files): - try: - metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value - file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list - doc = await document_service.ingest_file( - file=file, - metadata=metadata_item, - auth=auth, - rules=file_rules, - use_colpali=use_colpali, - folder_name=folder_name, - end_user_id=end_user_id - ) - documents.append(doc) - except Exception as e: - errors.append({ - "filename": file.filename, - "error": str(e) - }) + except Exception as e: + logger.error(f"Error queueing batch file ingestion: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error queueing batch file ingestion: {str(e)}") + - return BatchIngestResponse(documents=documents, errors=errors) @app.post("/retrieve/chunks", response_model=List[ChunkResult]) @@ -817,6 +1001,45 @@ async def get_document(document_id: str, auth: AuthContext = Depends(verify_toke except HTTPException as e: logger.error(f"Error getting document: {e}") raise e + +@app.get("/documents/{document_id}/status", response_model=Dict[str, Any]) +async def get_document_status(document_id: str, auth: AuthContext = Depends(verify_token)): + """ + Get the processing status of a document. + + Args: + document_id: ID of the document to check + auth: Authentication context + + Returns: + Dict containing status information for the document + """ + try: + doc = await document_service.db.get_document(document_id, auth) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + # Extract status information + status = doc.system_metadata.get("status", "unknown") + + response = { + "document_id": doc.external_id, + "status": status, + "filename": doc.filename, + "created_at": doc.system_metadata.get("created_at"), + "updated_at": doc.system_metadata.get("updated_at"), + } + + # Add error information if failed + if status == "failed": + response["error"] = doc.system_metadata.get("error", "Unknown error") + + return response + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting document status: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error getting document status: {str(e)}") @app.delete("/documents/{document_id}") diff --git a/core/config.py b/core/config.py index 4ca6dad..037838d 100644 --- a/core/config.py +++ b/core/config.py @@ -95,6 +95,10 @@ class Settings(BaseSettings): # API configuration API_DOMAIN: str = "api.morphik.ai" + # Redis configuration + REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + # Telemetry configuration TELEMETRY_ENABLED: bool = True HONEYCOMB_ENABLED: bool = True @@ -268,6 +272,14 @@ def get_settings() -> Settings: "MODE": config["morphik"].get("mode", "cloud"), # Default to "cloud" mode "API_DOMAIN": config["morphik"].get("api_domain", "api.morphik.ai"), # Default API domain } + + # load redis config + redis_config = {} + if "redis" in config: + redis_config = { + "REDIS_HOST": config["redis"].get("host", "localhost"), + "REDIS_PORT": int(config["redis"].get("port", 6379)), + } # load graph config graph_config = { @@ -309,6 +321,7 @@ def get_settings() -> Settings: vector_store_config, rules_config, morphik_config, + redis_config, graph_config, telemetry_config, openai_config, diff --git a/core/models/documents.py b/core/models/documents.py index 5454ffb..7a40c22 100644 --- a/core/models/documents.py +++ b/core/models/documents.py @@ -46,6 +46,7 @@ class Document(BaseModel): "version": 1, "folder_name": None, "end_user_id": None, + "status": "processing", # Status can be: processing, completed, failed } ) """metadata such as creation date etc.""" diff --git a/core/models/request.py b/core/models/request.py index 6ff3500..84ad4e3 100644 --- a/core/models/request.py +++ b/core/models/request.py @@ -95,6 +95,12 @@ class BatchIngestResponse(BaseModel): """Response model for batch ingestion""" documents: List[Document] errors: List[Dict[str, str]] + +class BatchIngestJobResponse(BaseModel): + """Response model for batch ingestion jobs""" + status: str = Field(..., description="Status of the batch operation") + documents: List[Document] = Field(..., description="List of created documents with processing status") + timestamp: str = Field(..., description="ISO-formatted timestamp") class GenerateUriRequest(BaseModel): diff --git a/core/parser/base_parser.py b/core/parser/base_parser.py index daccc93..e129b00 100644 --- a/core/parser/base_parser.py +++ b/core/parser/base_parser.py @@ -8,7 +8,7 @@ class BaseParser(ABC): @abstractmethod async def parse_file_to_text( - self, file: bytes, content_type: str, filename: str + self, file: bytes, filename: str ) -> Tuple[Dict[str, Any], str]: """ Parse file content into text. diff --git a/core/services/document_service.py b/core/services/document_service.py index f55b500..509c602 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -485,7 +485,8 @@ class DocumentService: logger.debug(f"Successfully stored text document {doc.external_id}") return doc - + + # TODO: check if it's unused. if so, remove it. async def ingest_file( self, file: UploadFile, diff --git a/core/workers/__init__.py b/core/workers/__init__.py new file mode 100644 index 0000000..aade4a2 --- /dev/null +++ b/core/workers/__init__.py @@ -0,0 +1 @@ +# Workers package for background job processing \ No newline at end of file diff --git a/core/workers/ingestion_worker.py b/core/workers/ingestion_worker.py new file mode 100644 index 0000000..0303e60 --- /dev/null +++ b/core/workers/ingestion_worker.py @@ -0,0 +1,411 @@ +import json +import logging +from typing import Dict, Any, List, Optional +from datetime import datetime, UTC +from pathlib import Path + +import arq +from core.models.auth import AuthContext, EntityType +from core.models.documents import Document +from core.database.postgres_database import PostgresDatabase +from core.vector_store.pgvector_store import PGVectorStore +from core.parser.morphik_parser import MorphikParser +from core.embedding.litellm_embedding import LiteLLMEmbeddingModel +from core.completion.litellm_completion import LiteLLMCompletionModel +from core.storage.local_storage import LocalStorage +from core.storage.s3_storage import S3Storage +from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel +from core.vector_store.multi_vector_store import MultiVectorStore +from core.services.document_service import DocumentService +from core.services.telemetry import TelemetryService +from core.services.rules_processor import RulesProcessor +from core.config import get_settings + +logger = logging.getLogger(__name__) + +async def process_ingestion_job( + ctx: Dict[str, Any], + document_id: str, + file_key: str, + bucket: str, + original_filename: str, + content_type: str, + metadata_json: str, + auth_dict: Dict[str, Any], + rules_list: List[Dict[str, Any]], + use_colpali: bool, + folder_name: Optional[str] = None, + end_user_id: Optional[str] = None +) -> Dict[str, Any]: + """ + Background worker task that processes file ingestion jobs. + + Args: + ctx: The ARQ context dictionary + file_key: The storage key where the file is stored + bucket: The storage bucket name + original_filename: The original file name + content_type: The file's content type/MIME type + metadata_json: JSON string of metadata + auth_dict: Dict representation of AuthContext + rules_list: List of rules to apply (already converted to dictionaries) + use_colpali: Whether to use ColPali embedding model + folder_name: Optional folder to scope the document to + end_user_id: Optional end-user ID to scope the document to + + Returns: + A dictionary with the document ID and processing status + """ + try: + # 1. Log the start of the job + logger.info(f"Starting ingestion job for file: {original_filename}") + + # 2. Deserialize metadata and auth + metadata = json.loads(metadata_json) if metadata_json else {} + auth = AuthContext( + entity_type=EntityType(auth_dict.get("entity_type", "unknown")), + entity_id=auth_dict.get("entity_id", ""), + app_id=auth_dict.get("app_id"), + permissions=set(auth_dict.get("permissions", ["read"])), + user_id=auth_dict.get("user_id", auth_dict.get("entity_id", "")) + ) + + # Get document service from the context + document_service : DocumentService = ctx['document_service'] + + # 3. Download the file from storage + logger.info(f"Downloading file from {bucket}/{file_key}") + file_content = await document_service.storage.download_file(bucket, file_key) + + # Ensure file_content is bytes + if hasattr(file_content, 'read'): + file_content = file_content.read() + + # 4. Parse file to text + additional_metadata, text = await document_service.parser.parse_file_to_text( + file_content, original_filename + ) + logger.debug(f"Parsed file into text of length {len(text)}") + + # 5. Apply rules if provided + if rules_list: + rule_metadata, modified_text = await document_service.rules_processor.process_rules(text, rules_list) + # Update document metadata with extracted metadata from rules + metadata.update(rule_metadata) + + if modified_text: + text = modified_text + logger.info("Updated text with modified content from rules") + + # 6. Retrieve the existing document + logger.debug(f"Retrieving document with ID: {document_id}") + logger.debug(f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}") + doc = await document_service.db.get_document(document_id, auth) + + if not doc: + logger.error(f"Document {document_id} not found in database") + logger.error(f"Details - file: {original_filename}, content_type: {content_type}, bucket: {bucket}, key: {file_key}") + logger.error(f"Auth: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}") + # Try to get all accessible documents to debug + try: + all_docs = await document_service.db.get_documents(auth, 0, 100) + logger.debug(f"User has access to {len(all_docs)} documents: {[d.external_id for d in all_docs]}") + except Exception as list_err: + logger.error(f"Failed to list user documents: {str(list_err)}") + + raise ValueError(f"Document {document_id} not found in database") + + # Prepare updates for the document + updates = { + "metadata": metadata, + "additional_metadata": additional_metadata, + "system_metadata": {**doc.system_metadata, "content": text} + } + + # Add folder_name and end_user_id to system_metadata if provided + if folder_name: + updates["system_metadata"]["folder_name"] = folder_name + if end_user_id: + updates["system_metadata"]["end_user_id"] = end_user_id + + # Update the document in the database + success = await document_service.db.update_document( + document_id=document_id, + updates=updates, + auth=auth + ) + + if not success: + raise ValueError(f"Failed to update document {document_id}") + + # Refresh document object with updated data + doc = await document_service.db.get_document(document_id, auth) + logger.debug(f"Updated document in database with parsed content") + + # 7. Split text into chunks + chunks = await document_service.parser.split_text(text) + if not chunks: + raise ValueError("No content chunks extracted") + logger.debug(f"Split processed text into {len(chunks)} chunks") + + # 8. Generate embeddings for chunks + embeddings = await document_service.embedding_model.embed_for_ingestion(chunks) + logger.debug(f"Generated {len(embeddings)} embeddings") + + # 9. Create chunk objects + chunk_objects = document_service._create_chunk_objects(doc.external_id, chunks, embeddings) + logger.debug(f"Created {len(chunk_objects)} chunk objects") + + # 10. Handle ColPali embeddings if enabled + chunk_objects_multivector = [] + if use_colpali and document_service.colpali_embedding_model and document_service.colpali_vector_store: + import filetype + file_type = filetype.guess(file_content) + + # For ColPali we need the base64 encoding of the file + import base64 + file_content_base64 = base64.b64encode(file_content).decode() + + chunks_multivector = document_service._create_chunks_multivector( + file_type, file_content_base64, file_content, chunks + ) + logger.debug(f"Created {len(chunks_multivector)} chunks for multivector embedding") + + colpali_embeddings = await document_service.colpali_embedding_model.embed_for_ingestion( + chunks_multivector + ) + logger.debug(f"Generated {len(colpali_embeddings)} embeddings for multivector embedding") + + chunk_objects_multivector = document_service._create_chunk_objects( + doc.external_id, chunks_multivector, colpali_embeddings + ) + + # Update document status to completed before storing + doc.system_metadata["status"] = "completed" + doc.system_metadata["updated_at"] = datetime.now(UTC) + + # 11. Store chunks and update document with is_update=True + chunk_ids = await document_service._store_chunks_and_doc( + chunk_objects, doc, use_colpali, chunk_objects_multivector, + is_update=True, auth=auth + ) + + logger.debug(f"Successfully completed processing for document {doc.external_id}") + + # 13. Log successful completion + logger.info(f"Successfully completed ingestion for {original_filename}, document ID: {doc.external_id}") + + # 14. Return document ID + return { + "document_id": doc.external_id, + "status": "completed", + "filename": original_filename, + "content_type": content_type, + "timestamp": datetime.now(UTC).isoformat() + } + + except Exception as e: + logger.error(f"Error processing ingestion job for file {original_filename}: {str(e)}") + + # Update document status to failed if the document exists + try: + # Create AuthContext for database operations + auth_context = AuthContext( + entity_type=EntityType(auth_dict.get("entity_type", "unknown")), + entity_id=auth_dict.get("entity_id", ""), + app_id=auth_dict.get("app_id"), + permissions=set(auth_dict.get("permissions", ["read"])), + user_id=auth_dict.get("user_id", auth_dict.get("entity_id", "")) + ) + + # Get database from context + database = ctx.get('database') + + if database: + # Try to get the document + doc = await database.get_document(document_id, auth_context) + + if doc: + # Update the document status to failed + await database.update_document( + document_id=document_id, + updates={ + "system_metadata": { + **doc.system_metadata, + "status": "failed", + "error": str(e), + "updated_at": datetime.now(UTC) + } + }, + auth=auth_context + ) + logger.info(f"Updated document {document_id} status to failed") + except Exception as inner_e: + logger.error(f"Failed to update document status: {str(inner_e)}") + + # Return error information + return { + "status": "failed", + "filename": original_filename, + "error": str(e), + "timestamp": datetime.now(UTC).isoformat() + } + +async def startup(ctx): + """ + Worker startup: Initialize all necessary services that will be reused across jobs. + + This initialization is similar to what happens in core/api.py during app startup, + but adapted for the worker context. + """ + logger.info("Worker starting up. Initializing services...") + + # Get settings + settings = get_settings() + + # Initialize database + logger.info("Initializing database...") + database = PostgresDatabase(uri=settings.POSTGRES_URI) + success = await database.initialize() + if success: + logger.info("Database initialization successful") + else: + logger.error("Database initialization failed") + ctx['database'] = database + + # Initialize vector store + logger.info("Initializing primary vector store...") + vector_store = PGVectorStore(uri=settings.POSTGRES_URI) + success = await vector_store.initialize() + if success: + logger.info("Primary vector store initialization successful") + else: + logger.error("Primary vector store initialization failed") + ctx['vector_store'] = vector_store + + # Initialize storage + if settings.STORAGE_PROVIDER == "local": + storage = LocalStorage(storage_path=settings.STORAGE_PATH) + elif settings.STORAGE_PROVIDER == "aws-s3": + storage = S3Storage( + aws_access_key=settings.AWS_ACCESS_KEY, + aws_secret_key=settings.AWS_SECRET_ACCESS_KEY, + region_name=settings.AWS_REGION, + default_bucket=settings.S3_BUCKET, + ) + else: + raise ValueError(f"Unsupported storage provider: {settings.STORAGE_PROVIDER}") + ctx['storage'] = storage + + # Initialize parser + parser = MorphikParser( + chunk_size=settings.CHUNK_SIZE, + chunk_overlap=settings.CHUNK_OVERLAP, + use_unstructured_api=settings.USE_UNSTRUCTURED_API, + unstructured_api_key=settings.UNSTRUCTURED_API_KEY, + assemblyai_api_key=settings.ASSEMBLYAI_API_KEY, + anthropic_api_key=settings.ANTHROPIC_API_KEY, + use_contextual_chunking=settings.USE_CONTEXTUAL_CHUNKING, + ) + ctx['parser'] = parser + + # Initialize embedding model + embedding_model = LiteLLMEmbeddingModel(model_key=settings.EMBEDDING_MODEL) + logger.info(f"Initialized LiteLLM embedding model with model key: {settings.EMBEDDING_MODEL}") + ctx['embedding_model'] = embedding_model + + # Initialize completion model + completion_model = LiteLLMCompletionModel(model_key=settings.COMPLETION_MODEL) + logger.info(f"Initialized LiteLLM completion model with model key: {settings.COMPLETION_MODEL}") + ctx['completion_model'] = completion_model + + # Initialize reranker + reranker = None + if settings.USE_RERANKING: + if settings.RERANKER_PROVIDER == "flag": + from core.reranker.flag_reranker import FlagReranker + reranker = FlagReranker( + model_name=settings.RERANKER_MODEL, + device=settings.RERANKER_DEVICE, + use_fp16=settings.RERANKER_USE_FP16, + query_max_length=settings.RERANKER_QUERY_MAX_LENGTH, + passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH, + ) + else: + logger.warning(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}") + ctx['reranker'] = reranker + + # Initialize ColPali embedding model and vector store if enabled + colpali_embedding_model = None + colpali_vector_store = None + if settings.ENABLE_COLPALI: + logger.info("Initializing ColPali components...") + colpali_embedding_model = ColpaliEmbeddingModel() + colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI) + _ = colpali_vector_store.initialize() + ctx['colpali_embedding_model'] = colpali_embedding_model + ctx['colpali_vector_store'] = colpali_vector_store + + # Initialize cache factory for DocumentService (may not be used for ingestion) + from core.cache.llama_cache_factory import LlamaCacheFactory + cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH)) + ctx['cache_factory'] = cache_factory + + # Initialize rules processor + rules_processor = RulesProcessor() + ctx['rules_processor'] = rules_processor + + # Initialize telemetry service + telemetry = TelemetryService() + ctx['telemetry'] = telemetry + + # Create the document service using all initialized components + document_service = DocumentService( + storage=storage, + database=database, + vector_store=vector_store, + embedding_model=embedding_model, + completion_model=completion_model, + parser=parser, + reranker=reranker, + cache_factory=cache_factory, + enable_colpali=settings.ENABLE_COLPALI, + colpali_embedding_model=colpali_embedding_model, + colpali_vector_store=colpali_vector_store, + ) + ctx['document_service'] = document_service + + logger.info("Worker startup complete. All services initialized.") + +async def shutdown(ctx): + """ + Worker shutdown: Clean up resources. + + Properly close connections and cleanup resources to prevent leaks. + """ + logger.info("Worker shutting down. Cleaning up resources...") + + # Close database connections + if 'database' in ctx and hasattr(ctx['database'], 'engine'): + logger.info("Closing database connections...") + await ctx['database'].engine.dispose() + + # Close any other open connections or resources that need cleanup + logger.info("Worker shutdown complete.") + +# ARQ Worker Settings +class WorkerSettings: + """ + ARQ Worker settings for the ingestion worker. + + This defines the functions available to the worker, startup and shutdown handlers, + and any specific Redis settings. + """ + functions = [process_ingestion_job] + on_startup = startup + on_shutdown = shutdown + # Redis settings will be loaded from environment variables by default + # Other optional settings: + # redis_settings = arq.connections.RedisSettings(host='localhost', port=6379) + keep_result_ms = 24 * 60 * 60 * 1000 # Keep results for 24 hours (24 * 60 * 60 * 1000 ms) + max_jobs = 10 # Maximum number of jobs to run concurrently \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 410a07d..121e21a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -26,6 +26,8 @@ services: - HOST=0.0.0.0 - PORT=8000 - LOG_LEVEL=DEBUG + - REDIS_HOST=redis + - REDIS_PORT=6379 volumes: - ./storage:/app/storage - ./logs:/app/logs @@ -34,6 +36,8 @@ services: depends_on: postgres: condition: service_healthy + redis: + condition: service_healthy config-check: condition: service_completed_successfully ollama: @@ -43,6 +47,51 @@ services: - morphik-network env_file: - .env + + worker: + build: . + command: arq core.workers.ingestion_worker.WorkerSettings + environment: + - JWT_SECRET_KEY=${JWT_SECRET_KEY:-your-secret-key-here} + - POSTGRES_URI=postgresql+asyncpg://morphik:morphik@postgres:5432/morphik + - PGPASSWORD=morphik + - LOG_LEVEL=DEBUG + - REDIS_HOST=redis + - REDIS_PORT=6379 + volumes: + - ./storage:/app/storage + - ./logs:/app/logs + - ./morphik.toml:/app/morphik.toml + - huggingface_cache:/root/.cache/huggingface + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + config-check: + condition: service_completed_successfully + ollama: + condition: service_started + required: false + networks: + - morphik-network + env_file: + - .env + + redis: + image: redis:7-alpine + ports: + - "6379:6379" + volumes: + - redis_data:/data + command: redis-server --appendonly yes + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 3s + retries: 5 + networks: + - morphik-network postgres: build: @@ -87,4 +136,5 @@ networks: volumes: postgres_data: ollama_data: - huggingface_cache: + huggingface_cache: + redis_data: diff --git a/morphik.toml b/morphik.toml index 71b86a0..8f099cd 100644 --- a/morphik.toml +++ b/morphik.toml @@ -100,6 +100,10 @@ enable_colpali = true mode = "self_hosted" # "cloud" or "self_hosted" api_domain = "api.morphik.ai" # API domain for cloud URIs +[redis] +host = "localhost" +port = 6379 + [graph] model = "ollama_llama" enable_entity_resolution = true diff --git a/requirements.txt b/requirements.txt index 7800ea7..02518c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -340,3 +340,4 @@ zlib-state==0.1.9 zstandard==0.23.0 litellm==1.65.4.post1 instructor==1.7.9 +arq==0.25.0 diff --git a/sdks/python/morphik/async_.py b/sdks/python/morphik/async_.py index 7b14816..a73d1b7 100644 --- a/sdks/python/morphik/async_.py +++ b/sdks/python/morphik/async_.py @@ -1490,6 +1490,76 @@ class AsyncMorphik: doc = self._logic._parse_document_response(response) doc._client = self return doc + + async def get_document_status(self, document_id: str) -> Dict[str, Any]: + """ + Get the current processing status of a document. + + Args: + document_id: ID of the document to check + + Returns: + Dict[str, Any]: Status information including current status, potential errors, and other metadata + + Example: + ```python + status = await db.get_document_status("doc_123") + if status["status"] == "completed": + print("Document processing complete") + elif status["status"] == "failed": + print(f"Processing failed: {status['error']}") + else: + print("Document still processing...") + ``` + """ + response = await self._request("GET", f"documents/{document_id}/status") + return response + + async def wait_for_document_completion(self, document_id: str, timeout_seconds=300, check_interval_seconds=2) -> Document: + """ + Wait for a document's processing to complete. + + Args: + document_id: ID of the document to wait for + timeout_seconds: Maximum time to wait for completion (default: 300 seconds) + check_interval_seconds: Time between status checks (default: 2 seconds) + + Returns: + Document: Updated document with the latest status + + Raises: + TimeoutError: If processing doesn't complete within the timeout period + ValueError: If processing fails with an error + + Example: + ```python + # Upload a file and wait for processing to complete + doc = await db.ingest_file("large_document.pdf") + try: + completed_doc = await db.wait_for_document_completion(doc.external_id) + print(f"Processing complete! Document has {len(completed_doc.chunk_ids)} chunks") + except TimeoutError: + print("Processing is taking too long") + except ValueError as e: + print(f"Processing failed: {e}") + ``` + """ + import asyncio + start_time = asyncio.get_event_loop().time() + + while (asyncio.get_event_loop().time() - start_time) < timeout_seconds: + status = await self.get_document_status(document_id) + + if status["status"] == "completed": + # Get the full document now that it's complete + return await self.get_document(document_id) + elif status["status"] == "failed": + raise ValueError(f"Document processing failed: {status.get('error', 'Unknown error')}") + + # Wait before checking again + await asyncio.sleep(check_interval_seconds) + + raise TimeoutError(f"Document processing did not complete within {timeout_seconds} seconds") async def get_document_by_filename(self, filename: str) -> Document: """ diff --git a/sdks/python/morphik/models.py b/sdks/python/morphik/models.py index 2ef449f..8b26618 100644 --- a/sdks/python/morphik/models.py +++ b/sdks/python/morphik/models.py @@ -24,6 +24,60 @@ class Document(BaseModel): # Client reference for update methods _client = None + + @property + def status(self) -> Dict[str, Any]: + """Get the latest processing status of the document from the API. + + Returns: + Dict[str, Any]: Status information including current status, potential errors, and other metadata + """ + if self._client is None: + raise ValueError( + "Document instance not connected to a client. Use a document returned from a Morphik client method." + ) + return self._client.get_document_status(self.external_id) + + @property + def is_processing(self) -> bool: + """Check if the document is still being processed.""" + return self.status.get("status") == "processing" + + @property + def is_ingested(self) -> bool: + """Check if the document has completed processing.""" + return self.status.get("status") == "completed" + + @property + def is_failed(self) -> bool: + """Check if document processing has failed.""" + return self.status.get("status") == "failed" + + @property + def error(self) -> Optional[str]: + """Get the error message if processing failed.""" + status_info = self.status + return status_info.get("error") if status_info.get("status") == "failed" else None + + def wait_for_completion(self, timeout_seconds=300, check_interval_seconds=2): + """Wait for document processing to complete. + + Args: + timeout_seconds: Maximum time to wait for completion (default: 300 seconds) + check_interval_seconds: Time between status checks (default: 2 seconds) + + Returns: + Document: Updated document with the latest status + + Raises: + TimeoutError: If processing doesn't complete within the timeout period + ValueError: If processing fails with an error + """ + if self._client is None: + raise ValueError( + "Document instance not connected to a client. Use a document returned from a Morphik client method." + ) + return self._client.wait_for_document_completion(self.external_id, timeout_seconds, check_interval_seconds) def update_with_text( self, diff --git a/sdks/python/morphik/sync.py b/sdks/python/morphik/sync.py index a084983..89e8371 100644 --- a/sdks/python/morphik/sync.py +++ b/sdks/python/morphik/sync.py @@ -1618,6 +1618,76 @@ class Morphik: doc = self._logic._parse_document_response(response) doc._client = self return doc + + def get_document_status(self, document_id: str) -> Dict[str, Any]: + """ + Get the current processing status of a document. + + Args: + document_id: ID of the document to check + + Returns: + Dict[str, Any]: Status information including current status, potential errors, and other metadata + + Example: + ```python + status = db.get_document_status("doc_123") + if status["status"] == "completed": + print("Document processing complete") + elif status["status"] == "failed": + print(f"Processing failed: {status['error']}") + else: + print("Document still processing...") + ``` + """ + response = self._request("GET", f"documents/{document_id}/status") + return response + + def wait_for_document_completion(self, document_id: str, timeout_seconds=300, check_interval_seconds=2) -> Document: + """ + Wait for a document's processing to complete. + + Args: + document_id: ID of the document to wait for + timeout_seconds: Maximum time to wait for completion (default: 300 seconds) + check_interval_seconds: Time between status checks (default: 2 seconds) + + Returns: + Document: Updated document with the latest status + + Raises: + TimeoutError: If processing doesn't complete within the timeout period + ValueError: If processing fails with an error + + Example: + ```python + # Upload a file and wait for processing to complete + doc = db.ingest_file("large_document.pdf") + try: + completed_doc = db.wait_for_document_completion(doc.external_id) + print(f"Processing complete! Document has {len(completed_doc.chunk_ids)} chunks") + except TimeoutError: + print("Processing is taking too long") + except ValueError as e: + print(f"Processing failed: {e}") + ``` + """ + import time + start_time = time.time() + + while (time.time() - start_time) < timeout_seconds: + status = self.get_document_status(document_id) + + if status["status"] == "completed": + # Get the full document now that it's complete + return self.get_document(document_id) + elif status["status"] == "failed": + raise ValueError(f"Document processing failed: {status.get('error', 'Unknown error')}") + + # Wait before checking again + time.sleep(check_interval_seconds) + + raise TimeoutError(f"Document processing did not complete within {timeout_seconds} seconds") def get_document_by_filename(self, filename: str) -> Document: """ diff --git a/start_server.py b/start_server.py index f98ff51..7e54318 100644 --- a/start_server.py +++ b/start_server.py @@ -4,10 +4,140 @@ import sys import tomli import requests import logging +import subprocess +import signal +import os +import atexit from dotenv import load_dotenv from core.config import get_settings from core.logging_config import setup_logging +# Global variable to store the worker process +worker_process = None + +def check_and_start_redis(): + """Check if the Redis container is running, start if necessary.""" + try: + # Check if container exists and is running + check_running_cmd = ["docker", "ps", "-q", "-f", "name=morphik-redis"] + running_container = subprocess.check_output(check_running_cmd).strip() + + if running_container: + logging.info("Redis container (morphik-redis) is already running.") + return + + # Check if container exists but is stopped + check_exists_cmd = ["docker", "ps", "-a", "-q", "-f", "name=morphik-redis"] + existing_container = subprocess.check_output(check_exists_cmd).strip() + + if existing_container: + logging.info("Starting existing Redis container (morphik-redis)...") + subprocess.run(["docker", "start", "morphik-redis"], check=True, capture_output=True) + logging.info("Redis container started.") + else: + logging.info("Creating and starting Redis container (morphik-redis)...") + subprocess.run( + ["docker", "run", "-d", "--name", "morphik-redis", "-p", "6379:6379", "redis"], + check=True, + capture_output=True, + ) + logging.info("Redis container created and started.") + + except subprocess.CalledProcessError as e: + logging.error(f"Failed to manage Redis container: {e}") + logging.error(f"Stderr: {e.stderr.decode() if e.stderr else 'N/A'}") + sys.exit(1) + except FileNotFoundError: + logging.error("Docker command not found. Please ensure Docker is installed and in PATH.") + sys.exit(1) + + +def start_arq_worker(): + """Start the ARQ worker as a subprocess.""" + global worker_process + try: + logging.info("Starting ARQ worker...") + + # Ensure logs directory exists + log_dir = os.path.join(os.getcwd(), "logs") + os.makedirs(log_dir, exist_ok=True) + + # Worker log file paths + worker_log_path = os.path.join(log_dir, "worker.log") + + # Open log files + worker_log = open(worker_log_path, "a") + + # Add timestamp to log + timestamp = subprocess.check_output(["date"]).decode().strip() + worker_log.write(f"\n\n--- Worker started at {timestamp} ---\n\n") + worker_log.flush() + + # Use sys.executable to ensure the same Python environment is used + worker_cmd = [sys.executable, "-m", "arq", "core.workers.ingestion_worker.WorkerSettings"] + + # Start the worker with output redirected to log files + worker_process = subprocess.Popen( + worker_cmd, + stdout=worker_log, + stderr=worker_log, + env=dict(os.environ, PYTHONUNBUFFERED="1") # Ensure unbuffered output + ) + logging.info(f"ARQ worker started with PID: {worker_process.pid}") + logging.info(f"Worker logs available at: {worker_log_path}") + except Exception as e: + logging.error(f"Failed to start ARQ worker: {e}") + sys.exit(1) + + +def cleanup_processes(): + """Stop the ARQ worker process on exit.""" + global worker_process + if worker_process and worker_process.poll() is None: # Check if process is still running + logging.info(f"Stopping ARQ worker (PID: {worker_process.pid})...") + + # Log the worker termination + try: + log_dir = os.path.join(os.getcwd(), "logs") + worker_log_path = os.path.join(log_dir, "worker.log") + + with open(worker_log_path, "a") as worker_log: + timestamp = subprocess.check_output(["date"]).decode().strip() + worker_log.write(f"\n\n--- Worker stopping at {timestamp} ---\n\n") + except Exception as e: + logging.warning(f"Could not write worker stop message to log: {e}") + + # Send SIGTERM first for graceful shutdown + worker_process.terminate() + try: + # Wait a bit for graceful shutdown + worker_process.wait(timeout=5) + logging.info("ARQ worker stopped gracefully.") + except subprocess.TimeoutExpired: + logging.warning("ARQ worker did not terminate gracefully, sending SIGKILL.") + worker_process.kill() # Force kill if it doesn't stop + logging.info("ARQ worker killed.") + + # Close any open file descriptors for the process + if hasattr(worker_process, 'stdout') and worker_process.stdout: + worker_process.stdout.close() + if hasattr(worker_process, 'stderr') and worker_process.stderr: + worker_process.stderr.close() + + # Optional: Add Redis container stop logic here if desired + # try: + # logging.info("Stopping Redis container...") + # subprocess.run(["docker", "stop", "morphik-redis"], check=False, capture_output=True) + # except Exception as e: + # logging.warning(f"Could not stop Redis container: {e}") + + +# Register the cleanup function to be called on script exit +atexit.register(cleanup_processes) +# Also register for SIGINT (Ctrl+C) and SIGTERM +signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0)) +signal.signal(signal.SIGTERM, lambda sig, frame: sys.exit(0)) + def check_ollama_running(base_url): """Check if Ollama is running and accessible at the given URL.""" @@ -98,6 +228,9 @@ def main(): # Set up logging first with specified level setup_logging(log_level=args.log.upper()) + # Check and start Redis container + check_and_start_redis() + # Load environment variables from .env file load_dotenv() @@ -133,14 +266,18 @@ def main(): # Load settings (this will validate all required env vars) settings = get_settings() - # Start server + # Start ARQ worker in the background + start_arq_worker() + + # Start server (this is blocking) + logging.info("Starting Uvicorn server...") uvicorn.run( "core.api:app", host=settings.HOST, port=settings.PORT, loop="asyncio", log_level=args.log, - # reload=settings.RELOAD + # reload=settings.RELOAD # Reload might interfere with subprocess management )