From 5ae396cddbfcf9d319f81f8bf89c95412f018ea0 Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Thu, 3 Apr 2025 12:54:54 -0700 Subject: [PATCH] Squashed changes from hosted-service --- core/api.py | 13 ++ core/services/document_service.py | 2 + core/vector_store/multi_vector_store.py | 283 ++++++++++++++---------- core/vector_store/pgvector_store.py | 210 ++++++++++++++++-- 4 files changed, 373 insertions(+), 135 deletions(-) diff --git a/core/api.py b/core/api.py index aed9b96..819e11a 100644 --- a/core/api.py +++ b/core/api.py @@ -110,6 +110,19 @@ async def initialize_database(): # We don't raise an exception here to allow the app to continue starting # even if there are initialization errors +@app.on_event("startup") +async def initialize_vector_store(): + """Initialize vector store tables and indexes on application startup.""" + logger.info("Initializing vector store...") + if hasattr(vector_store, 'initialize'): + success = await vector_store.initialize() + if success: + logger.info("Vector store initialization successful") + else: + logger.error("Vector store initialization failed") + else: + logger.warning("Vector store does not have an initialize method") + # Initialize vector store match settings.VECTOR_STORE_PROVIDER: case "mongodb": diff --git a/core/services/document_service.py b/core/services/document_service.py index 7c41c76..b5219c4 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -151,6 +151,8 @@ class DocumentService: # by filtering for only the chunks which weren't ingested via colpali... if len(chunks_multivector) == 0: return chunks + if len(chunks) == 0: + return chunks_multivector # TODO: this is duct tape, fix it properly later diff --git a/core/vector_store/multi_vector_store.py b/core/vector_store/multi_vector_store.py index d10ff7a..56a9e9f 100644 --- a/core/vector_store/multi_vector_store.py +++ b/core/vector_store/multi_vector_store.py @@ -1,7 +1,9 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, ContextManager import logging import torch import numpy as np +import time +from contextlib import contextmanager import psycopg from pgvector.psycopg import Bit, register_vector from core.models.chunk import DocumentChunk @@ -16,135 +18,183 @@ class MultiVectorStore(BaseVectorStore): def __init__( self, uri: str, + max_retries: int = 3, + retry_delay: float = 1.0, ): """Initialize PostgreSQL connection for multi-vector storage. Args: uri: PostgreSQL connection URI + max_retries: Maximum number of connection retry attempts + retry_delay: Delay in seconds between retry attempts """ # Convert SQLAlchemy URI to psycopg format if needed if uri.startswith("postgresql+asyncpg://"): uri = uri.replace("postgresql+asyncpg://", "postgresql://") self.uri = uri - # self.conn = psycopg.connect(self.uri, autocommit=True) self.conn = None + self.max_retries = max_retries + self.retry_delay = retry_delay self.initialize() + + @contextmanager + def get_connection(self): + """Get a PostgreSQL connection with retry logic. + + Yields: + A PostgreSQL connection object + + Raises: + psycopg.OperationalError: If all connection attempts fail + """ + attempt = 0 + last_error = None + + # Try to establish a new connection with retries + while attempt < self.max_retries: + try: + # Always create a fresh connection for critical operations + conn = psycopg.connect(self.uri, autocommit=True) + # Register vector extension on every new connection + register_vector(conn) + + try: + yield conn + return + finally: + # Always close connections after use + try: + conn.close() + except Exception: + pass + except psycopg.OperationalError as e: + last_error = e + attempt += 1 + if attempt < self.max_retries: + logger.warning(f"Connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...") + time.sleep(self.retry_delay) + + # If we get here, all retries failed + logger.error(f"All connection attempts failed after {self.max_retries} retries: {str(last_error)}") + raise last_error def initialize(self): """Initialize database tables and max_sim function.""" try: - # Connect to database - self.conn = psycopg.connect(self.uri, autocommit=True) - - # Register vector extension - self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector") - register_vector(self.conn) + # Use the connection with retry logic + with self.get_connection() as conn: + # Register vector extension + conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + register_vector(conn) # First check if the table exists and if it has the required columns - check_table = self.conn.execute( - """ - SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_name = 'multi_vector_embeddings' - ); - """ - ).fetchone()[0] - - if check_table: - # Check if document_id column exists - has_document_id = self.conn.execute( + with self.get_connection() as conn: + check_table = conn.execute( """ SELECT EXISTS ( - SELECT FROM information_schema.columns - WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id' + SELECT FROM information_schema.tables + WHERE table_name = 'multi_vector_embeddings' ); """ ).fetchone()[0] - # If the table exists but doesn't have document_id, we need to add the required columns - if not has_document_id: - logger.info("Updating multi_vector_embeddings table with required columns") - self.conn.execute( + if check_table: + # Check if document_id column exists + has_document_id = conn.execute( """ - ALTER TABLE multi_vector_embeddings - ADD COLUMN document_id TEXT, - ADD COLUMN chunk_number INTEGER, - ADD COLUMN content TEXT, - ADD COLUMN chunk_metadata TEXT + SELECT EXISTS ( + SELECT FROM information_schema.columns + WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id' + ); """ - ) - self.conn.execute( + ).fetchone()[0] + + # If the table exists but doesn't have document_id, we need to add the required columns + if not has_document_id: + logger.info("Updating multi_vector_embeddings table with required columns") + conn.execute( + """ + ALTER TABLE multi_vector_embeddings + ADD COLUMN document_id TEXT, + ADD COLUMN chunk_number INTEGER, + ADD COLUMN content TEXT, + ADD COLUMN chunk_metadata TEXT """ - ALTER TABLE multi_vector_embeddings - ALTER COLUMN document_id SET NOT NULL + ) + conn.execute( + """ + ALTER TABLE multi_vector_embeddings + ALTER COLUMN document_id SET NOT NULL + """ + ) + + # Add a commit to ensure changes are applied + conn.commit() + else: + # Create table if it doesn't exist with all required columns + conn.execute( + """ + CREATE TABLE IF NOT EXISTS multi_vector_embeddings ( + id BIGSERIAL PRIMARY KEY, + document_id TEXT NOT NULL, + chunk_number INTEGER NOT NULL, + content TEXT NOT NULL, + chunk_metadata TEXT, + embeddings BIT(128)[] + ) """ ) - # Add a commit to ensure changes are applied - self.conn.commit() - else: - # Create table if it doesn't exist with all required columns - self.conn.execute( - """ - CREATE TABLE IF NOT EXISTS multi_vector_embeddings ( - id BIGSERIAL PRIMARY KEY, - document_id TEXT NOT NULL, - chunk_number INTEGER NOT NULL, - content TEXT NOT NULL, - chunk_metadata TEXT, - embeddings BIT(128)[] - ) - """ - ) - - # Add a commit to ensure table creation is complete - self.conn.commit() + # Add a commit to ensure table creation is complete + conn.commit() try: # Create index on document_id - self.conn.execute( + with self.get_connection() as conn: + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id + ON multi_vector_embeddings (document_id) """ - CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id - ON multi_vector_embeddings (document_id) - """ - ) + ) except Exception as e: # Log index creation failure but continue logger.warning(f"Failed to create index: {str(e)}") try: # First, try to drop the existing function if it exists - self.conn.execute( + with self.get_connection() as conn: + conn.execute( + """ + DROP FUNCTION IF EXISTS max_sim(bit[], bit[]) """ - DROP FUNCTION IF EXISTS max_sim(bit[], bit[]) - """ - ) - logger.info("Dropped existing max_sim function") + ) + logger.info("Dropped existing max_sim function") - # Create max_sim function - self.conn.execute( + # Create max_sim function + conn.execute( + """ + CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$ + WITH queries AS ( + SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo + ), + documents AS ( + SELECT unnest(document) AS document + ), + similarities AS ( + SELECT + query_number, + 1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity + FROM queries CROSS JOIN documents + ), + max_similarities AS ( + SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number + ) + SELECT SUM(max_similarity) FROM max_similarities + $$ LANGUAGE SQL """ - CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$ - WITH queries AS ( - SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo - ), - documents AS ( - SELECT unnest(document) AS document - ), - similarities AS ( - SELECT - query_number, - 1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity - FROM queries CROSS JOIN documents - ), - max_similarities AS ( - SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number - ) - SELECT SUM(max_similarity) FROM max_similarities - $$ LANGUAGE SQL - """ - ) - logger.info("Created max_sim function successfully") + ) + logger.info("Created max_sim function successfully") except Exception as e: logger.error(f"Error creating max_sim function: {str(e)}") # Continue even if function creation fails - it might already exist and be usable @@ -161,11 +211,12 @@ class MultiVectorStore(BaseVectorStore): embeddings = embeddings.cpu().numpy() if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray): embeddings = np.array(embeddings) - # try: + + # Add this check to ensure pgvector is registered for the connection + with self.get_connection() as conn: + register_vector(conn) + return [Bit(embedding > 0) for embedding in embeddings] - # except Exception as e: - # logger.error(f"Error quantizing embeddings: {str(e)}") - # raise e async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]: """Store document chunks with their multi-vector embeddings.""" @@ -189,21 +240,22 @@ class MultiVectorStore(BaseVectorStore): # Create binary representation for each vector binary_embeddings = self._binary_quantize(embeddings) - # Insert into database - self.conn.execute( - """ - INSERT INTO multi_vector_embeddings - (document_id, chunk_number, content, chunk_metadata, embeddings) - VALUES (%s, %s, %s, %s, %s) - """, - ( - chunk.document_id, - chunk.chunk_number, - chunk.content, - str(chunk.metadata), - binary_embeddings, - ), - ) + # Insert into database with retry logic + with self.get_connection() as conn: + conn.execute( + """ + INSERT INTO multi_vector_embeddings + (document_id, chunk_number, content, chunk_metadata, embeddings) + VALUES (%s, %s, %s, %s, %s) + """, + ( + chunk.document_id, + chunk.chunk_number, + chunk.content, + str(chunk.metadata), + binary_embeddings, + ), + ) stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}") @@ -244,8 +296,9 @@ class MultiVectorStore(BaseVectorStore): query += " ORDER BY similarity DESC LIMIT %s" params.append(k) - # Execute query - result = self.conn.execute(query, params).fetchall() + # Execute query with retry logic + with self.get_connection() as conn: + result = conn.execute(query, params).fetchall() # Convert to DocumentChunks chunks = [] @@ -305,7 +358,8 @@ class MultiVectorStore(BaseVectorStore): logger.debug(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store") - result = self.conn.execute(query).fetchall() + with self.get_connection() as conn: + result = conn.execute(query).fetchall() # Convert to DocumentChunks chunks = [] @@ -339,9 +393,10 @@ class MultiVectorStore(BaseVectorStore): bool: True if the operation was successful, False otherwise """ try: - # Delete all chunks for the specified document + # Delete all chunks for the specified document with retry logic query = f"DELETE FROM multi_vector_embeddings WHERE document_id = '{document_id}'" - self.conn.execute(query) + with self.get_connection() as conn: + conn.execute(query) logger.info(f"Deleted all chunks for document {document_id} from multi-vector store") return True @@ -353,4 +408,8 @@ class MultiVectorStore(BaseVectorStore): def close(self): """Close the database connection.""" if self.conn: - self.conn.close() + try: + self.conn.close() + self.conn = None + except Exception as e: + logger.error(f"Error closing connection: {str(e)}") diff --git a/core/vector_store/pgvector_store.py b/core/vector_store/pgvector_store.py index d3890ff..c6dd1a8 100644 --- a/core/vector_store/pgvector_store.py +++ b/core/vector_store/pgvector_store.py @@ -1,10 +1,14 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, AsyncContextManager, Any import logging +import time +import asyncio +from contextlib import asynccontextmanager from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import Column, String, Integer, Index, select, text from sqlalchemy.sql.expression import func from sqlalchemy.types import UserDefinedType +from sqlalchemy.exc import OperationalError from .base_vector_store import BaseVectorStore from core.models.chunk import DocumentChunk @@ -68,34 +72,194 @@ class PGVectorStore(BaseVectorStore): def __init__( self, uri: str, + max_retries: int = 3, + retry_delay: float = 1.0, ): - """Initialize PostgreSQL connection for vector storage.""" + """Initialize PostgreSQL connection for vector storage. + + Args: + uri: PostgreSQL connection URI + max_retries: Maximum number of connection retry attempts + retry_delay: Delay in seconds between retry attempts + """ + # Use the URI exactly as provided without any modifications + # This ensures compatibility with Supabase and other PostgreSQL providers + logger.info(f"Initializing database engine with provided URI") + + # Create the engine with the URI as is self.engine = create_async_engine(uri) + + # Log success + logger.info("Created database engine successfully") self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) + self.max_retries = max_retries + self.retry_delay = retry_delay + + @asynccontextmanager + async def get_session_with_retry(self) -> AsyncContextManager[AsyncSession]: + """Get a SQLAlchemy async session with retry logic. + + Yields: + AsyncSession: A SQLAlchemy async session + + Raises: + OperationalError: If all connection attempts fail + """ + attempt = 0 + last_error = None + + while attempt < self.max_retries: + try: + async with self.async_session() as session: + # Test if the connection is valid with a simple query + await session.execute(text("SELECT 1")) + yield session + return + except OperationalError as e: + last_error = e + attempt += 1 + if attempt < self.max_retries: + logger.warning(f"Database connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...") + await asyncio.sleep(self.retry_delay) + + # If we get here, all retries failed + logger.error(f"All database connection attempts failed after {self.max_retries} retries: {str(last_error)}") + raise last_error async def initialize(self): """Initialize database tables and vector extension.""" try: + # Import config to get vector dimensions + from core.config import get_settings + settings = get_settings() + dimensions = settings.VECTOR_DIMENSIONS + + logger.info(f"Initializing PGVector store with {dimensions} dimensions") + + # Use retry logic for initialization + attempt = 0 + last_error = None + + while attempt < self.max_retries: + try: + async with self.engine.begin() as conn: + # Enable pgvector extension + await conn.execute( + text("CREATE EXTENSION IF NOT EXISTS vector") + ) + logger.info("Enabled pgvector extension") + + # Rest of initialization code follows + break # Success, exit the retry loop + except OperationalError as e: + last_error = e + attempt += 1 + if attempt < self.max_retries: + logger.warning(f"Database initialization attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...") + await asyncio.sleep(self.retry_delay) + else: + logger.error(f"All database initialization attempts failed after {self.max_retries} retries: {str(last_error)}") + raise last_error + + # Continue with the rest of the initialization async with self.engine.begin() as conn: - # Enable pgvector extension - await conn.execute( - func.create_extension("vector", schema="public", if_not_exists=True) - ) - - # Create tables and indexes - await conn.run_sync(Base.metadata.create_all) - - # Create vector index - await conn.execute( - text( - """ - CREATE INDEX IF NOT EXISTS vector_idx - ON vector_embeddings - USING ivfflat (embedding vector_cosine_ops) - WITH (lists = 100); + + # Check if vector_embeddings table exists + check_table_sql = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'vector_embeddings' + ); """ + result = await conn.execute(text(check_table_sql)) + table_exists = result.scalar() + + if table_exists: + # Check current vector dimensions + check_dim_sql = """ + SELECT atttypmod - 4 AS dimensions + FROM pg_attribute a + JOIN pg_class c ON a.attrelid = c.oid + JOIN pg_type t ON a.atttypid = t.oid + WHERE c.relname = 'vector_embeddings' + AND a.attname = 'embedding' + AND t.typname = 'vector'; + """ + result = await conn.execute(text(check_dim_sql)) + current_dim = result.scalar() + + if current_dim != dimensions: + logger.info(f"Vector dimensions changed from {current_dim} to {dimensions}, recreating table") + + # Drop existing vector index if it exists + await conn.execute(text("DROP INDEX IF EXISTS vector_idx;")) + + # Drop existing vector embeddings table + await conn.execute(text("DROP TABLE IF EXISTS vector_embeddings;")) + + # Create vector embeddings table with proper vector column + create_table_sql = f""" + CREATE TABLE vector_embeddings ( + id SERIAL PRIMARY KEY, + document_id VARCHAR(255) NOT NULL, + chunk_number INTEGER NOT NULL, + content TEXT NOT NULL, + chunk_metadata TEXT, + embedding vector({dimensions}) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ); + """ + await conn.execute(text(create_table_sql)) + logger.info(f"Created vector_embeddings table with vector({dimensions})") + + # Create indexes + await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);")) + + # Create vector index + await conn.execute( + text( + f""" + CREATE INDEX vector_idx + ON vector_embeddings + USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """ + ) + ) + logger.info("Created IVFFlat index on vector_embeddings") + else: + logger.info(f"Vector dimensions unchanged ({dimensions}), using existing table") + else: + # Create tables and indexes if they don't exist + create_table_sql = f""" + CREATE TABLE vector_embeddings ( + id SERIAL PRIMARY KEY, + document_id VARCHAR(255) NOT NULL, + chunk_number INTEGER NOT NULL, + content TEXT NOT NULL, + chunk_metadata TEXT, + embedding vector({dimensions}) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ); + """ + await conn.execute(text(create_table_sql)) + logger.info(f"Created vector_embeddings table with vector({dimensions})") + + # Create indexes + await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);")) + + # Create vector index + await conn.execute( + text( + f""" + CREATE INDEX vector_idx + ON vector_embeddings + USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """ + ) ) - ) + logger.info("Created IVFFlat index on vector_embeddings") logger.info("PGVector store initialized successfully") return True @@ -109,7 +273,7 @@ class PGVectorStore(BaseVectorStore): if not chunks: return True, [] - async with self.async_session() as session: + async with self.get_session_with_retry() as session: stored_ids = [] for chunk in chunks: if not chunk.embedding: @@ -143,7 +307,7 @@ class PGVectorStore(BaseVectorStore): ) -> List[DocumentChunk]: """Find similar chunks using cosine similarity.""" try: - async with self.async_session() as session: + async with self.get_session_with_retry() as session: # Build query query = select(VectorEmbedding).order_by( VectorEmbedding.embedding.op("<->")(query_embedding) @@ -196,7 +360,7 @@ class PGVectorStore(BaseVectorStore): if not chunk_identifiers: return [] - async with self.async_session() as session: + async with self.get_session_with_retry() as session: # Create a list of OR conditions for the query conditions = [] for doc_id, chunk_num in chunk_identifiers: @@ -253,7 +417,7 @@ class PGVectorStore(BaseVectorStore): bool: True if the operation was successful, False otherwise """ try: - async with self.async_session() as session: + async with self.get_session_with_retry() as session: # Delete all chunks for the specified document query = text(f"DELETE FROM vector_embeddings WHERE document_id = :doc_id") await session.execute(query, {"doc_id": document_id})