add task queue (#87)

* add task queue

* ensure task queuing is working as expected.

* add downstream sdk changes

* bugs and pr comments

* update docker arq running logic
This commit is contained in:
Arnav Agrawal 2025-04-15 23:31:49 -07:00 committed by GitHub
parent 7fe2e19a81
commit 3c1195e001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1103 additions and 61 deletions

View File

@ -1,5 +1,7 @@
import asyncio
import json
import base64
import uuid
from datetime import datetime, UTC, timedelta
from pathlib import Path
import sys
@ -8,6 +10,7 @@ from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile, F
from fastapi.middleware.cors import CORSMiddleware
import jwt
import logging
import arq
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.limits_utils import check_and_increment_limits
from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, UpdateGraphRequest, BatchIngestResponse
@ -79,6 +82,8 @@ if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
database = PostgresDatabase(uri=settings.POSTGRES_URI)
# Redis settings already imported at top of file
@app.on_event("startup")
async def initialize_database():
@ -129,6 +134,37 @@ async def initialize_user_limits_database():
user_limits_db = UserLimitsDatabase(uri=settings.POSTGRES_URI)
await user_limits_db.initialize()
@app.on_event("startup")
async def initialize_redis_pool():
"""Initialize the Redis connection pool for background tasks."""
global redis_pool
logger.info("Initializing Redis connection pool...")
# Get Redis settings from configuration
redis_host = settings.REDIS_HOST
redis_port = settings.REDIS_PORT
# Log the Redis connection details
logger.info(f"Connecting to Redis at {redis_host}:{redis_port}")
redis_settings = arq.connections.RedisSettings(
host=redis_host,
port=redis_port,
)
redis_pool = await arq.create_pool(redis_settings)
logger.info("Redis connection pool initialized successfully")
@app.on_event("shutdown")
async def close_redis_pool():
"""Close the Redis connection pool on application shutdown."""
global redis_pool
if redis_pool:
logger.info("Closing Redis connection pool...")
redis_pool.close()
await redis_pool.wait_closed()
logger.info("Redis connection pool closed")
# Initialize vector store
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
@ -319,6 +355,13 @@ async def ingest_text(
raise HTTPException(status_code=403, detail=str(e))
# Redis pool for background tasks
redis_pool = None
def get_redis_pool():
"""Get the global Redis connection pool for background tasks."""
return redis_pool
@app.post("/ingest/file", response_model=Document)
async def ingest_file(
file: UploadFile,
@ -328,9 +371,10 @@ async def ingest_file(
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = Form(None),
end_user_id: Optional[str] = Form(None),
redis: arq.ArqRedis = Depends(get_redis_pool),
) -> Document:
"""
Ingest a file document.
Ingest a file document asynchronously.
Args:
file: File to ingest
@ -342,17 +386,23 @@ async def ingest_file(
use_colpali: Whether to use ColPali embedding model
folder_name: Optional folder to scope the document to
end_user_id: Optional end-user ID to scope the document to
redis: Redis connection pool for background tasks
Returns:
Document: Metadata of ingested document
Document with processing status that can be used to check progress
"""
try:
# Parse metadata and rules
metadata_dict = json.loads(metadata)
rules_list = json.loads(rules)
use_colpali = bool(use_colpali)
# Ensure user has write permission
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
async with telemetry.track_operation(
operation_type="ingest_file",
operation_type="queue_ingest_file",
user_id=auth.entity_id,
metadata={
"filename": file.filename,
@ -364,21 +414,97 @@ async def ingest_file(
"end_user_id": end_user_id,
},
):
logger.debug(f"API: Ingesting file with use_colpali: {use_colpali}")
logger.debug(f"API: Queueing file ingestion with use_colpali: {use_colpali}")
return await document_service.ingest_file(
file=file,
# Create a document with processing status
doc = Document(
content_type=file.content_type,
filename=file.filename,
metadata=metadata_dict,
auth=auth,
rules=rules_list,
owner={"type": auth.entity_type, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [],
},
system_metadata={"status": "processing"}
)
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
doc.system_metadata["folder_name"] = folder_name
if end_user_id:
doc.system_metadata["end_user_id"] = end_user_id
# Set processing status
doc.system_metadata["status"] = "processing"
# Store the document in the database
success = await database.store_document(doc)
if not success:
raise Exception("Failed to store document metadata")
# Read file content
file_content = await file.read()
# Generate a unique key for the file
file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}"
# Store the file in the configured storage
file_content_base64 = base64.b64encode(file_content).decode()
bucket, stored_key = await storage.upload_from_base64(
file_content_base64,
file_key,
file.content_type
)
logger.debug(f"Stored file in bucket {bucket} with key {stored_key}")
# Update document with storage info
doc.storage_info = {"bucket": bucket, "key": stored_key}
await database.update_document(
document_id=doc.external_id,
updates={"storage_info": doc.storage_info},
auth=auth
)
# Convert auth context to a dictionary for serialization
auth_dict = {
"entity_type": auth.entity_type.value,
"entity_id": auth.entity_id,
"app_id": auth.app_id,
"permissions": list(auth.permissions),
"user_id": auth.user_id
}
# Enqueue the background job
job = await redis.enqueue_job(
'process_ingestion_job',
document_id=doc.external_id,
file_key=stored_key,
bucket=bucket,
original_filename=file.filename,
content_type=file.content_type,
metadata_json=metadata,
auth_dict=auth_dict,
rules_list=rules_list,
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id,
end_user_id=end_user_id
)
logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}")
# Return the document with processing status
return doc
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error queueing file ingestion: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error queueing file ingestion: {str(e)}")
@app.post("/ingest/files", response_model=BatchIngestResponse)
@ -387,13 +513,14 @@ async def batch_ingest_files(
metadata: str = Form("{}"),
rules: str = Form("[]"),
use_colpali: Optional[bool] = Form(None),
parallel: bool = Form(True),
parallel: Optional[bool] = Form(True),
folder_name: Optional[str] = Form(None),
end_user_id: Optional[str] = Form(None),
auth: AuthContext = Depends(verify_token),
redis: arq.ArqRedis = Depends(get_redis_pool),
) -> BatchIngestResponse:
"""
Batch ingest multiple files.
Batch ingest multiple files using the task queue.
Args:
files: List of files to ingest
@ -402,15 +529,15 @@ async def batch_ingest_files(
- A single list of rules to apply to all files
- A list of rule lists, one per file
use_colpali: Whether to use ColPali-style embedding
parallel: Whether to process files in parallel
folder_name: Optional folder to scope the documents to
end_user_id: Optional end-user ID to scope the documents to
auth: Authentication context
redis: Redis connection pool for background tasks
Returns:
BatchIngestResponse containing:
- documents: List of successfully ingested documents
- errors: List of errors encountered during ingestion
- documents: List of created documents with processing status
- errors: List of errors that occurred during the batch operation
"""
if not files:
raise HTTPException(
@ -421,8 +548,15 @@ async def batch_ingest_files(
try:
metadata_value = json.loads(metadata)
rules_list = json.loads(rules)
use_colpali = bool(use_colpali)
# Ensure user has write permission
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
# Validate metadata if it's a list
if isinstance(metadata_value, list) and len(metadata_value) != len(files):
@ -439,13 +573,19 @@ async def batch_ingest_files(
detail=f"Number of rule lists ({len(rules_list)}) must match number of files ({len(files)})"
)
documents = []
errors = []
# Convert auth context to a dictionary for serialization
auth_dict = {
"entity_type": auth.entity_type.value,
"entity_id": auth.entity_id,
"app_id": auth.app_id,
"permissions": list(auth.permissions),
"user_id": auth.user_id
}
# We'll pass folder_name and end_user_id directly to the ingest_file functions
created_documents = []
async with telemetry.track_operation(
operation_type="batch_ingest",
operation_type="queue_batch_ingest",
user_id=auth.entity_id,
metadata={
"file_count": len(files),
@ -455,54 +595,98 @@ async def batch_ingest_files(
"end_user_id": end_user_id,
},
):
if parallel:
tasks = []
try:
for i, file in enumerate(files):
# Get the metadata and rules for this file
metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value
file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list
task = document_service.ingest_file(
file=file,
# Create a document with processing status
doc = Document(
content_type=file.content_type,
filename=file.filename,
metadata=metadata_item,
auth=auth,
rules=file_rules,
owner={"type": auth.entity_type, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [],
},
)
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
doc.system_metadata["folder_name"] = folder_name
if end_user_id:
doc.system_metadata["end_user_id"] = end_user_id
# Set processing status
doc.system_metadata["status"] = "processing"
# Store the document in the database
success = await database.store_document(doc)
if not success:
raise Exception(f"Failed to store document metadata for {file.filename}")
# Read file content
file_content = await file.read()
# Generate a unique key for the file
file_key = f"ingest_uploads/{uuid.uuid4()}/{file.filename}"
# Store the file in the configured storage
file_content_base64 = base64.b64encode(file_content).decode()
bucket, stored_key = await storage.upload_from_base64(
file_content_base64,
file_key,
file.content_type
)
logger.debug(f"Stored file in bucket {bucket} with key {stored_key}")
# Update document with storage info
doc.storage_info = {"bucket": bucket, "key": stored_key}
await database.update_document(
document_id=doc.external_id,
updates={"storage_info": doc.storage_info},
auth=auth
)
# Convert metadata to JSON string for job
metadata_json = json.dumps(metadata_item)
# Enqueue the background job
job = await redis.enqueue_job(
'process_ingestion_job',
document_id=doc.external_id,
file_key=stored_key,
bucket=bucket,
original_filename=file.filename,
content_type=file.content_type,
metadata_json=metadata_json,
auth_dict=auth_dict,
rules_list=file_rules,
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id
)
tasks.append(task)
logger.info(f"File ingestion job queued with ID: {job.job_id} for document: {doc.external_id}")
# Add document to the list
created_documents.append(doc)
# Return information about created documents
return BatchIngestResponse(
documents=created_documents,
errors=[]
)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
errors.append({
"filename": files[i].filename,
"error": str(result)
})
else:
documents.append(result)
else:
for i, file in enumerate(files):
try:
metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value
file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list
doc = await document_service.ingest_file(
file=file,
metadata=metadata_item,
auth=auth,
rules=file_rules,
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id
)
documents.append(doc)
except Exception as e:
errors.append({
"filename": file.filename,
"error": str(e)
})
except Exception as e:
logger.error(f"Error queueing batch file ingestion: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error queueing batch file ingestion: {str(e)}")
return BatchIngestResponse(documents=documents, errors=errors)
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
@ -817,6 +1001,45 @@ async def get_document(document_id: str, auth: AuthContext = Depends(verify_toke
except HTTPException as e:
logger.error(f"Error getting document: {e}")
raise e
@app.get("/documents/{document_id}/status", response_model=Dict[str, Any])
async def get_document_status(document_id: str, auth: AuthContext = Depends(verify_token)):
"""
Get the processing status of a document.
Args:
document_id: ID of the document to check
auth: Authentication context
Returns:
Dict containing status information for the document
"""
try:
doc = await document_service.db.get_document(document_id, auth)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Extract status information
status = doc.system_metadata.get("status", "unknown")
response = {
"document_id": doc.external_id,
"status": status,
"filename": doc.filename,
"created_at": doc.system_metadata.get("created_at"),
"updated_at": doc.system_metadata.get("updated_at"),
}
# Add error information if failed
if status == "failed":
response["error"] = doc.system_metadata.get("error", "Unknown error")
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting document status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error getting document status: {str(e)}")
@app.delete("/documents/{document_id}")

View File

@ -95,6 +95,10 @@ class Settings(BaseSettings):
# API configuration
API_DOMAIN: str = "api.morphik.ai"
# Redis configuration
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
# Telemetry configuration
TELEMETRY_ENABLED: bool = True
HONEYCOMB_ENABLED: bool = True
@ -268,6 +272,14 @@ def get_settings() -> Settings:
"MODE": config["morphik"].get("mode", "cloud"), # Default to "cloud" mode
"API_DOMAIN": config["morphik"].get("api_domain", "api.morphik.ai"), # Default API domain
}
# load redis config
redis_config = {}
if "redis" in config:
redis_config = {
"REDIS_HOST": config["redis"].get("host", "localhost"),
"REDIS_PORT": int(config["redis"].get("port", 6379)),
}
# load graph config
graph_config = {
@ -309,6 +321,7 @@ def get_settings() -> Settings:
vector_store_config,
rules_config,
morphik_config,
redis_config,
graph_config,
telemetry_config,
openai_config,

View File

@ -46,6 +46,7 @@ class Document(BaseModel):
"version": 1,
"folder_name": None,
"end_user_id": None,
"status": "processing", # Status can be: processing, completed, failed
}
)
"""metadata such as creation date etc."""

View File

@ -95,6 +95,12 @@ class BatchIngestResponse(BaseModel):
"""Response model for batch ingestion"""
documents: List[Document]
errors: List[Dict[str, str]]
class BatchIngestJobResponse(BaseModel):
"""Response model for batch ingestion jobs"""
status: str = Field(..., description="Status of the batch operation")
documents: List[Document] = Field(..., description="List of created documents with processing status")
timestamp: str = Field(..., description="ISO-formatted timestamp")
class GenerateUriRequest(BaseModel):

View File

@ -8,7 +8,7 @@ class BaseParser(ABC):
@abstractmethod
async def parse_file_to_text(
self, file: bytes, content_type: str, filename: str
self, file: bytes, filename: str
) -> Tuple[Dict[str, Any], str]:
"""
Parse file content into text.

View File

@ -485,7 +485,8 @@ class DocumentService:
logger.debug(f"Successfully stored text document {doc.external_id}")
return doc
# TODO: check if it's unused. if so, remove it.
async def ingest_file(
self,
file: UploadFile,

1
core/workers/__init__.py Normal file
View File

@ -0,0 +1 @@
# Workers package for background job processing

View File

@ -0,0 +1,411 @@
import json
import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, UTC
from pathlib import Path
import arq
from core.models.auth import AuthContext, EntityType
from core.models.documents import Document
from core.database.postgres_database import PostgresDatabase
from core.vector_store.pgvector_store import PGVectorStore
from core.parser.morphik_parser import MorphikParser
from core.embedding.litellm_embedding import LiteLLMEmbeddingModel
from core.completion.litellm_completion import LiteLLMCompletionModel
from core.storage.local_storage import LocalStorage
from core.storage.s3_storage import S3Storage
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore
from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService
from core.services.rules_processor import RulesProcessor
from core.config import get_settings
logger = logging.getLogger(__name__)
async def process_ingestion_job(
ctx: Dict[str, Any],
document_id: str,
file_key: str,
bucket: str,
original_filename: str,
content_type: str,
metadata_json: str,
auth_dict: Dict[str, Any],
rules_list: List[Dict[str, Any]],
use_colpali: bool,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Background worker task that processes file ingestion jobs.
Args:
ctx: The ARQ context dictionary
file_key: The storage key where the file is stored
bucket: The storage bucket name
original_filename: The original file name
content_type: The file's content type/MIME type
metadata_json: JSON string of metadata
auth_dict: Dict representation of AuthContext
rules_list: List of rules to apply (already converted to dictionaries)
use_colpali: Whether to use ColPali embedding model
folder_name: Optional folder to scope the document to
end_user_id: Optional end-user ID to scope the document to
Returns:
A dictionary with the document ID and processing status
"""
try:
# 1. Log the start of the job
logger.info(f"Starting ingestion job for file: {original_filename}")
# 2. Deserialize metadata and auth
metadata = json.loads(metadata_json) if metadata_json else {}
auth = AuthContext(
entity_type=EntityType(auth_dict.get("entity_type", "unknown")),
entity_id=auth_dict.get("entity_id", ""),
app_id=auth_dict.get("app_id"),
permissions=set(auth_dict.get("permissions", ["read"])),
user_id=auth_dict.get("user_id", auth_dict.get("entity_id", ""))
)
# Get document service from the context
document_service : DocumentService = ctx['document_service']
# 3. Download the file from storage
logger.info(f"Downloading file from {bucket}/{file_key}")
file_content = await document_service.storage.download_file(bucket, file_key)
# Ensure file_content is bytes
if hasattr(file_content, 'read'):
file_content = file_content.read()
# 4. Parse file to text
additional_metadata, text = await document_service.parser.parse_file_to_text(
file_content, original_filename
)
logger.debug(f"Parsed file into text of length {len(text)}")
# 5. Apply rules if provided
if rules_list:
rule_metadata, modified_text = await document_service.rules_processor.process_rules(text, rules_list)
# Update document metadata with extracted metadata from rules
metadata.update(rule_metadata)
if modified_text:
text = modified_text
logger.info("Updated text with modified content from rules")
# 6. Retrieve the existing document
logger.debug(f"Retrieving document with ID: {document_id}")
logger.debug(f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}")
doc = await document_service.db.get_document(document_id, auth)
if not doc:
logger.error(f"Document {document_id} not found in database")
logger.error(f"Details - file: {original_filename}, content_type: {content_type}, bucket: {bucket}, key: {file_key}")
logger.error(f"Auth: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}")
# Try to get all accessible documents to debug
try:
all_docs = await document_service.db.get_documents(auth, 0, 100)
logger.debug(f"User has access to {len(all_docs)} documents: {[d.external_id for d in all_docs]}")
except Exception as list_err:
logger.error(f"Failed to list user documents: {str(list_err)}")
raise ValueError(f"Document {document_id} not found in database")
# Prepare updates for the document
updates = {
"metadata": metadata,
"additional_metadata": additional_metadata,
"system_metadata": {**doc.system_metadata, "content": text}
}
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
updates["system_metadata"]["folder_name"] = folder_name
if end_user_id:
updates["system_metadata"]["end_user_id"] = end_user_id
# Update the document in the database
success = await document_service.db.update_document(
document_id=document_id,
updates=updates,
auth=auth
)
if not success:
raise ValueError(f"Failed to update document {document_id}")
# Refresh document object with updated data
doc = await document_service.db.get_document(document_id, auth)
logger.debug(f"Updated document in database with parsed content")
# 7. Split text into chunks
chunks = await document_service.parser.split_text(text)
if not chunks:
raise ValueError("No content chunks extracted")
logger.debug(f"Split processed text into {len(chunks)} chunks")
# 8. Generate embeddings for chunks
embeddings = await document_service.embedding_model.embed_for_ingestion(chunks)
logger.debug(f"Generated {len(embeddings)} embeddings")
# 9. Create chunk objects
chunk_objects = document_service._create_chunk_objects(doc.external_id, chunks, embeddings)
logger.debug(f"Created {len(chunk_objects)} chunk objects")
# 10. Handle ColPali embeddings if enabled
chunk_objects_multivector = []
if use_colpali and document_service.colpali_embedding_model and document_service.colpali_vector_store:
import filetype
file_type = filetype.guess(file_content)
# For ColPali we need the base64 encoding of the file
import base64
file_content_base64 = base64.b64encode(file_content).decode()
chunks_multivector = document_service._create_chunks_multivector(
file_type, file_content_base64, file_content, chunks
)
logger.debug(f"Created {len(chunks_multivector)} chunks for multivector embedding")
colpali_embeddings = await document_service.colpali_embedding_model.embed_for_ingestion(
chunks_multivector
)
logger.debug(f"Generated {len(colpali_embeddings)} embeddings for multivector embedding")
chunk_objects_multivector = document_service._create_chunk_objects(
doc.external_id, chunks_multivector, colpali_embeddings
)
# Update document status to completed before storing
doc.system_metadata["status"] = "completed"
doc.system_metadata["updated_at"] = datetime.now(UTC)
# 11. Store chunks and update document with is_update=True
chunk_ids = await document_service._store_chunks_and_doc(
chunk_objects, doc, use_colpali, chunk_objects_multivector,
is_update=True, auth=auth
)
logger.debug(f"Successfully completed processing for document {doc.external_id}")
# 13. Log successful completion
logger.info(f"Successfully completed ingestion for {original_filename}, document ID: {doc.external_id}")
# 14. Return document ID
return {
"document_id": doc.external_id,
"status": "completed",
"filename": original_filename,
"content_type": content_type,
"timestamp": datetime.now(UTC).isoformat()
}
except Exception as e:
logger.error(f"Error processing ingestion job for file {original_filename}: {str(e)}")
# Update document status to failed if the document exists
try:
# Create AuthContext for database operations
auth_context = AuthContext(
entity_type=EntityType(auth_dict.get("entity_type", "unknown")),
entity_id=auth_dict.get("entity_id", ""),
app_id=auth_dict.get("app_id"),
permissions=set(auth_dict.get("permissions", ["read"])),
user_id=auth_dict.get("user_id", auth_dict.get("entity_id", ""))
)
# Get database from context
database = ctx.get('database')
if database:
# Try to get the document
doc = await database.get_document(document_id, auth_context)
if doc:
# Update the document status to failed
await database.update_document(
document_id=document_id,
updates={
"system_metadata": {
**doc.system_metadata,
"status": "failed",
"error": str(e),
"updated_at": datetime.now(UTC)
}
},
auth=auth_context
)
logger.info(f"Updated document {document_id} status to failed")
except Exception as inner_e:
logger.error(f"Failed to update document status: {str(inner_e)}")
# Return error information
return {
"status": "failed",
"filename": original_filename,
"error": str(e),
"timestamp": datetime.now(UTC).isoformat()
}
async def startup(ctx):
"""
Worker startup: Initialize all necessary services that will be reused across jobs.
This initialization is similar to what happens in core/api.py during app startup,
but adapted for the worker context.
"""
logger.info("Worker starting up. Initializing services...")
# Get settings
settings = get_settings()
# Initialize database
logger.info("Initializing database...")
database = PostgresDatabase(uri=settings.POSTGRES_URI)
success = await database.initialize()
if success:
logger.info("Database initialization successful")
else:
logger.error("Database initialization failed")
ctx['database'] = database
# Initialize vector store
logger.info("Initializing primary vector store...")
vector_store = PGVectorStore(uri=settings.POSTGRES_URI)
success = await vector_store.initialize()
if success:
logger.info("Primary vector store initialization successful")
else:
logger.error("Primary vector store initialization failed")
ctx['vector_store'] = vector_store
# Initialize storage
if settings.STORAGE_PROVIDER == "local":
storage = LocalStorage(storage_path=settings.STORAGE_PATH)
elif settings.STORAGE_PROVIDER == "aws-s3":
storage = S3Storage(
aws_access_key=settings.AWS_ACCESS_KEY,
aws_secret_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=settings.AWS_REGION,
default_bucket=settings.S3_BUCKET,
)
else:
raise ValueError(f"Unsupported storage provider: {settings.STORAGE_PROVIDER}")
ctx['storage'] = storage
# Initialize parser
parser = MorphikParser(
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
use_unstructured_api=settings.USE_UNSTRUCTURED_API,
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
anthropic_api_key=settings.ANTHROPIC_API_KEY,
use_contextual_chunking=settings.USE_CONTEXTUAL_CHUNKING,
)
ctx['parser'] = parser
# Initialize embedding model
embedding_model = LiteLLMEmbeddingModel(model_key=settings.EMBEDDING_MODEL)
logger.info(f"Initialized LiteLLM embedding model with model key: {settings.EMBEDDING_MODEL}")
ctx['embedding_model'] = embedding_model
# Initialize completion model
completion_model = LiteLLMCompletionModel(model_key=settings.COMPLETION_MODEL)
logger.info(f"Initialized LiteLLM completion model with model key: {settings.COMPLETION_MODEL}")
ctx['completion_model'] = completion_model
# Initialize reranker
reranker = None
if settings.USE_RERANKING:
if settings.RERANKER_PROVIDER == "flag":
from core.reranker.flag_reranker import FlagReranker
reranker = FlagReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
)
else:
logger.warning(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
ctx['reranker'] = reranker
# Initialize ColPali embedding model and vector store if enabled
colpali_embedding_model = None
colpali_vector_store = None
if settings.ENABLE_COLPALI:
logger.info("Initializing ColPali components...")
colpali_embedding_model = ColpaliEmbeddingModel()
colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI)
_ = colpali_vector_store.initialize()
ctx['colpali_embedding_model'] = colpali_embedding_model
ctx['colpali_vector_store'] = colpali_vector_store
# Initialize cache factory for DocumentService (may not be used for ingestion)
from core.cache.llama_cache_factory import LlamaCacheFactory
cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH))
ctx['cache_factory'] = cache_factory
# Initialize rules processor
rules_processor = RulesProcessor()
ctx['rules_processor'] = rules_processor
# Initialize telemetry service
telemetry = TelemetryService()
ctx['telemetry'] = telemetry
# Create the document service using all initialized components
document_service = DocumentService(
storage=storage,
database=database,
vector_store=vector_store,
embedding_model=embedding_model,
completion_model=completion_model,
parser=parser,
reranker=reranker,
cache_factory=cache_factory,
enable_colpali=settings.ENABLE_COLPALI,
colpali_embedding_model=colpali_embedding_model,
colpali_vector_store=colpali_vector_store,
)
ctx['document_service'] = document_service
logger.info("Worker startup complete. All services initialized.")
async def shutdown(ctx):
"""
Worker shutdown: Clean up resources.
Properly close connections and cleanup resources to prevent leaks.
"""
logger.info("Worker shutting down. Cleaning up resources...")
# Close database connections
if 'database' in ctx and hasattr(ctx['database'], 'engine'):
logger.info("Closing database connections...")
await ctx['database'].engine.dispose()
# Close any other open connections or resources that need cleanup
logger.info("Worker shutdown complete.")
# ARQ Worker Settings
class WorkerSettings:
"""
ARQ Worker settings for the ingestion worker.
This defines the functions available to the worker, startup and shutdown handlers,
and any specific Redis settings.
"""
functions = [process_ingestion_job]
on_startup = startup
on_shutdown = shutdown
# Redis settings will be loaded from environment variables by default
# Other optional settings:
# redis_settings = arq.connections.RedisSettings(host='localhost', port=6379)
keep_result_ms = 24 * 60 * 60 * 1000 # Keep results for 24 hours (24 * 60 * 60 * 1000 ms)
max_jobs = 10 # Maximum number of jobs to run concurrently

View File

@ -26,6 +26,8 @@ services:
- HOST=0.0.0.0
- PORT=8000
- LOG_LEVEL=DEBUG
- REDIS_HOST=redis
- REDIS_PORT=6379
volumes:
- ./storage:/app/storage
- ./logs:/app/logs
@ -34,6 +36,8 @@ services:
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
config-check:
condition: service_completed_successfully
ollama:
@ -43,6 +47,51 @@ services:
- morphik-network
env_file:
- .env
worker:
build: .
command: arq core.workers.ingestion_worker.WorkerSettings
environment:
- JWT_SECRET_KEY=${JWT_SECRET_KEY:-your-secret-key-here}
- POSTGRES_URI=postgresql+asyncpg://morphik:morphik@postgres:5432/morphik
- PGPASSWORD=morphik
- LOG_LEVEL=DEBUG
- REDIS_HOST=redis
- REDIS_PORT=6379
volumes:
- ./storage:/app/storage
- ./logs:/app/logs
- ./morphik.toml:/app/morphik.toml
- huggingface_cache:/root/.cache/huggingface
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
config-check:
condition: service_completed_successfully
ollama:
condition: service_started
required: false
networks:
- morphik-network
env_file:
- .env
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
networks:
- morphik-network
postgres:
build:
@ -87,4 +136,5 @@ networks:
volumes:
postgres_data:
ollama_data:
huggingface_cache:
huggingface_cache:
redis_data:

View File

@ -100,6 +100,10 @@ enable_colpali = true
mode = "self_hosted" # "cloud" or "self_hosted"
api_domain = "api.morphik.ai" # API domain for cloud URIs
[redis]
host = "localhost"
port = 6379
[graph]
model = "ollama_llama"
enable_entity_resolution = true

View File

@ -340,3 +340,4 @@ zlib-state==0.1.9
zstandard==0.23.0
litellm==1.65.4.post1
instructor==1.7.9
arq==0.25.0

View File

@ -1490,6 +1490,76 @@ class AsyncMorphik:
doc = self._logic._parse_document_response(response)
doc._client = self
return doc
async def get_document_status(self, document_id: str) -> Dict[str, Any]:
"""
Get the current processing status of a document.
Args:
document_id: ID of the document to check
Returns:
Dict[str, Any]: Status information including current status, potential errors, and other metadata
Example:
```python
status = await db.get_document_status("doc_123")
if status["status"] == "completed":
print("Document processing complete")
elif status["status"] == "failed":
print(f"Processing failed: {status['error']}")
else:
print("Document still processing...")
```
"""
response = await self._request("GET", f"documents/{document_id}/status")
return response
async def wait_for_document_completion(self, document_id: str, timeout_seconds=300, check_interval_seconds=2) -> Document:
"""
Wait for a document's processing to complete.
Args:
document_id: ID of the document to wait for
timeout_seconds: Maximum time to wait for completion (default: 300 seconds)
check_interval_seconds: Time between status checks (default: 2 seconds)
Returns:
Document: Updated document with the latest status
Raises:
TimeoutError: If processing doesn't complete within the timeout period
ValueError: If processing fails with an error
Example:
```python
# Upload a file and wait for processing to complete
doc = await db.ingest_file("large_document.pdf")
try:
completed_doc = await db.wait_for_document_completion(doc.external_id)
print(f"Processing complete! Document has {len(completed_doc.chunk_ids)} chunks")
except TimeoutError:
print("Processing is taking too long")
except ValueError as e:
print(f"Processing failed: {e}")
```
"""
import asyncio
start_time = asyncio.get_event_loop().time()
while (asyncio.get_event_loop().time() - start_time) < timeout_seconds:
status = await self.get_document_status(document_id)
if status["status"] == "completed":
# Get the full document now that it's complete
return await self.get_document(document_id)
elif status["status"] == "failed":
raise ValueError(f"Document processing failed: {status.get('error', 'Unknown error')}")
# Wait before checking again
await asyncio.sleep(check_interval_seconds)
raise TimeoutError(f"Document processing did not complete within {timeout_seconds} seconds")
async def get_document_by_filename(self, filename: str) -> Document:
"""

View File

@ -24,6 +24,60 @@ class Document(BaseModel):
# Client reference for update methods
_client = None
@property
def status(self) -> Dict[str, Any]:
"""Get the latest processing status of the document from the API.
Returns:
Dict[str, Any]: Status information including current status, potential errors, and other metadata
"""
if self._client is None:
raise ValueError(
"Document instance not connected to a client. Use a document returned from a Morphik client method."
)
return self._client.get_document_status(self.external_id)
@property
def is_processing(self) -> bool:
"""Check if the document is still being processed."""
return self.status.get("status") == "processing"
@property
def is_ingested(self) -> bool:
"""Check if the document has completed processing."""
return self.status.get("status") == "completed"
@property
def is_failed(self) -> bool:
"""Check if document processing has failed."""
return self.status.get("status") == "failed"
@property
def error(self) -> Optional[str]:
"""Get the error message if processing failed."""
status_info = self.status
return status_info.get("error") if status_info.get("status") == "failed" else None
def wait_for_completion(self, timeout_seconds=300, check_interval_seconds=2):
"""Wait for document processing to complete.
Args:
timeout_seconds: Maximum time to wait for completion (default: 300 seconds)
check_interval_seconds: Time between status checks (default: 2 seconds)
Returns:
Document: Updated document with the latest status
Raises:
TimeoutError: If processing doesn't complete within the timeout period
ValueError: If processing fails with an error
"""
if self._client is None:
raise ValueError(
"Document instance not connected to a client. Use a document returned from a Morphik client method."
)
return self._client.wait_for_document_completion(self.external_id, timeout_seconds, check_interval_seconds)
def update_with_text(
self,

View File

@ -1618,6 +1618,76 @@ class Morphik:
doc = self._logic._parse_document_response(response)
doc._client = self
return doc
def get_document_status(self, document_id: str) -> Dict[str, Any]:
"""
Get the current processing status of a document.
Args:
document_id: ID of the document to check
Returns:
Dict[str, Any]: Status information including current status, potential errors, and other metadata
Example:
```python
status = db.get_document_status("doc_123")
if status["status"] == "completed":
print("Document processing complete")
elif status["status"] == "failed":
print(f"Processing failed: {status['error']}")
else:
print("Document still processing...")
```
"""
response = self._request("GET", f"documents/{document_id}/status")
return response
def wait_for_document_completion(self, document_id: str, timeout_seconds=300, check_interval_seconds=2) -> Document:
"""
Wait for a document's processing to complete.
Args:
document_id: ID of the document to wait for
timeout_seconds: Maximum time to wait for completion (default: 300 seconds)
check_interval_seconds: Time between status checks (default: 2 seconds)
Returns:
Document: Updated document with the latest status
Raises:
TimeoutError: If processing doesn't complete within the timeout period
ValueError: If processing fails with an error
Example:
```python
# Upload a file and wait for processing to complete
doc = db.ingest_file("large_document.pdf")
try:
completed_doc = db.wait_for_document_completion(doc.external_id)
print(f"Processing complete! Document has {len(completed_doc.chunk_ids)} chunks")
except TimeoutError:
print("Processing is taking too long")
except ValueError as e:
print(f"Processing failed: {e}")
```
"""
import time
start_time = time.time()
while (time.time() - start_time) < timeout_seconds:
status = self.get_document_status(document_id)
if status["status"] == "completed":
# Get the full document now that it's complete
return self.get_document(document_id)
elif status["status"] == "failed":
raise ValueError(f"Document processing failed: {status.get('error', 'Unknown error')}")
# Wait before checking again
time.sleep(check_interval_seconds)
raise TimeoutError(f"Document processing did not complete within {timeout_seconds} seconds")
def get_document_by_filename(self, filename: str) -> Document:
"""

View File

@ -4,10 +4,140 @@ import sys
import tomli
import requests
import logging
import subprocess
import signal
import os
import atexit
from dotenv import load_dotenv
from core.config import get_settings
from core.logging_config import setup_logging
# Global variable to store the worker process
worker_process = None
def check_and_start_redis():
"""Check if the Redis container is running, start if necessary."""
try:
# Check if container exists and is running
check_running_cmd = ["docker", "ps", "-q", "-f", "name=morphik-redis"]
running_container = subprocess.check_output(check_running_cmd).strip()
if running_container:
logging.info("Redis container (morphik-redis) is already running.")
return
# Check if container exists but is stopped
check_exists_cmd = ["docker", "ps", "-a", "-q", "-f", "name=morphik-redis"]
existing_container = subprocess.check_output(check_exists_cmd).strip()
if existing_container:
logging.info("Starting existing Redis container (morphik-redis)...")
subprocess.run(["docker", "start", "morphik-redis"], check=True, capture_output=True)
logging.info("Redis container started.")
else:
logging.info("Creating and starting Redis container (morphik-redis)...")
subprocess.run(
["docker", "run", "-d", "--name", "morphik-redis", "-p", "6379:6379", "redis"],
check=True,
capture_output=True,
)
logging.info("Redis container created and started.")
except subprocess.CalledProcessError as e:
logging.error(f"Failed to manage Redis container: {e}")
logging.error(f"Stderr: {e.stderr.decode() if e.stderr else 'N/A'}")
sys.exit(1)
except FileNotFoundError:
logging.error("Docker command not found. Please ensure Docker is installed and in PATH.")
sys.exit(1)
def start_arq_worker():
"""Start the ARQ worker as a subprocess."""
global worker_process
try:
logging.info("Starting ARQ worker...")
# Ensure logs directory exists
log_dir = os.path.join(os.getcwd(), "logs")
os.makedirs(log_dir, exist_ok=True)
# Worker log file paths
worker_log_path = os.path.join(log_dir, "worker.log")
# Open log files
worker_log = open(worker_log_path, "a")
# Add timestamp to log
timestamp = subprocess.check_output(["date"]).decode().strip()
worker_log.write(f"\n\n--- Worker started at {timestamp} ---\n\n")
worker_log.flush()
# Use sys.executable to ensure the same Python environment is used
worker_cmd = [sys.executable, "-m", "arq", "core.workers.ingestion_worker.WorkerSettings"]
# Start the worker with output redirected to log files
worker_process = subprocess.Popen(
worker_cmd,
stdout=worker_log,
stderr=worker_log,
env=dict(os.environ, PYTHONUNBUFFERED="1") # Ensure unbuffered output
)
logging.info(f"ARQ worker started with PID: {worker_process.pid}")
logging.info(f"Worker logs available at: {worker_log_path}")
except Exception as e:
logging.error(f"Failed to start ARQ worker: {e}")
sys.exit(1)
def cleanup_processes():
"""Stop the ARQ worker process on exit."""
global worker_process
if worker_process and worker_process.poll() is None: # Check if process is still running
logging.info(f"Stopping ARQ worker (PID: {worker_process.pid})...")
# Log the worker termination
try:
log_dir = os.path.join(os.getcwd(), "logs")
worker_log_path = os.path.join(log_dir, "worker.log")
with open(worker_log_path, "a") as worker_log:
timestamp = subprocess.check_output(["date"]).decode().strip()
worker_log.write(f"\n\n--- Worker stopping at {timestamp} ---\n\n")
except Exception as e:
logging.warning(f"Could not write worker stop message to log: {e}")
# Send SIGTERM first for graceful shutdown
worker_process.terminate()
try:
# Wait a bit for graceful shutdown
worker_process.wait(timeout=5)
logging.info("ARQ worker stopped gracefully.")
except subprocess.TimeoutExpired:
logging.warning("ARQ worker did not terminate gracefully, sending SIGKILL.")
worker_process.kill() # Force kill if it doesn't stop
logging.info("ARQ worker killed.")
# Close any open file descriptors for the process
if hasattr(worker_process, 'stdout') and worker_process.stdout:
worker_process.stdout.close()
if hasattr(worker_process, 'stderr') and worker_process.stderr:
worker_process.stderr.close()
# Optional: Add Redis container stop logic here if desired
# try:
# logging.info("Stopping Redis container...")
# subprocess.run(["docker", "stop", "morphik-redis"], check=False, capture_output=True)
# except Exception as e:
# logging.warning(f"Could not stop Redis container: {e}")
# Register the cleanup function to be called on script exit
atexit.register(cleanup_processes)
# Also register for SIGINT (Ctrl+C) and SIGTERM
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
signal.signal(signal.SIGTERM, lambda sig, frame: sys.exit(0))
def check_ollama_running(base_url):
"""Check if Ollama is running and accessible at the given URL."""
@ -98,6 +228,9 @@ def main():
# Set up logging first with specified level
setup_logging(log_level=args.log.upper())
# Check and start Redis container
check_and_start_redis()
# Load environment variables from .env file
load_dotenv()
@ -133,14 +266,18 @@ def main():
# Load settings (this will validate all required env vars)
settings = get_settings()
# Start server
# Start ARQ worker in the background
start_arq_worker()
# Start server (this is blocking)
logging.info("Starting Uvicorn server...")
uvicorn.run(
"core.api:app",
host=settings.HOST,
port=settings.PORT,
loop="asyncio",
log_level=args.log,
# reload=settings.RELOAD
# reload=settings.RELOAD # Reload might interfere with subprocess management
)