Squashed changes from hosted-service

This commit is contained in:
Adityavardhan Agrawal 2025-04-03 12:54:54 -07:00
parent bf7c90164f
commit 5ae396cddb
4 changed files with 373 additions and 135 deletions

View File

@ -110,6 +110,19 @@ async def initialize_database():
# We don't raise an exception here to allow the app to continue starting
# even if there are initialization errors
@app.on_event("startup")
async def initialize_vector_store():
"""Initialize vector store tables and indexes on application startup."""
logger.info("Initializing vector store...")
if hasattr(vector_store, 'initialize'):
success = await vector_store.initialize()
if success:
logger.info("Vector store initialization successful")
else:
logger.error("Vector store initialization failed")
else:
logger.warning("Vector store does not have an initialize method")
# Initialize vector store
match settings.VECTOR_STORE_PROVIDER:
case "mongodb":

View File

@ -151,6 +151,8 @@ class DocumentService:
# by filtering for only the chunks which weren't ingested via colpali...
if len(chunks_multivector) == 0:
return chunks
if len(chunks) == 0:
return chunks_multivector
# TODO: this is duct tape, fix it properly later

View File

@ -1,7 +1,9 @@
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, ContextManager
import logging
import torch
import numpy as np
import time
from contextlib import contextmanager
import psycopg
from pgvector.psycopg import Bit, register_vector
from core.models.chunk import DocumentChunk
@ -16,135 +18,183 @@ class MultiVectorStore(BaseVectorStore):
def __init__(
self,
uri: str,
max_retries: int = 3,
retry_delay: float = 1.0,
):
"""Initialize PostgreSQL connection for multi-vector storage.
Args:
uri: PostgreSQL connection URI
max_retries: Maximum number of connection retry attempts
retry_delay: Delay in seconds between retry attempts
"""
# Convert SQLAlchemy URI to psycopg format if needed
if uri.startswith("postgresql+asyncpg://"):
uri = uri.replace("postgresql+asyncpg://", "postgresql://")
self.uri = uri
# self.conn = psycopg.connect(self.uri, autocommit=True)
self.conn = None
self.max_retries = max_retries
self.retry_delay = retry_delay
self.initialize()
@contextmanager
def get_connection(self):
"""Get a PostgreSQL connection with retry logic.
Yields:
A PostgreSQL connection object
Raises:
psycopg.OperationalError: If all connection attempts fail
"""
attempt = 0
last_error = None
# Try to establish a new connection with retries
while attempt < self.max_retries:
try:
# Always create a fresh connection for critical operations
conn = psycopg.connect(self.uri, autocommit=True)
# Register vector extension on every new connection
register_vector(conn)
try:
yield conn
return
finally:
# Always close connections after use
try:
conn.close()
except Exception:
pass
except psycopg.OperationalError as e:
last_error = e
attempt += 1
if attempt < self.max_retries:
logger.warning(f"Connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
time.sleep(self.retry_delay)
# If we get here, all retries failed
logger.error(f"All connection attempts failed after {self.max_retries} retries: {str(last_error)}")
raise last_error
def initialize(self):
"""Initialize database tables and max_sim function."""
try:
# Connect to database
self.conn = psycopg.connect(self.uri, autocommit=True)
# Register vector extension
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(self.conn)
# Use the connection with retry logic
with self.get_connection() as conn:
# Register vector extension
conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(conn)
# First check if the table exists and if it has the required columns
check_table = self.conn.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'multi_vector_embeddings'
);
"""
).fetchone()[0]
if check_table:
# Check if document_id column exists
has_document_id = self.conn.execute(
with self.get_connection() as conn:
check_table = conn.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id'
SELECT FROM information_schema.tables
WHERE table_name = 'multi_vector_embeddings'
);
"""
).fetchone()[0]
# If the table exists but doesn't have document_id, we need to add the required columns
if not has_document_id:
logger.info("Updating multi_vector_embeddings table with required columns")
self.conn.execute(
if check_table:
# Check if document_id column exists
has_document_id = conn.execute(
"""
ALTER TABLE multi_vector_embeddings
ADD COLUMN document_id TEXT,
ADD COLUMN chunk_number INTEGER,
ADD COLUMN content TEXT,
ADD COLUMN chunk_metadata TEXT
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id'
);
"""
)
self.conn.execute(
).fetchone()[0]
# If the table exists but doesn't have document_id, we need to add the required columns
if not has_document_id:
logger.info("Updating multi_vector_embeddings table with required columns")
conn.execute(
"""
ALTER TABLE multi_vector_embeddings
ADD COLUMN document_id TEXT,
ADD COLUMN chunk_number INTEGER,
ADD COLUMN content TEXT,
ADD COLUMN chunk_metadata TEXT
"""
ALTER TABLE multi_vector_embeddings
ALTER COLUMN document_id SET NOT NULL
)
conn.execute(
"""
ALTER TABLE multi_vector_embeddings
ALTER COLUMN document_id SET NOT NULL
"""
)
# Add a commit to ensure changes are applied
conn.commit()
else:
# Create table if it doesn't exist with all required columns
conn.execute(
"""
CREATE TABLE IF NOT EXISTS multi_vector_embeddings (
id BIGSERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
chunk_number INTEGER NOT NULL,
content TEXT NOT NULL,
chunk_metadata TEXT,
embeddings BIT(128)[]
)
"""
)
# Add a commit to ensure changes are applied
self.conn.commit()
else:
# Create table if it doesn't exist with all required columns
self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS multi_vector_embeddings (
id BIGSERIAL PRIMARY KEY,
document_id TEXT NOT NULL,
chunk_number INTEGER NOT NULL,
content TEXT NOT NULL,
chunk_metadata TEXT,
embeddings BIT(128)[]
)
"""
)
# Add a commit to ensure table creation is complete
self.conn.commit()
# Add a commit to ensure table creation is complete
conn.commit()
try:
# Create index on document_id
self.conn.execute(
with self.get_connection() as conn:
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id
ON multi_vector_embeddings (document_id)
"""
CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id
ON multi_vector_embeddings (document_id)
"""
)
)
except Exception as e:
# Log index creation failure but continue
logger.warning(f"Failed to create index: {str(e)}")
try:
# First, try to drop the existing function if it exists
self.conn.execute(
with self.get_connection() as conn:
conn.execute(
"""
DROP FUNCTION IF EXISTS max_sim(bit[], bit[])
"""
DROP FUNCTION IF EXISTS max_sim(bit[], bit[])
"""
)
logger.info("Dropped existing max_sim function")
)
logger.info("Dropped existing max_sim function")
# Create max_sim function
self.conn.execute(
# Create max_sim function
conn.execute(
"""
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
WITH queries AS (
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
),
documents AS (
SELECT unnest(document) AS document
),
similarities AS (
SELECT
query_number,
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
FROM queries CROSS JOIN documents
),
max_similarities AS (
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
)
SELECT SUM(max_similarity) FROM max_similarities
$$ LANGUAGE SQL
"""
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
WITH queries AS (
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
),
documents AS (
SELECT unnest(document) AS document
),
similarities AS (
SELECT
query_number,
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
FROM queries CROSS JOIN documents
),
max_similarities AS (
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
)
SELECT SUM(max_similarity) FROM max_similarities
$$ LANGUAGE SQL
"""
)
logger.info("Created max_sim function successfully")
)
logger.info("Created max_sim function successfully")
except Exception as e:
logger.error(f"Error creating max_sim function: {str(e)}")
# Continue even if function creation fails - it might already exist and be usable
@ -161,11 +211,12 @@ class MultiVectorStore(BaseVectorStore):
embeddings = embeddings.cpu().numpy()
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
embeddings = np.array(embeddings)
# try:
# Add this check to ensure pgvector is registered for the connection
with self.get_connection() as conn:
register_vector(conn)
return [Bit(embedding > 0) for embedding in embeddings]
# except Exception as e:
# logger.error(f"Error quantizing embeddings: {str(e)}")
# raise e
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
"""Store document chunks with their multi-vector embeddings."""
@ -189,21 +240,22 @@ class MultiVectorStore(BaseVectorStore):
# Create binary representation for each vector
binary_embeddings = self._binary_quantize(embeddings)
# Insert into database
self.conn.execute(
"""
INSERT INTO multi_vector_embeddings
(document_id, chunk_number, content, chunk_metadata, embeddings)
VALUES (%s, %s, %s, %s, %s)
""",
(
chunk.document_id,
chunk.chunk_number,
chunk.content,
str(chunk.metadata),
binary_embeddings,
),
)
# Insert into database with retry logic
with self.get_connection() as conn:
conn.execute(
"""
INSERT INTO multi_vector_embeddings
(document_id, chunk_number, content, chunk_metadata, embeddings)
VALUES (%s, %s, %s, %s, %s)
""",
(
chunk.document_id,
chunk.chunk_number,
chunk.content,
str(chunk.metadata),
binary_embeddings,
),
)
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
@ -244,8 +296,9 @@ class MultiVectorStore(BaseVectorStore):
query += " ORDER BY similarity DESC LIMIT %s"
params.append(k)
# Execute query
result = self.conn.execute(query, params).fetchall()
# Execute query with retry logic
with self.get_connection() as conn:
result = conn.execute(query, params).fetchall()
# Convert to DocumentChunks
chunks = []
@ -305,7 +358,8 @@ class MultiVectorStore(BaseVectorStore):
logger.debug(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store")
result = self.conn.execute(query).fetchall()
with self.get_connection() as conn:
result = conn.execute(query).fetchall()
# Convert to DocumentChunks
chunks = []
@ -339,9 +393,10 @@ class MultiVectorStore(BaseVectorStore):
bool: True if the operation was successful, False otherwise
"""
try:
# Delete all chunks for the specified document
# Delete all chunks for the specified document with retry logic
query = f"DELETE FROM multi_vector_embeddings WHERE document_id = '{document_id}'"
self.conn.execute(query)
with self.get_connection() as conn:
conn.execute(query)
logger.info(f"Deleted all chunks for document {document_id} from multi-vector store")
return True
@ -353,4 +408,8 @@ class MultiVectorStore(BaseVectorStore):
def close(self):
"""Close the database connection."""
if self.conn:
self.conn.close()
try:
self.conn.close()
self.conn = None
except Exception as e:
logger.error(f"Error closing connection: {str(e)}")

View File

@ -1,10 +1,14 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, AsyncContextManager, Any
import logging
import time
import asyncio
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, Integer, Index, select, text
from sqlalchemy.sql.expression import func
from sqlalchemy.types import UserDefinedType
from sqlalchemy.exc import OperationalError
from .base_vector_store import BaseVectorStore
from core.models.chunk import DocumentChunk
@ -68,34 +72,194 @@ class PGVectorStore(BaseVectorStore):
def __init__(
self,
uri: str,
max_retries: int = 3,
retry_delay: float = 1.0,
):
"""Initialize PostgreSQL connection for vector storage."""
"""Initialize PostgreSQL connection for vector storage.
Args:
uri: PostgreSQL connection URI
max_retries: Maximum number of connection retry attempts
retry_delay: Delay in seconds between retry attempts
"""
# Use the URI exactly as provided without any modifications
# This ensures compatibility with Supabase and other PostgreSQL providers
logger.info(f"Initializing database engine with provided URI")
# Create the engine with the URI as is
self.engine = create_async_engine(uri)
# Log success
logger.info("Created database engine successfully")
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self.max_retries = max_retries
self.retry_delay = retry_delay
@asynccontextmanager
async def get_session_with_retry(self) -> AsyncContextManager[AsyncSession]:
"""Get a SQLAlchemy async session with retry logic.
Yields:
AsyncSession: A SQLAlchemy async session
Raises:
OperationalError: If all connection attempts fail
"""
attempt = 0
last_error = None
while attempt < self.max_retries:
try:
async with self.async_session() as session:
# Test if the connection is valid with a simple query
await session.execute(text("SELECT 1"))
yield session
return
except OperationalError as e:
last_error = e
attempt += 1
if attempt < self.max_retries:
logger.warning(f"Database connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
await asyncio.sleep(self.retry_delay)
# If we get here, all retries failed
logger.error(f"All database connection attempts failed after {self.max_retries} retries: {str(last_error)}")
raise last_error
async def initialize(self):
"""Initialize database tables and vector extension."""
try:
# Import config to get vector dimensions
from core.config import get_settings
settings = get_settings()
dimensions = settings.VECTOR_DIMENSIONS
logger.info(f"Initializing PGVector store with {dimensions} dimensions")
# Use retry logic for initialization
attempt = 0
last_error = None
while attempt < self.max_retries:
try:
async with self.engine.begin() as conn:
# Enable pgvector extension
await conn.execute(
text("CREATE EXTENSION IF NOT EXISTS vector")
)
logger.info("Enabled pgvector extension")
# Rest of initialization code follows
break # Success, exit the retry loop
except OperationalError as e:
last_error = e
attempt += 1
if attempt < self.max_retries:
logger.warning(f"Database initialization attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
await asyncio.sleep(self.retry_delay)
else:
logger.error(f"All database initialization attempts failed after {self.max_retries} retries: {str(last_error)}")
raise last_error
# Continue with the rest of the initialization
async with self.engine.begin() as conn:
# Enable pgvector extension
await conn.execute(
func.create_extension("vector", schema="public", if_not_exists=True)
)
# Create tables and indexes
await conn.run_sync(Base.metadata.create_all)
# Create vector index
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS vector_idx
ON vector_embeddings
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
# Check if vector_embeddings table exists
check_table_sql = """
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'vector_embeddings'
);
"""
result = await conn.execute(text(check_table_sql))
table_exists = result.scalar()
if table_exists:
# Check current vector dimensions
check_dim_sql = """
SELECT atttypmod - 4 AS dimensions
FROM pg_attribute a
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_type t ON a.atttypid = t.oid
WHERE c.relname = 'vector_embeddings'
AND a.attname = 'embedding'
AND t.typname = 'vector';
"""
result = await conn.execute(text(check_dim_sql))
current_dim = result.scalar()
if current_dim != dimensions:
logger.info(f"Vector dimensions changed from {current_dim} to {dimensions}, recreating table")
# Drop existing vector index if it exists
await conn.execute(text("DROP INDEX IF EXISTS vector_idx;"))
# Drop existing vector embeddings table
await conn.execute(text("DROP TABLE IF EXISTS vector_embeddings;"))
# Create vector embeddings table with proper vector column
create_table_sql = f"""
CREATE TABLE vector_embeddings (
id SERIAL PRIMARY KEY,
document_id VARCHAR(255) NOT NULL,
chunk_number INTEGER NOT NULL,
content TEXT NOT NULL,
chunk_metadata TEXT,
embedding vector({dimensions}) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
"""
await conn.execute(text(create_table_sql))
logger.info(f"Created vector_embeddings table with vector({dimensions})")
# Create indexes
await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);"))
# Create vector index
await conn.execute(
text(
f"""
CREATE INDEX vector_idx
ON vector_embeddings
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
)
)
logger.info("Created IVFFlat index on vector_embeddings")
else:
logger.info(f"Vector dimensions unchanged ({dimensions}), using existing table")
else:
# Create tables and indexes if they don't exist
create_table_sql = f"""
CREATE TABLE vector_embeddings (
id SERIAL PRIMARY KEY,
document_id VARCHAR(255) NOT NULL,
chunk_number INTEGER NOT NULL,
content TEXT NOT NULL,
chunk_metadata TEXT,
embedding vector({dimensions}) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
"""
await conn.execute(text(create_table_sql))
logger.info(f"Created vector_embeddings table with vector({dimensions})")
# Create indexes
await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);"))
# Create vector index
await conn.execute(
text(
f"""
CREATE INDEX vector_idx
ON vector_embeddings
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
)
)
)
logger.info("Created IVFFlat index on vector_embeddings")
logger.info("PGVector store initialized successfully")
return True
@ -109,7 +273,7 @@ class PGVectorStore(BaseVectorStore):
if not chunks:
return True, []
async with self.async_session() as session:
async with self.get_session_with_retry() as session:
stored_ids = []
for chunk in chunks:
if not chunk.embedding:
@ -143,7 +307,7 @@ class PGVectorStore(BaseVectorStore):
) -> List[DocumentChunk]:
"""Find similar chunks using cosine similarity."""
try:
async with self.async_session() as session:
async with self.get_session_with_retry() as session:
# Build query
query = select(VectorEmbedding).order_by(
VectorEmbedding.embedding.op("<->")(query_embedding)
@ -196,7 +360,7 @@ class PGVectorStore(BaseVectorStore):
if not chunk_identifiers:
return []
async with self.async_session() as session:
async with self.get_session_with_retry() as session:
# Create a list of OR conditions for the query
conditions = []
for doc_id, chunk_num in chunk_identifiers:
@ -253,7 +417,7 @@ class PGVectorStore(BaseVectorStore):
bool: True if the operation was successful, False otherwise
"""
try:
async with self.async_session() as session:
async with self.get_session_with_retry() as session:
# Delete all chunks for the specified document
query = text(f"DELETE FROM vector_embeddings WHERE document_id = :doc_id")
await session.execute(query, {"doc_id": document_id})