Add pooling changes for scalable ingestion (#90)

This commit is contained in:
Arnav Agrawal 2025-04-16 21:45:34 -07:00 committed by GitHub
parent 2b1c253bc1
commit 4ffe0b3bac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 327 additions and 47 deletions

View File

@ -396,7 +396,9 @@ async def ingest_file(
# Parse metadata and rules
metadata_dict = json.loads(metadata)
rules_list = json.loads(rules)
use_colpali = bool(use_colpali)
# Fix bool conversion: ensure string "false" is properly converted to False
def str2bool(v): return v if isinstance(v, bool) else str(v).lower() in {"true", "1", "yes"}
use_colpali = str2bool(use_colpali)
# Ensure user has write permission
if "write" not in auth.permissions:
@ -422,7 +424,7 @@ async def ingest_file(
content_type=file.content_type,
filename=file.filename,
metadata=metadata_dict,
owner={"type": auth.entity_type, "id": auth.entity_id},
owner={"type": auth.entity_type.value, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],
@ -558,7 +560,9 @@ async def batch_ingest_files(
try:
metadata_value = json.loads(metadata)
rules_list = json.loads(rules)
use_colpali = bool(use_colpali)
# Fix bool conversion: ensure string "false" is properly converted to False
def str2bool(v): return str(v).lower() in {"true", "1", "yes"}
use_colpali = str2bool(use_colpali)
# Ensure user has write permission
if "write" not in auth.permissions:
@ -616,7 +620,7 @@ async def batch_ingest_files(
content_type=file.content_type,
filename=file.filename,
metadata=metadata_item,
owner={"type": auth.entity_type, "id": auth.entity_id},
owner={"type": auth.entity_type.value, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],

View File

@ -43,6 +43,14 @@ class Settings(BaseSettings):
# Database configuration
DATABASE_PROVIDER: Literal["postgres"]
DATABASE_NAME: Optional[str] = None
# Database connection pool settings
DB_POOL_SIZE: int = 20
DB_MAX_OVERFLOW: int = 30
DB_POOL_RECYCLE: int = 3600
DB_POOL_TIMEOUT: int = 10
DB_POOL_PRE_PING: bool = True
DB_MAX_RETRIES: int = 3
DB_RETRY_DELAY: float = 1.0
# Embedding configuration
EMBEDDING_PROVIDER: Literal["litellm"] = "litellm"
@ -164,7 +172,18 @@ def get_settings() -> Settings:
completion_config["COMPLETION_MODEL"] = config["completion"]["model"]
# load database config
database_config = {"DATABASE_PROVIDER": config["database"]["provider"]}
database_config = {
"DATABASE_PROVIDER": config["database"]["provider"],
"DATABASE_NAME": config["database"].get("name", None),
# Add database connection pool settings
"DB_POOL_SIZE": config["database"].get("pool_size", 20),
"DB_MAX_OVERFLOW": config["database"].get("max_overflow", 30),
"DB_POOL_RECYCLE": config["database"].get("pool_recycle", 3600),
"DB_POOL_TIMEOUT": config["database"].get("pool_timeout", 10),
"DB_POOL_PRE_PING": config["database"].get("pool_pre_ping", True),
"DB_MAX_RETRIES": config["database"].get("max_retries", 3),
"DB_RETRY_DELAY": config["database"].get("retry_delay", 1.0),
}
if database_config["DATABASE_PROVIDER"] != "postgres":
prov = database_config["DATABASE_PROVIDER"]
raise ValueError(f"Unknown database provider selected: '{prov}'")

View File

@ -109,7 +109,36 @@ class PostgresDatabase(BaseDatabase):
uri: str,
):
"""Initialize PostgreSQL connection for document storage."""
self.engine = create_async_engine(uri)
# Load settings from config
from core.config import get_settings
settings = get_settings()
# Get database pool settings from config with defaults
pool_size = getattr(settings, "DB_POOL_SIZE", 20)
max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30)
pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600)
pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10)
pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True)
logger.info(f"Initializing PostgreSQL connection pool with size={pool_size}, "
f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s")
# Create async engine with explicit pool settings
self.engine = create_async_engine(
uri,
# Prevent connection timeouts by keeping connections alive
pool_pre_ping=pool_pre_ping,
# Increase pool size to handle concurrent operations
pool_size=pool_size,
# Maximum overflow connections allowed beyond pool_size
max_overflow=max_overflow,
# Keep connections in the pool for up to 60 minutes
pool_recycle=pool_recycle,
# Time to wait for a connection from the pool (10 seconds)
pool_timeout=pool_timeout,
# Echo SQL for debugging (set to False in production)
echo=False,
)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False

View File

@ -39,6 +39,7 @@ import pdf2image
from PIL.Image import Image
import tempfile
import os
import asyncio
logger = logging.getLogger(__name__)
IMAGE = {im.mime for im in IMAGE}
@ -845,42 +846,120 @@ class DocumentService:
auth: Optional[AuthContext] = None,
) -> List[str]:
"""Helper to store chunks and document"""
# Store chunks in vector store
success, result = await self.vector_store.store_embeddings(chunk_objects)
if not success:
raise Exception("Failed to store chunk embeddings")
# Add retry logic for vector store operations
max_retries = 3
retry_delay = 1.0
# Store chunks in vector store with retry
attempt = 0
success = False
result = None
while attempt < max_retries and not success:
try:
success, result = await self.vector_store.store_embeddings(chunk_objects)
if not success:
raise Exception("Failed to store chunk embeddings")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(f"Database connection error during embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(f"All database connection attempts failed after {max_retries} retries: {error_msg}")
raise Exception("Failed to store chunk embeddings after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing embeddings: {error_msg}")
raise
logger.debug("Stored chunk embeddings in vector store")
doc.chunk_ids = result
if use_colpali and self.colpali_vector_store and chunk_objects_multivector:
success, result_multivector = await self.colpali_vector_store.store_embeddings(
chunk_objects_multivector
)
if not success:
raise Exception("Failed to store multivector chunk embeddings")
# Reset retry variables for colpali storage
attempt = 0
retry_delay = 1.0
success = False
result_multivector = None
while attempt < max_retries and not success:
try:
success, result_multivector = await self.colpali_vector_store.store_embeddings(
chunk_objects_multivector
)
if not success:
raise Exception("Failed to store multivector chunk embeddings")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(f"Database connection error during colpali embeddings storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(f"All colpali database connection attempts failed after {max_retries} retries: {error_msg}")
raise Exception("Failed to store multivector chunk embeddings after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing colpali embeddings: {error_msg}")
raise
logger.debug("Stored multivector chunk embeddings in vector store")
doc.chunk_ids += result_multivector
# Store document metadata
if is_update and auth:
# For updates, use update_document
updates = {
"chunk_ids": doc.chunk_ids,
"metadata": doc.metadata,
"system_metadata": doc.system_metadata,
"filename": doc.filename,
"content_type": doc.content_type,
"storage_info": doc.storage_info,
}
if not await self.db.update_document(doc.external_id, updates, auth):
raise Exception("Failed to update document metadata")
logger.debug("Updated document metadata in database")
else:
# For new documents, use store_document
if not await self.db.store_document(doc):
raise Exception("Failed to store document metadata")
logger.debug("Stored document metadata in database")
# Store document metadata with retry
attempt = 0
retry_delay = 1.0
success = False
while attempt < max_retries and not success:
try:
if is_update and auth:
# For updates, use update_document
updates = {
"chunk_ids": doc.chunk_ids,
"metadata": doc.metadata,
"system_metadata": doc.system_metadata,
"filename": doc.filename,
"content_type": doc.content_type,
"storage_info": doc.storage_info,
}
success = await self.db.update_document(doc.external_id, updates, auth)
if not success:
raise Exception("Failed to update document metadata")
else:
# For new documents, use store_document
success = await self.db.store_document(doc)
if not success:
raise Exception("Failed to store document metadata")
break
except Exception as e:
attempt += 1
error_msg = str(e)
if "connection was closed" in error_msg or "ConnectionDoesNotExistError" in error_msg:
if attempt < max_retries:
logger.warning(f"Database connection error during document metadata storage (attempt {attempt}/{max_retries}): {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
# Increase delay for next retry (exponential backoff)
retry_delay *= 2
else:
logger.error(f"All database connection attempts failed after {max_retries} retries: {error_msg}")
raise Exception("Failed to store document metadata after multiple retries")
else:
# For other exceptions, don't retry
logger.error(f"Error storing document metadata: {error_msg}")
raise
logger.debug("Stored document metadata in database")
logger.debug(f"Chunk IDs stored: {doc.chunk_ids}")
return doc.chunk_ids

View File

@ -287,10 +287,13 @@ class MultiVectorStore(BaseVectorStore):
params = [binary_query_embeddings]
# Add document filter if needed
# Add document filter if needed with proper parameterization
if doc_ids:
doc_ids_str = "', '".join(doc_ids)
query += f" WHERE document_id IN ('{doc_ids_str}')"
# Use placeholders for each document ID
placeholders = ', '.join(['%s'] * len(doc_ids))
query += f" WHERE document_id IN ({placeholders})"
# Add document IDs to params
params.extend(doc_ids)
# Add ordering and limit
query += " ORDER BY similarity DESC LIMIT %s"

View File

@ -83,15 +83,41 @@ class PGVectorStore(BaseVectorStore):
max_retries: Maximum number of connection retry attempts
retry_delay: Delay in seconds between retry attempts
"""
# Load settings from config
from core.config import get_settings
settings = get_settings()
# Get database pool settings from config with defaults
pool_size = getattr(settings, "DB_POOL_SIZE", 20)
max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30)
pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600)
pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10)
pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True)
# Use the URI exactly as provided without any modifications
# This ensures compatibility with Supabase and other PostgreSQL providers
logger.info(f"Initializing database engine with provided URI")
logger.info(f"Initializing vector store database engine with pool size={pool_size}, "
f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s")
# Create the engine with the URI as is
self.engine = create_async_engine(uri)
# Create the engine with the URI as is and improved connection pool settings
self.engine = create_async_engine(
uri,
# Prevent connection timeouts by keeping connections alive
pool_pre_ping=pool_pre_ping,
# Increase pool size to handle concurrent operations
pool_size=pool_size,
# Maximum overflow connections allowed beyond pool_size
max_overflow=max_overflow,
# Keep connections in the pool for up to 60 minutes
pool_recycle=pool_recycle,
# Time to wait for a connection from the pool (10 seconds)
pool_timeout=pool_timeout,
# Echo SQL for debugging (set to False in production)
echo=False,
)
# Log success
logger.info("Created database engine successfully")
logger.info("Created vector store database engine successfully")
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self.max_retries = max_retries
self.retry_delay = retry_delay

View File

@ -3,6 +3,7 @@ import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, UTC
from pathlib import Path
import asyncio
import arq
from core.models.auth import AuthContext, EntityType
@ -20,9 +21,58 @@ from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService
from core.services.rules_processor import RulesProcessor
from core.config import get_settings
from sqlalchemy import text
logger = logging.getLogger(__name__)
async def get_document_with_retry(document_service, document_id, auth, max_retries=3, initial_delay=0.3):
"""
Helper function to get a document with retries to handle race conditions.
Args:
document_service: The document service instance
document_id: ID of the document to retrieve
auth: Authentication context
max_retries: Maximum number of retry attempts
initial_delay: Initial delay before first attempt in seconds
Returns:
Document if found and accessible, None otherwise
"""
attempt = 0
retry_delay = initial_delay
# Add initial delay to allow transaction to commit
if initial_delay > 0:
await asyncio.sleep(initial_delay)
while attempt < max_retries:
try:
doc = await document_service.db.get_document(document_id, auth)
if doc:
logger.debug(f"Successfully retrieved document {document_id} on attempt {attempt+1}")
return doc
# Document not found but no exception raised
attempt += 1
if attempt < max_retries:
logger.warning(f"Document {document_id} not found on attempt {attempt}/{max_retries}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay *= 1.5
except Exception as e:
attempt += 1
error_msg = str(e)
if attempt < max_retries:
logger.warning(f"Error retrieving document on attempt {attempt}/{max_retries}: {error_msg}. Retrying in {retry_delay}s...")
await asyncio.sleep(retry_delay)
retry_delay *= 1.5
else:
logger.error(f"Failed to retrieve document after {max_retries} attempts: {error_msg}")
return None
return None
async def process_ingestion_job(
ctx: Dict[str, Any],
document_id: str,
@ -100,10 +150,12 @@ async def process_ingestion_job(
# 6. Retrieve the existing document
logger.debug(f"Retrieving document with ID: {document_id}")
logger.debug(f"Auth context: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}")
doc = await document_service.db.get_document(document_id, auth)
# Use the retry helper function with initial delay to handle race conditions
doc = await get_document_with_retry(document_service, document_id, auth, max_retries=5, initial_delay=1.0)
if not doc:
logger.error(f"Document {document_id} not found in database")
logger.error(f"Document {document_id} not found in database after multiple retries")
logger.error(f"Details - file: {original_filename}, content_type: {content_type}, bucket: {bucket}, key: {file_key}")
logger.error(f"Auth: entity_type={auth.entity_type}, entity_id={auth.entity_id}, permissions={auth.permissions}")
# Try to get all accessible documents to debug
@ -113,7 +165,7 @@ async def process_ingestion_job(
except Exception as list_err:
logger.error(f"Failed to list user documents: {str(list_err)}")
raise ValueError(f"Document {document_id} not found in database")
raise ValueError(f"Document {document_id} not found in database after multiple retries")
# Prepare updates for the document
updates = {
@ -342,7 +394,13 @@ async def startup(ctx):
logger.info("Initializing ColPali components...")
colpali_embedding_model = ColpaliEmbeddingModel()
colpali_vector_store = MultiVectorStore(uri=settings.POSTGRES_URI)
_ = colpali_vector_store.initialize()
# Properly await the initialization to ensure indexes are ready
# MultiVectorStore.initialize is synchronous, so we need to run it in a thread
success = await asyncio.to_thread(colpali_vector_store.initialize)
if success:
logger.info("ColPali vector store initialization successful")
else:
logger.error("ColPali vector store initialization failed")
ctx['colpali_embedding_model'] = colpali_embedding_model
ctx['colpali_vector_store'] = colpali_vector_store
@ -390,6 +448,16 @@ async def shutdown(ctx):
logger.info("Closing database connections...")
await ctx['database'].engine.dispose()
# Close vector store connections if they exist
if 'vector_store' in ctx and hasattr(ctx['vector_store'], 'engine'):
logger.info("Closing vector store connections...")
await ctx['vector_store'].engine.dispose()
# Close colpali vector store connections if they exist
if 'colpali_vector_store' in ctx and hasattr(ctx['colpali_vector_store'], 'engine'):
logger.info("Closing colpali vector store connections...")
await ctx['colpali_vector_store'].engine.dispose()
# Close any other open connections or resources that need cleanup
logger.info("Worker shutdown complete.")
@ -408,4 +476,48 @@ class WorkerSettings:
# Other optional settings:
# redis_settings = arq.connections.RedisSettings(host='localhost', port=6379)
keep_result_ms = 24 * 60 * 60 * 1000 # Keep results for 24 hours (24 * 60 * 60 * 1000 ms)
max_jobs = 10 # Maximum number of jobs to run concurrently
max_jobs = 5 # Reduce concurrent jobs to prevent connection pool exhaustion
health_check_interval = 300 # Check worker health every 5 minutes instead of 30 seconds to reduce connection overhead
job_timeout = 3600 # 1 hour timeout for jobs
max_tries = 3 # Retry failed jobs up to 3 times
poll_delay = 0.5 # Poll delay to prevent excessive Redis queries
# Log Redis and connection pool information for debugging
@staticmethod
async def health_check(ctx):
"""Periodic health check to log connection status and job stats."""
database = ctx.get('database')
vector_store = ctx.get('vector_store')
job_stats = ctx.get('job_stats', {})
redis_info = await ctx['redis'].info()
logger.info(f"Health check: Redis v{redis_info.get('redis_version', 'unknown')} "
f"mem_usage={redis_info.get('used_memory_human', 'unknown')} "
f"clients_connected={redis_info.get('connected_clients', 'unknown')} "
f"db_keys={redis_info.get('db0', {}).get('keys', 0)}"
)
# Log job statistics
logger.info(f"Job stats: completed={job_stats.get('complete', 0)} "
f"failed={job_stats.get('failed', 0)} "
f"retried={job_stats.get('retried', 0)} "
f"ongoing={job_stats.get('ongoing', 0)} "
f"queued={job_stats.get('queued', 0)}"
)
# Test database connectivity
if database and hasattr(database, 'async_session'):
try:
async with database.async_session() as session:
await session.execute(text("SELECT 1"))
logger.debug("Database connection is healthy")
except Exception as e:
logger.error(f"Database connection test failed: {str(e)}")
# Test vector store connectivity if available
if vector_store and hasattr(vector_store, 'async_session'):
try:
async with vector_store.get_session_with_retry() as session:
logger.debug("Vector store connection is healthy")
except Exception as e:
logger.error(f"Vector store connection test failed: {str(e)}")

View File

@ -53,6 +53,14 @@ default_temperature = 0.5
[database]
provider = "postgres"
# Connection pool settings
pool_size = 10 # Maximum number of connections in the pool
max_overflow = 15 # Maximum number of connections that can be created beyond pool_size
pool_recycle = 3600 # Time in seconds after which a connection is recycled (1 hour)
pool_timeout = 10 # Seconds to wait for a connection from the pool
pool_pre_ping = true # Check connection viability before using it from the pool
max_retries = 3 # Number of retries for database operations
retry_delay = 1.0 # Initial delay between retries in seconds
[embedding]
model = "ollama_embedding" # Reference to registered model