mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Squashed changes from hosted-service
This commit is contained in:
parent
bf7c90164f
commit
5ae396cddb
13
core/api.py
13
core/api.py
@ -110,6 +110,19 @@ async def initialize_database():
|
|||||||
# We don't raise an exception here to allow the app to continue starting
|
# We don't raise an exception here to allow the app to continue starting
|
||||||
# even if there are initialization errors
|
# even if there are initialization errors
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def initialize_vector_store():
|
||||||
|
"""Initialize vector store tables and indexes on application startup."""
|
||||||
|
logger.info("Initializing vector store...")
|
||||||
|
if hasattr(vector_store, 'initialize'):
|
||||||
|
success = await vector_store.initialize()
|
||||||
|
if success:
|
||||||
|
logger.info("Vector store initialization successful")
|
||||||
|
else:
|
||||||
|
logger.error("Vector store initialization failed")
|
||||||
|
else:
|
||||||
|
logger.warning("Vector store does not have an initialize method")
|
||||||
|
|
||||||
# Initialize vector store
|
# Initialize vector store
|
||||||
match settings.VECTOR_STORE_PROVIDER:
|
match settings.VECTOR_STORE_PROVIDER:
|
||||||
case "mongodb":
|
case "mongodb":
|
||||||
|
@ -151,6 +151,8 @@ class DocumentService:
|
|||||||
# by filtering for only the chunks which weren't ingested via colpali...
|
# by filtering for only the chunks which weren't ingested via colpali...
|
||||||
if len(chunks_multivector) == 0:
|
if len(chunks_multivector) == 0:
|
||||||
return chunks
|
return chunks
|
||||||
|
if len(chunks) == 0:
|
||||||
|
return chunks_multivector
|
||||||
|
|
||||||
# TODO: this is duct tape, fix it properly later
|
# TODO: this is duct tape, fix it properly later
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union, ContextManager
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
import psycopg
|
import psycopg
|
||||||
from pgvector.psycopg import Bit, register_vector
|
from pgvector.psycopg import Bit, register_vector
|
||||||
from core.models.chunk import DocumentChunk
|
from core.models.chunk import DocumentChunk
|
||||||
@ -16,135 +18,183 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uri: str,
|
uri: str,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
):
|
):
|
||||||
"""Initialize PostgreSQL connection for multi-vector storage.
|
"""Initialize PostgreSQL connection for multi-vector storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri: PostgreSQL connection URI
|
uri: PostgreSQL connection URI
|
||||||
|
max_retries: Maximum number of connection retry attempts
|
||||||
|
retry_delay: Delay in seconds between retry attempts
|
||||||
"""
|
"""
|
||||||
# Convert SQLAlchemy URI to psycopg format if needed
|
# Convert SQLAlchemy URI to psycopg format if needed
|
||||||
if uri.startswith("postgresql+asyncpg://"):
|
if uri.startswith("postgresql+asyncpg://"):
|
||||||
uri = uri.replace("postgresql+asyncpg://", "postgresql://")
|
uri = uri.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
# self.conn = psycopg.connect(self.uri, autocommit=True)
|
|
||||||
self.conn = None
|
self.conn = None
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.retry_delay = retry_delay
|
||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_connection(self):
|
||||||
|
"""Get a PostgreSQL connection with retry logic.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A PostgreSQL connection object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
psycopg.OperationalError: If all connection attempts fail
|
||||||
|
"""
|
||||||
|
attempt = 0
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
# Try to establish a new connection with retries
|
||||||
|
while attempt < self.max_retries:
|
||||||
|
try:
|
||||||
|
# Always create a fresh connection for critical operations
|
||||||
|
conn = psycopg.connect(self.uri, autocommit=True)
|
||||||
|
# Register vector extension on every new connection
|
||||||
|
register_vector(conn)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
# Always close connections after use
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except psycopg.OperationalError as e:
|
||||||
|
last_error = e
|
||||||
|
attempt += 1
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
logger.warning(f"Connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
|
||||||
|
time.sleep(self.retry_delay)
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
logger.error(f"All connection attempts failed after {self.max_retries} retries: {str(last_error)}")
|
||||||
|
raise last_error
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialize database tables and max_sim function."""
|
"""Initialize database tables and max_sim function."""
|
||||||
try:
|
try:
|
||||||
# Connect to database
|
# Use the connection with retry logic
|
||||||
self.conn = psycopg.connect(self.uri, autocommit=True)
|
with self.get_connection() as conn:
|
||||||
|
# Register vector extension
|
||||||
# Register vector extension
|
conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
register_vector(conn)
|
||||||
register_vector(self.conn)
|
|
||||||
|
|
||||||
# First check if the table exists and if it has the required columns
|
# First check if the table exists and if it has the required columns
|
||||||
check_table = self.conn.execute(
|
with self.get_connection() as conn:
|
||||||
"""
|
check_table = conn.execute(
|
||||||
SELECT EXISTS (
|
|
||||||
SELECT FROM information_schema.tables
|
|
||||||
WHERE table_name = 'multi_vector_embeddings'
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
).fetchone()[0]
|
|
||||||
|
|
||||||
if check_table:
|
|
||||||
# Check if document_id column exists
|
|
||||||
has_document_id = self.conn.execute(
|
|
||||||
"""
|
"""
|
||||||
SELECT EXISTS (
|
SELECT EXISTS (
|
||||||
SELECT FROM information_schema.columns
|
SELECT FROM information_schema.tables
|
||||||
WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id'
|
WHERE table_name = 'multi_vector_embeddings'
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
).fetchone()[0]
|
).fetchone()[0]
|
||||||
|
|
||||||
# If the table exists but doesn't have document_id, we need to add the required columns
|
if check_table:
|
||||||
if not has_document_id:
|
# Check if document_id column exists
|
||||||
logger.info("Updating multi_vector_embeddings table with required columns")
|
has_document_id = conn.execute(
|
||||||
self.conn.execute(
|
|
||||||
"""
|
"""
|
||||||
ALTER TABLE multi_vector_embeddings
|
SELECT EXISTS (
|
||||||
ADD COLUMN document_id TEXT,
|
SELECT FROM information_schema.columns
|
||||||
ADD COLUMN chunk_number INTEGER,
|
WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id'
|
||||||
ADD COLUMN content TEXT,
|
);
|
||||||
ADD COLUMN chunk_metadata TEXT
|
|
||||||
"""
|
"""
|
||||||
)
|
).fetchone()[0]
|
||||||
self.conn.execute(
|
|
||||||
|
# If the table exists but doesn't have document_id, we need to add the required columns
|
||||||
|
if not has_document_id:
|
||||||
|
logger.info("Updating multi_vector_embeddings table with required columns")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE multi_vector_embeddings
|
||||||
|
ADD COLUMN document_id TEXT,
|
||||||
|
ADD COLUMN chunk_number INTEGER,
|
||||||
|
ADD COLUMN content TEXT,
|
||||||
|
ADD COLUMN chunk_metadata TEXT
|
||||||
"""
|
"""
|
||||||
ALTER TABLE multi_vector_embeddings
|
)
|
||||||
ALTER COLUMN document_id SET NOT NULL
|
conn.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE multi_vector_embeddings
|
||||||
|
ALTER COLUMN document_id SET NOT NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a commit to ensure changes are applied
|
||||||
|
conn.commit()
|
||||||
|
else:
|
||||||
|
# Create table if it doesn't exist with all required columns
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS multi_vector_embeddings (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
document_id TEXT NOT NULL,
|
||||||
|
chunk_number INTEGER NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
chunk_metadata TEXT,
|
||||||
|
embeddings BIT(128)[]
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add a commit to ensure changes are applied
|
# Add a commit to ensure table creation is complete
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
else:
|
|
||||||
# Create table if it doesn't exist with all required columns
|
|
||||||
self.conn.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS multi_vector_embeddings (
|
|
||||||
id BIGSERIAL PRIMARY KEY,
|
|
||||||
document_id TEXT NOT NULL,
|
|
||||||
chunk_number INTEGER NOT NULL,
|
|
||||||
content TEXT NOT NULL,
|
|
||||||
chunk_metadata TEXT,
|
|
||||||
embeddings BIT(128)[]
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add a commit to ensure table creation is complete
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create index on document_id
|
# Create index on document_id
|
||||||
self.conn.execute(
|
with self.get_connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id
|
||||||
|
ON multi_vector_embeddings (document_id)
|
||||||
"""
|
"""
|
||||||
CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id
|
)
|
||||||
ON multi_vector_embeddings (document_id)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log index creation failure but continue
|
# Log index creation failure but continue
|
||||||
logger.warning(f"Failed to create index: {str(e)}")
|
logger.warning(f"Failed to create index: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# First, try to drop the existing function if it exists
|
# First, try to drop the existing function if it exists
|
||||||
self.conn.execute(
|
with self.get_connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
DROP FUNCTION IF EXISTS max_sim(bit[], bit[])
|
||||||
"""
|
"""
|
||||||
DROP FUNCTION IF EXISTS max_sim(bit[], bit[])
|
)
|
||||||
"""
|
logger.info("Dropped existing max_sim function")
|
||||||
)
|
|
||||||
logger.info("Dropped existing max_sim function")
|
|
||||||
|
|
||||||
# Create max_sim function
|
# Create max_sim function
|
||||||
self.conn.execute(
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
|
||||||
|
WITH queries AS (
|
||||||
|
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
|
||||||
|
),
|
||||||
|
documents AS (
|
||||||
|
SELECT unnest(document) AS document
|
||||||
|
),
|
||||||
|
similarities AS (
|
||||||
|
SELECT
|
||||||
|
query_number,
|
||||||
|
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
|
||||||
|
FROM queries CROSS JOIN documents
|
||||||
|
),
|
||||||
|
max_similarities AS (
|
||||||
|
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
|
||||||
|
)
|
||||||
|
SELECT SUM(max_similarity) FROM max_similarities
|
||||||
|
$$ LANGUAGE SQL
|
||||||
"""
|
"""
|
||||||
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
|
)
|
||||||
WITH queries AS (
|
logger.info("Created max_sim function successfully")
|
||||||
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
|
|
||||||
),
|
|
||||||
documents AS (
|
|
||||||
SELECT unnest(document) AS document
|
|
||||||
),
|
|
||||||
similarities AS (
|
|
||||||
SELECT
|
|
||||||
query_number,
|
|
||||||
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
|
|
||||||
FROM queries CROSS JOIN documents
|
|
||||||
),
|
|
||||||
max_similarities AS (
|
|
||||||
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
|
|
||||||
)
|
|
||||||
SELECT SUM(max_similarity) FROM max_similarities
|
|
||||||
$$ LANGUAGE SQL
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
logger.info("Created max_sim function successfully")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating max_sim function: {str(e)}")
|
logger.error(f"Error creating max_sim function: {str(e)}")
|
||||||
# Continue even if function creation fails - it might already exist and be usable
|
# Continue even if function creation fails - it might already exist and be usable
|
||||||
@ -161,11 +211,12 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
embeddings = embeddings.cpu().numpy()
|
embeddings = embeddings.cpu().numpy()
|
||||||
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
|
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
|
||||||
embeddings = np.array(embeddings)
|
embeddings = np.array(embeddings)
|
||||||
# try:
|
|
||||||
|
# Add this check to ensure pgvector is registered for the connection
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
register_vector(conn)
|
||||||
|
|
||||||
return [Bit(embedding > 0) for embedding in embeddings]
|
return [Bit(embedding > 0) for embedding in embeddings]
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error quantizing embeddings: {str(e)}")
|
|
||||||
# raise e
|
|
||||||
|
|
||||||
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
||||||
"""Store document chunks with their multi-vector embeddings."""
|
"""Store document chunks with their multi-vector embeddings."""
|
||||||
@ -189,21 +240,22 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
# Create binary representation for each vector
|
# Create binary representation for each vector
|
||||||
binary_embeddings = self._binary_quantize(embeddings)
|
binary_embeddings = self._binary_quantize(embeddings)
|
||||||
|
|
||||||
# Insert into database
|
# Insert into database with retry logic
|
||||||
self.conn.execute(
|
with self.get_connection() as conn:
|
||||||
"""
|
conn.execute(
|
||||||
INSERT INTO multi_vector_embeddings
|
"""
|
||||||
(document_id, chunk_number, content, chunk_metadata, embeddings)
|
INSERT INTO multi_vector_embeddings
|
||||||
VALUES (%s, %s, %s, %s, %s)
|
(document_id, chunk_number, content, chunk_metadata, embeddings)
|
||||||
""",
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
(
|
""",
|
||||||
chunk.document_id,
|
(
|
||||||
chunk.chunk_number,
|
chunk.document_id,
|
||||||
chunk.content,
|
chunk.chunk_number,
|
||||||
str(chunk.metadata),
|
chunk.content,
|
||||||
binary_embeddings,
|
str(chunk.metadata),
|
||||||
),
|
binary_embeddings,
|
||||||
)
|
),
|
||||||
|
)
|
||||||
|
|
||||||
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
|
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
|
||||||
|
|
||||||
@ -244,8 +296,9 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
query += " ORDER BY similarity DESC LIMIT %s"
|
query += " ORDER BY similarity DESC LIMIT %s"
|
||||||
params.append(k)
|
params.append(k)
|
||||||
|
|
||||||
# Execute query
|
# Execute query with retry logic
|
||||||
result = self.conn.execute(query, params).fetchall()
|
with self.get_connection() as conn:
|
||||||
|
result = conn.execute(query, params).fetchall()
|
||||||
|
|
||||||
# Convert to DocumentChunks
|
# Convert to DocumentChunks
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -305,7 +358,8 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
|
|
||||||
logger.debug(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store")
|
logger.debug(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store")
|
||||||
|
|
||||||
result = self.conn.execute(query).fetchall()
|
with self.get_connection() as conn:
|
||||||
|
result = conn.execute(query).fetchall()
|
||||||
|
|
||||||
# Convert to DocumentChunks
|
# Convert to DocumentChunks
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -339,9 +393,10 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
bool: True if the operation was successful, False otherwise
|
bool: True if the operation was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Delete all chunks for the specified document
|
# Delete all chunks for the specified document with retry logic
|
||||||
query = f"DELETE FROM multi_vector_embeddings WHERE document_id = '{document_id}'"
|
query = f"DELETE FROM multi_vector_embeddings WHERE document_id = '{document_id}'"
|
||||||
self.conn.execute(query)
|
with self.get_connection() as conn:
|
||||||
|
conn.execute(query)
|
||||||
|
|
||||||
logger.info(f"Deleted all chunks for document {document_id} from multi-vector store")
|
logger.info(f"Deleted all chunks for document {document_id} from multi-vector store")
|
||||||
return True
|
return True
|
||||||
@ -353,4 +408,8 @@ class MultiVectorStore(BaseVectorStore):
|
|||||||
def close(self):
|
def close(self):
|
||||||
"""Close the database connection."""
|
"""Close the database connection."""
|
||||||
if self.conn:
|
if self.conn:
|
||||||
self.conn.close()
|
try:
|
||||||
|
self.conn.close()
|
||||||
|
self.conn = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing connection: {str(e)}")
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, AsyncContextManager, Any
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
from sqlalchemy import Column, String, Integer, Index, select, text
|
from sqlalchemy import Column, String, Integer, Index, select, text
|
||||||
from sqlalchemy.sql.expression import func
|
from sqlalchemy.sql.expression import func
|
||||||
from sqlalchemy.types import UserDefinedType
|
from sqlalchemy.types import UserDefinedType
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
from .base_vector_store import BaseVectorStore
|
from .base_vector_store import BaseVectorStore
|
||||||
from core.models.chunk import DocumentChunk
|
from core.models.chunk import DocumentChunk
|
||||||
@ -68,34 +72,194 @@ class PGVectorStore(BaseVectorStore):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uri: str,
|
uri: str,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
):
|
):
|
||||||
"""Initialize PostgreSQL connection for vector storage."""
|
"""Initialize PostgreSQL connection for vector storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: PostgreSQL connection URI
|
||||||
|
max_retries: Maximum number of connection retry attempts
|
||||||
|
retry_delay: Delay in seconds between retry attempts
|
||||||
|
"""
|
||||||
|
# Use the URI exactly as provided without any modifications
|
||||||
|
# This ensures compatibility with Supabase and other PostgreSQL providers
|
||||||
|
logger.info(f"Initializing database engine with provided URI")
|
||||||
|
|
||||||
|
# Create the engine with the URI as is
|
||||||
self.engine = create_async_engine(uri)
|
self.engine = create_async_engine(uri)
|
||||||
|
|
||||||
|
# Log success
|
||||||
|
logger.info("Created database engine successfully")
|
||||||
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.retry_delay = retry_delay
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_session_with_retry(self) -> AsyncContextManager[AsyncSession]:
|
||||||
|
"""Get a SQLAlchemy async session with retry logic.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AsyncSession: A SQLAlchemy async session
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OperationalError: If all connection attempts fail
|
||||||
|
"""
|
||||||
|
attempt = 0
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
while attempt < self.max_retries:
|
||||||
|
try:
|
||||||
|
async with self.async_session() as session:
|
||||||
|
# Test if the connection is valid with a simple query
|
||||||
|
await session.execute(text("SELECT 1"))
|
||||||
|
yield session
|
||||||
|
return
|
||||||
|
except OperationalError as e:
|
||||||
|
last_error = e
|
||||||
|
attempt += 1
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
logger.warning(f"Database connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
|
||||||
|
await asyncio.sleep(self.retry_delay)
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
logger.error(f"All database connection attempts failed after {self.max_retries} retries: {str(last_error)}")
|
||||||
|
raise last_error
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize database tables and vector extension."""
|
"""Initialize database tables and vector extension."""
|
||||||
try:
|
try:
|
||||||
|
# Import config to get vector dimensions
|
||||||
|
from core.config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
dimensions = settings.VECTOR_DIMENSIONS
|
||||||
|
|
||||||
|
logger.info(f"Initializing PGVector store with {dimensions} dimensions")
|
||||||
|
|
||||||
|
# Use retry logic for initialization
|
||||||
|
attempt = 0
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
while attempt < self.max_retries:
|
||||||
|
try:
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
# Enable pgvector extension
|
||||||
|
await conn.execute(
|
||||||
|
text("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
)
|
||||||
|
logger.info("Enabled pgvector extension")
|
||||||
|
|
||||||
|
# Rest of initialization code follows
|
||||||
|
break # Success, exit the retry loop
|
||||||
|
except OperationalError as e:
|
||||||
|
last_error = e
|
||||||
|
attempt += 1
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
logger.warning(f"Database initialization attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
|
||||||
|
await asyncio.sleep(self.retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error(f"All database initialization attempts failed after {self.max_retries} retries: {str(last_error)}")
|
||||||
|
raise last_error
|
||||||
|
|
||||||
|
# Continue with the rest of the initialization
|
||||||
async with self.engine.begin() as conn:
|
async with self.engine.begin() as conn:
|
||||||
# Enable pgvector extension
|
|
||||||
await conn.execute(
|
|
||||||
func.create_extension("vector", schema="public", if_not_exists=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create tables and indexes
|
# Check if vector_embeddings table exists
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
check_table_sql = """
|
||||||
|
SELECT EXISTS (
|
||||||
# Create vector index
|
SELECT FROM information_schema.tables
|
||||||
await conn.execute(
|
WHERE table_name = 'vector_embeddings'
|
||||||
text(
|
);
|
||||||
"""
|
|
||||||
CREATE INDEX IF NOT EXISTS vector_idx
|
|
||||||
ON vector_embeddings
|
|
||||||
USING ivfflat (embedding vector_cosine_ops)
|
|
||||||
WITH (lists = 100);
|
|
||||||
"""
|
"""
|
||||||
|
result = await conn.execute(text(check_table_sql))
|
||||||
|
table_exists = result.scalar()
|
||||||
|
|
||||||
|
if table_exists:
|
||||||
|
# Check current vector dimensions
|
||||||
|
check_dim_sql = """
|
||||||
|
SELECT atttypmod - 4 AS dimensions
|
||||||
|
FROM pg_attribute a
|
||||||
|
JOIN pg_class c ON a.attrelid = c.oid
|
||||||
|
JOIN pg_type t ON a.atttypid = t.oid
|
||||||
|
WHERE c.relname = 'vector_embeddings'
|
||||||
|
AND a.attname = 'embedding'
|
||||||
|
AND t.typname = 'vector';
|
||||||
|
"""
|
||||||
|
result = await conn.execute(text(check_dim_sql))
|
||||||
|
current_dim = result.scalar()
|
||||||
|
|
||||||
|
if current_dim != dimensions:
|
||||||
|
logger.info(f"Vector dimensions changed from {current_dim} to {dimensions}, recreating table")
|
||||||
|
|
||||||
|
# Drop existing vector index if it exists
|
||||||
|
await conn.execute(text("DROP INDEX IF EXISTS vector_idx;"))
|
||||||
|
|
||||||
|
# Drop existing vector embeddings table
|
||||||
|
await conn.execute(text("DROP TABLE IF EXISTS vector_embeddings;"))
|
||||||
|
|
||||||
|
# Create vector embeddings table with proper vector column
|
||||||
|
create_table_sql = f"""
|
||||||
|
CREATE TABLE vector_embeddings (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
document_id VARCHAR(255) NOT NULL,
|
||||||
|
chunk_number INTEGER NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
chunk_metadata TEXT,
|
||||||
|
embedding vector({dimensions}) NOT NULL,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
await conn.execute(text(create_table_sql))
|
||||||
|
logger.info(f"Created vector_embeddings table with vector({dimensions})")
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);"))
|
||||||
|
|
||||||
|
# Create vector index
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
CREATE INDEX vector_idx
|
||||||
|
ON vector_embeddings
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("Created IVFFlat index on vector_embeddings")
|
||||||
|
else:
|
||||||
|
logger.info(f"Vector dimensions unchanged ({dimensions}), using existing table")
|
||||||
|
else:
|
||||||
|
# Create tables and indexes if they don't exist
|
||||||
|
create_table_sql = f"""
|
||||||
|
CREATE TABLE vector_embeddings (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
document_id VARCHAR(255) NOT NULL,
|
||||||
|
chunk_number INTEGER NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
chunk_metadata TEXT,
|
||||||
|
embedding vector({dimensions}) NOT NULL,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
await conn.execute(text(create_table_sql))
|
||||||
|
logger.info(f"Created vector_embeddings table with vector({dimensions})")
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);"))
|
||||||
|
|
||||||
|
# Create vector index
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
CREATE INDEX vector_idx
|
||||||
|
ON vector_embeddings
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
"""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
logger.info("Created IVFFlat index on vector_embeddings")
|
||||||
|
|
||||||
logger.info("PGVector store initialized successfully")
|
logger.info("PGVector store initialized successfully")
|
||||||
return True
|
return True
|
||||||
@ -109,7 +273,7 @@ class PGVectorStore(BaseVectorStore):
|
|||||||
if not chunks:
|
if not chunks:
|
||||||
return True, []
|
return True, []
|
||||||
|
|
||||||
async with self.async_session() as session:
|
async with self.get_session_with_retry() as session:
|
||||||
stored_ids = []
|
stored_ids = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if not chunk.embedding:
|
if not chunk.embedding:
|
||||||
@ -143,7 +307,7 @@ class PGVectorStore(BaseVectorStore):
|
|||||||
) -> List[DocumentChunk]:
|
) -> List[DocumentChunk]:
|
||||||
"""Find similar chunks using cosine similarity."""
|
"""Find similar chunks using cosine similarity."""
|
||||||
try:
|
try:
|
||||||
async with self.async_session() as session:
|
async with self.get_session_with_retry() as session:
|
||||||
# Build query
|
# Build query
|
||||||
query = select(VectorEmbedding).order_by(
|
query = select(VectorEmbedding).order_by(
|
||||||
VectorEmbedding.embedding.op("<->")(query_embedding)
|
VectorEmbedding.embedding.op("<->")(query_embedding)
|
||||||
@ -196,7 +360,7 @@ class PGVectorStore(BaseVectorStore):
|
|||||||
if not chunk_identifiers:
|
if not chunk_identifiers:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async with self.async_session() as session:
|
async with self.get_session_with_retry() as session:
|
||||||
# Create a list of OR conditions for the query
|
# Create a list of OR conditions for the query
|
||||||
conditions = []
|
conditions = []
|
||||||
for doc_id, chunk_num in chunk_identifiers:
|
for doc_id, chunk_num in chunk_identifiers:
|
||||||
@ -253,7 +417,7 @@ class PGVectorStore(BaseVectorStore):
|
|||||||
bool: True if the operation was successful, False otherwise
|
bool: True if the operation was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with self.async_session() as session:
|
async with self.get_session_with_retry() as session:
|
||||||
# Delete all chunks for the specified document
|
# Delete all chunks for the specified document
|
||||||
query = text(f"DELETE FROM vector_embeddings WHERE document_id = :doc_id")
|
query = text(f"DELETE FROM vector_embeddings WHERE document_id = :doc_id")
|
||||||
await session.execute(query, {"doc_id": document_id})
|
await session.execute(query, {"doc_id": document_id})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user