Add pooling changes for scalable ingestion (#90)

This commit is contained in:
Arnav Agrawal 2025-04-16 21:45:34 -07:00 committed by GitHub
parent 2b1c253bc1
commit 4ffe0b3bac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 327 additions and 47 deletions

View File

@ -396,7 +396,9 @@ async def ingest_file(
# Parse metadata and rules # Parse metadata and rules
metadata_dict = json.loads(metadata) metadata_dict = json.loads(metadata)
rules_list = json.loads(rules) rules_list = json.loads(rules)
use_colpali = bool(use_colpali) # Fix bool conversion: ensure string "false" is properly converted to False
def str2bool(v): return v if isinstance(v, bool) else str(v).lower() in {"true", "1", "yes"}
use_colpali = str2bool(use_colpali)
# Ensure user has write permission # Ensure user has write permission
if "write" not in auth.permissions: if "write" not in auth.permissions:
@ -422,7 +424,7 @@ async def ingest_file(
content_type=file.content_type, content_type=file.content_type,
filename=file.filename, filename=file.filename,
metadata=metadata_dict, metadata=metadata_dict,
owner={"type": auth.entity_type, "id": auth.entity_id}, owner={"type": auth.entity_type.value, "id": auth.entity_id},
access_control={ access_control={
"readers": [auth.entity_id], "readers": [auth.entity_id],
"writers": [auth.entity_id], "writers": [auth.entity_id],
@ -558,7 +560,9 @@ async def batch_ingest_files(
try: try:
metadata_value = json.loads(metadata) metadata_value = json.loads(metadata)
rules_list = json.loads(rules) rules_list = json.loads(rules)
use_colpali = bool(use_colpali) # Fix bool conversion: ensure string "false" is properly converted to False
def str2bool(v): return str(v).lower() in {"true", "1", "yes"}
use_colpali = str2bool(use_colpali)
# Ensure user has write permission # Ensure user has write permission
if "write" not in auth.permissions: if "write" not in auth.permissions:
@ -616,7 +620,7 @@ async def batch_ingest_files(
content_type=file.content_type, content_type=file.content_type,
filename=file.filename, filename=file.filename,
metadata=metadata_item, metadata=metadata_item,
owner={"type": auth.entity_type, "id": auth.entity_id}, owner={"type": auth.entity_type.value, "id": auth.entity_id},
access_control={ access_control={
"readers": [auth.entity_id], "readers": [auth.entity_id],
"writers": [auth.entity_id], "writers": [auth.entity_id],

View File

@ -43,6 +43,14 @@ class Settings(BaseSettings):
# Database configuration # Database configuration
DATABASE_PROVIDER: Literal["postgres"] DATABASE_PROVIDER: Literal["postgres"]
DATABASE_NAME: Optional[str] = None DATABASE_NAME: Optional[str] = None
# Database connection pool settings
DB_POOL_SIZE: int = 20
DB_MAX_OVERFLOW: int = 30
DB_POOL_RECYCLE: int = 3600
DB_POOL_TIMEOUT: int = 10
DB_POOL_PRE_PING: bool = True
DB_MAX_RETRIES: int = 3
DB_RETRY_DELAY: float = 1.0
# Embedding configuration # Embedding configuration
EMBEDDING_PROVIDER: Literal["litellm"] = "litellm" EMBEDDING_PROVIDER: Literal["litellm"] = "litellm"
@ -164,7 +172,18 @@ def get_settings() -> Settings:
completion_config["COMPLETION_MODEL"] = config["completion"]["model"] completion_config["COMPLETION_MODEL"] = config["completion"]["model"]
# load database config # load database config
database_config = {"DATABASE_PROVIDER": config["database"]["provider"]} database_config = {
"DATABASE_PROVIDER": config["database"]["provider"],
"DATABASE_NAME": config["database"].get("name", None),
# Add database connection pool settings
"DB_POOL_SIZE": config["database"].get("pool_size", 20),
"DB_MAX_OVERFLOW": config["database"].get("max_overflow", 30),
"DB_POOL_RECYCLE": config["database"].get("pool_recycle", 3600),
"DB_POOL_TIMEOUT": config["database"].get("pool_timeout", 10),
"DB_POOL_PRE_PING": config["database"].get("pool_pre_ping", True),
"DB_MAX_RETRIES": config["database"].get("max_retries", 3),
"DB_RETRY_DELAY": config["database"].get("retry_delay", 1.0),
}
if database_config["DATABASE_PROVIDER"] != "postgres": if database_config["DATABASE_PROVIDER"] != "postgres":
prov = database_config["DATABASE_PROVIDER"] prov = database_config["DATABASE_PROVIDER"]
raise ValueError(f"Unknown database provider selected: '{prov}'") raise ValueError(f"Unknown database provider selected: '{prov}'")

View File

@ -109,7 +109,36 @@ class PostgresDatabase(BaseDatabase):
uri: str, uri: str,
): ):
"""Initialize PostgreSQL connection for document storage.""" """Initialize PostgreSQL connection for document storage."""
self.engine = create_async_engine(uri) # Load settings from config
from core.config import get_settings
settings = get_settings()
# Get database pool settings from config with defaults
pool_size = getattr(settings, "DB_POOL_SIZE", 20)
max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30)
pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600)
pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10)
pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True)
logger.info(f"Initializing PostgreSQL connection pool with size={pool_size}, "
f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s")
# Create async engine with explicit pool settings
self.engine = create_async_engine(
uri,
# Prevent connection timeouts by keeping connections alive
pool_pre_ping=pool_pre_ping,
# Increase pool size to handle concurrent operations
pool_size=pool_size,
# Maximum overflow connections allowed beyond pool_size
max_overflow=max_overflow,
# Keep connections in the pool for up to 60 minutes
pool_recycle=pool_recycle,
# Time to wait for a connection from the pool (10 seconds)
pool_timeout=pool_timeout,
# Echo SQL for debugging (set to False in production)
echo=False,
)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False self._initialized = False

View File

@ -39,6 +39,7 @@ import pdf2image
from PIL.Image import Image from PIL.Image import Image
import tempfile import tempfile
import os import os
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMAGE = {im.mime for im in IMAGE} IMAGE = {im.mime for im in IMAGE}
@ -845,42 +846,120 @@ class DocumentService:
auth: Optional[AuthContext] = None, auth: Optional[AuthContext] = None,
) -> List[str]: ) -> List[str]:
"""Helper to store chunks and document""" """Helper to store chunks and document"""
# Store chunks in vector store # Add retry logic for vector store operations
success, result = await self.vector_store.store_embeddings(chunk_objects) max_retries = 3
if not success: retry_delay = 1.0
raise Exception("Failed to store chunk embeddings")
# Store chunks in vector store with retry
attempt = 0
success = False
result = None
while attempt < max_retries and not success:
try:
success, result = await self.vector_store.store_embeddings(chunk_objects)
if not success:
raise Exception("Failed to store chunk embeddings")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(f"Database connection error during embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(f"All database connection attempts failed after {max_retries} retries: {error_msg}")
raise Exception("Failed to store chunk embeddings after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing embeddings: {error_msg}")
raise
logger.debug("Stored chunk embeddings in vector store") logger.debug("Stored chunk embeddings in vector store")
doc.chunk_ids = result doc.chunk_ids = result
if use_colpali and self.colpali_vector_store and chunk_objects_multivector: if use_colpali and self.colpali_vector_store and chunk_objects_multivector:
success, result_multivector = await self.colpali_vector_store.store_embeddings( # Reset retry variables for colpali storage
chunk_objects_multivector attempt = 0
) retry_delay = 1.0
if not success: success = False
raise Exception("Failed to store multivector chunk embeddings") result_multivector = None
while attempt < max_retries and not success:
try:
success, result_multivector = await self.colpali_vector_store.store_embeddings(
chunk_objects_multivector
)
if not success:
raise Exception("Failed to store multivector chunk embeddings")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(f"Database connection error during colpali embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(f"All colpali database connection attempts failed after {max_retries} retries: {error_msg}")
raise Exception("Failed to store multivector chunk embeddings after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing colpali embeddings: {error_msg}")
raise
logger.debug("Stored multivector chunk embeddings in vector store") logger.debug("Stored multivector chunk embeddings in vector store")
doc.chunk_ids += result_multivector doc.chunk_ids += result_multivector
# Store document metadata # Store document metadata with retry
if is_update and auth: attempt = 0
# For updates, use update_document retry_delay = 1.0
updates = { success = False
"chunk_ids": doc.chunk_ids,
"metadata": doc.metadata, while attempt < max_retries and not success:
"system_metadata": doc.system_metadata, try:
"filename": doc.filename, if is_update and auth:
"content_type": doc.content_type, # For updates, use update_document
"storage_info": doc.storage_info, updates = {
} "chunk_ids": doc.chunk_ids,
if not await self.db.update_document(doc.external_id, updates, auth): "metadata": doc.metadata,
raise Exception("Failed to update document metadata") "system_metadata": doc.system_metadata,
logger.debug("Updated document metadata in database") "filename": doc.filename,
else: "content_type": doc.content_type,
# For new documents, use store_document "storage_info": doc.storage_info,
if not await self.db.store_document(doc): }
raise Exception("Failed to store document metadata") success = await self.db.update_document(doc.external_id, updates, auth)
logger.debug("Stored document metadata in database") if not success:
raise Exception("Failed to update document metadata")
else:
# For new documents, use store_document
success = await self.db.store_document(doc)
if not success:
raise Exception("Failed to store document metadata")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(f"Database connection error during document metadata storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(f"All database connection attempts failed after {max_retries} retries: {error_msg}")
raise Exception("Failed to store document metadata after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing document metadata: {error_msg}")
raise
logger.debug("Stored document metadata in database")
logger.debug(f"Chunk IDs stored: {doc.chunk_ids}") logger.debug(f"Chunk IDs stored: {doc.chunk_ids}")
return doc.chunk_ids return doc.chunk_ids

View File

@ -287,10 +287,13 @@ class MultiVectorStore(BaseVectorStore):
params = [binary_query_embeddings] params = [binary_query_embeddings]
# Add document filter if needed # Add document filter if needed with proper parameterization
if doc_ids: if doc_ids:
doc_ids_str = "', '".join(doc_ids) # Use placeholders for each document ID
query += f" WHERE document_id IN ('{doc_ids_str}')" placeholders = ', '.join(['%s'] * len(doc_ids))
query += f" WHERE document_id IN ({placeholders})"
# Add document IDs to params
params.extend(doc_ids)
# Add ordering and limit # Add ordering and limit
query += " ORDER BY similarity DESC LIMIT %s" query += " ORDER BY similarity DESC LIMIT %s"

View File

@ -83,15 +83,41 @@ class PGVectorStore(BaseVectorStore):
max_retries: Maximum number of connection retry attempts max_retries: Maximum number of connection retry attempts
retry_delay: Delay in seconds between retry attempts retry_delay: Delay in seconds between retry attempts
""" """
# Load settings from config
from core.config import get_settings
settings = get_settings()
# Get database pool settings from config with defaults
pool_size = getattr(settings, "DB_POOL_SIZE", 20)
max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30)
pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600)
pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10)
pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True)
# Use the URI exactly as provided without any modifications # Use the URI exactly as provided without any modifications
# This ensures compatibility with Supabase and other PostgreSQL providers # This ensures compatibility with Supabase and other PostgreSQL providers
logger.info(f"Initializing database engine with provided URI") logger.info(f"Initializing vector store database engine with pool size={pool_size}, "
f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s")
# Create the engine with the URI as is # Create the engine with the URI as is and improved connection pool settings
self.engine = create_async_engine(uri) self.engine = create_async_engine(
uri,
# Prevent connection timeouts by keeping connections alive
pool_pre_ping=pool_pre_ping,
# Increase pool size to handle concurrent operations
pool_size=pool_size,
# Maximum overflow connections allowed beyond pool_size
max_overflow=max_overflow,
# Keep connections in the pool for up to 60 minutes
pool_recycle=pool_recycle,
# Time to wait for a connection from the pool (10 seconds)
pool_timeout=pool_timeout,
# Echo SQL for debugging (set to False in production)
echo=False,
)
# Log success # Log success
logger.info("Created database engine successfully") logger.info("Created vector store database engine successfully")
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self.max_retries = max_retries self.max_retries = max_retries
self.retry_delay = retry_delay self.retry_delay = retry_delay

View File

@ -3,6 +3,7 @@ import logging
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime, UTC from datetime import datetime, UTC
from pathlib import Path from pathlib import Path
import asyncio
import arq import arq
from core.models.auth import AuthContext, EntityType from core.models.auth import AuthContext, EntityType
@ -20,9 +21,58 @@ from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService from core.services.telemetry import TelemetryService
from core.services.rules_processor import RulesProcessor from core.services.rules_processor import RulesProcessor
from core.config import get_settings from core.config import get_settings
from sqlalchemy import text
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def get_document_with_retry(document_service, document_id, auth, max_retries=3, initial_delay=0.3):
"""
Helper function to get a document with retries to handle race conditions.
Args:
document_service: The document service instance
document_id: ID of the document to retrieve
auth: Authentication context
max_retries: Maximum number of retry attempts
initial_delay: Initial delay before first attempt in seconds
Returns:
Document if found and accessible, None otherwise
"""
attempt = 0
retry_delay = initial_delay
# Add initial delay to allow transaction to commit
if initial_delay > 0:
await asyncio.sleep(initial_delay)
while attempt < max_retries:
try:
doc = await document_service.db.get_document(document_id, auth)
if doc:
logger.debug(f"Successfully retrieved document {document_id} on attempt {attempt+1}")
return doc
# Document not found but no exception raised
attempt += 1
if attempt < max_retries:
logger.warning(f"Document {document_id} not found on attempt {attempt}/{max_retries}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay *= 1.5
except Exception as e:
attempt += 1
error_msg = str(e)
if attempt < max_retries:
logger.warning(f"Error retrieving document on attempt {attempt}/{max_retries}: {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay *= 1.5
else:
logger.error(f"Failed to retrieve document after {max_retries} attempts: {error_msg}")
return None
return None
async def process_ingestion_job( async def process_ingestion_job(
ctx: Dict[str, Any], ctx: Dict[str, Any],
document_id: str, document_id: str,
@ -100,10 +150,12 @@ async def process_ingestion_job(
# 6. Retrieve the existing document # 6. Retrieve the existing document
logger.debug(f"Retrieving document with ID: {document_id}") logger.debug(f"Retrieving document with ID: {document_id}")
logger.debug(f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}") logger.debug(f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}")
doc = await document_service.db.get_document(document_id, auth)
# Use the retry helper function with initial delay to handle race conditions
doc = await get_document_with_retry(document_service, document_id, auth, max_retries=5, initial_delay=1.0)
if not doc: if not doc:
logger.error(f"Document {document_id} not found in database") logger.error(f"Document {document_id} not found in database after multiple retries")
logger.error(f"Details - file: {original_filename}, content_type: {content_type}, bucket: {bucket}, key: {file_key}") logger.error(f"Details - file: {original_filename}, content_type: {content_type}, bucket: {bucket}, key: {file_key}")
logger.error(f"Auth: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}") logger.error(f"Auth: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}")
# Try to get all accessible documents to debug # Try to get all accessible documents to debug
@ -113,7 +165,7 @@ async def process_ingestion_job(
except Exception as list_err: except Exception as list_err:
logger.error(f"Failed to list user documents: {str(list_err)}") logger.error(f"Failed to list user documents: {str(list_err)}")
raise ValueError(f"Document {document_id} not found in database") raise ValueError(f"Document {document_id} not found in database after multiple retries")
# Prepare updates for the document # Prepare updates for the document
updates = { updates = {
@ -342,7 +394,13 @@ async def startup(ctx):
logger.info("Initializing ColPali components...") logger.info("Initializing ColPali components...")
colpali_embedding_model = ColpaliEmbeddingModel() colpali_embedding_model = ColpaliEmbeddingModel()
colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI) colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI)
_ = colpali_vector_store.initialize() # Properly await the initialization to ensure indexes are ready
# MultiVectorStore.initialize is synchronous, so we need to run it in a thread
success = await asyncio.to_thread(colpali_vector_store.initialize)
if success:
logger.info("ColPali vector store initialization successful")
else:
logger.error("ColPali vector store initialization failed")
ctx['colpali_embedding_model'] = colpali_embedding_model ctx['colpali_embedding_model'] = colpali_embedding_model
ctx['colpali_vector_store'] = colpali_vector_store ctx['colpali_vector_store'] = colpali_vector_store
@ -390,6 +448,16 @@ async def shutdown(ctx):
logger.info("Closing database connections...") logger.info("Closing database connections...")
await ctx['database'].engine.dispose() await ctx['database'].engine.dispose()
# Close vector store connections if they exist
if 'vector_store' in ctx and hasattr(ctx['vector_store'], 'engine'):
logger.info("Closing vector store connections...")
await ctx['vector_store'].engine.dispose()
# Close colpali vector store connections if they exist
if 'colpali_vector_store' in ctx and hasattr(ctx['colpali_vector_store'], 'engine'):
logger.info("Closing colpali vector store connections...")
await ctx['colpali_vector_store'].engine.dispose()
# Close any other open connections or resources that need cleanup # Close any other open connections or resources that need cleanup
logger.info("Worker shutdown complete.") logger.info("Worker shutdown complete.")
@ -408,4 +476,48 @@ class WorkerSettings:
# Other optional settings: # Other optional settings:
# redis_settings = arq.connections.RedisSettings(host='localhost', port=6379) # redis_settings = arq.connections.RedisSettings(host='localhost', port=6379)
keep_result_ms = 24 * 60 * 60 * 1000 # Keep results for 24 hours (24 * 60 * 60 * 1000 ms) keep_result_ms = 24 * 60 * 60 * 1000 # Keep results for 24 hours (24 * 60 * 60 * 1000 ms)
max_jobs = 10 # Maximum number of jobs to run concurrently max_jobs = 5 # Reduce concurrent jobs to prevent connection pool exhaustion
health_check_interval = 300 # Check worker health every 5 minutes instead of 30 seconds to reduce connection overhead
job_timeout = 3600 # 1 hour timeout for jobs
max_tries = 3 # Retry failed jobs up to 3 times
poll_delay = 0.5 # Poll delay to prevent excessive Redis queries
# Log Redis and connection pool information for debugging
@staticmethod
async def health_check(ctx):
"""Periodic health check to log connection status and job stats."""
database = ctx.get('database')
vector_store = ctx.get('vector_store')
job_stats = ctx.get('job_stats', {})
redis_info = await ctx['redis'].info()
logger.info(f"Health check: Redis v{redis_info.get('redis_version', 'unknown')} "
f"mem_usage={redis_info.get('used_memory_human', 'unknown')} "
f"clients_connected={redis_info.get('connected_clients', 'unknown')} "
f"db_keys={redis_info.get('db0', {}).get('keys', 0)}"
)
# Log job statistics
logger.info(f"Job stats: completed={job_stats.get('complete', 0)} "
f"failed={job_stats.get('failed', 0)} "
f"retried={job_stats.get('retried', 0)} "
f"ongoing={job_stats.get('ongoing', 0)} "
f"queued={job_stats.get('queued', 0)}"
)
# Test database connectivity
if database and hasattr(database, 'async_session'):
try:
async with database.async_session() as session:
await session.execute(text("SELECT 1"))
logger.debug("Database connection is healthy")
except Exception as e:
logger.error(f"Database connection test failed: {str(e)}")
# Test vector store connectivity if available
if vector_store and hasattr(vector_store, 'async_session'):
try:
async with vector_store.get_session_with_retry() as session:
logger.debug("Vector store connection is healthy")
except Exception as e:
logger.error(f"Vector store connection test failed: {str(e)}")

View File

@ -53,6 +53,14 @@ default_temperature = 0.5
[database] [database]
provider = "postgres" provider = "postgres"
# Connection pool settings
pool_size = 10 # Maximum number of connections in the pool
max_overflow = 15 # Maximum number of connections that can be created beyond pool_size
pool_recycle = 3600 # Time in seconds after which a connection is recycled (1 hour)
pool_timeout = 10 # Seconds to wait for a connection from the pool
pool_pre_ping = true # Check connection viability before using it from the pool
max_retries = 3 # Number of retries for database operations
retry_delay = 1.0 # Initial delay between retries in seconds
[embedding] [embedding]
model = "ollama_embedding" # Reference to registered model model = "ollama_embedding" # Reference to registered model