mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Squashed changes from hosted-service
This commit is contained in:
parent
bf7c90164f
commit
5ae396cddb
13
core/api.py
13
core/api.py
@ -110,6 +110,19 @@ async def initialize_database():
|
||||
# We don't raise an exception here to allow the app to continue starting
|
||||
# even if there are initialization errors
|
||||
|
||||
@app.on_event("startup")
|
||||
async def initialize_vector_store():
|
||||
"""Initialize vector store tables and indexes on application startup."""
|
||||
logger.info("Initializing vector store...")
|
||||
if hasattr(vector_store, 'initialize'):
|
||||
success = await vector_store.initialize()
|
||||
if success:
|
||||
logger.info("Vector store initialization successful")
|
||||
else:
|
||||
logger.error("Vector store initialization failed")
|
||||
else:
|
||||
logger.warning("Vector store does not have an initialize method")
|
||||
|
||||
# Initialize vector store
|
||||
match settings.VECTOR_STORE_PROVIDER:
|
||||
case "mongodb":
|
||||
|
@ -151,6 +151,8 @@ class DocumentService:
|
||||
# by filtering for only the chunks which weren't ingested via colpali...
|
||||
if len(chunks_multivector) == 0:
|
||||
return chunks
|
||||
if len(chunks) == 0:
|
||||
return chunks_multivector
|
||||
|
||||
# TODO: this is duct tape, fix it properly later
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union, ContextManager
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
import psycopg
|
||||
from pgvector.psycopg import Bit, register_vector
|
||||
from core.models.chunk import DocumentChunk
|
||||
@ -16,135 +18,183 @@ class MultiVectorStore(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
"""Initialize PostgreSQL connection for multi-vector storage.
|
||||
|
||||
Args:
|
||||
uri: PostgreSQL connection URI
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
retry_delay: Delay in seconds between retry attempts
|
||||
"""
|
||||
# Convert SQLAlchemy URI to psycopg format if needed
|
||||
if uri.startswith("postgresql+asyncpg://"):
|
||||
uri = uri.replace("postgresql+asyncpg://", "postgresql://")
|
||||
self.uri = uri
|
||||
# self.conn = psycopg.connect(self.uri, autocommit=True)
|
||||
self.conn = None
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.initialize()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Get a PostgreSQL connection with retry logic.
|
||||
|
||||
Yields:
|
||||
A PostgreSQL connection object
|
||||
|
||||
Raises:
|
||||
psycopg.OperationalError: If all connection attempts fail
|
||||
"""
|
||||
attempt = 0
|
||||
last_error = None
|
||||
|
||||
# Try to establish a new connection with retries
|
||||
while attempt < self.max_retries:
|
||||
try:
|
||||
# Always create a fresh connection for critical operations
|
||||
conn = psycopg.connect(self.uri, autocommit=True)
|
||||
# Register vector extension on every new connection
|
||||
register_vector(conn)
|
||||
|
||||
try:
|
||||
yield conn
|
||||
return
|
||||
finally:
|
||||
# Always close connections after use
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
except psycopg.OperationalError as e:
|
||||
last_error = e
|
||||
attempt += 1
|
||||
if attempt < self.max_retries:
|
||||
logger.warning(f"Connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
|
||||
time.sleep(self.retry_delay)
|
||||
|
||||
# If we get here, all retries failed
|
||||
logger.error(f"All connection attempts failed after {self.max_retries} retries: {str(last_error)}")
|
||||
raise last_error
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize database tables and max_sim function."""
|
||||
try:
|
||||
# Connect to database
|
||||
self.conn = psycopg.connect(self.uri, autocommit=True)
|
||||
|
||||
# Register vector extension
|
||||
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
register_vector(self.conn)
|
||||
# Use the connection with retry logic
|
||||
with self.get_connection() as conn:
|
||||
# Register vector extension
|
||||
conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
register_vector(conn)
|
||||
|
||||
# First check if the table exists and if it has the required columns
|
||||
check_table = self.conn.execute(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'multi_vector_embeddings'
|
||||
);
|
||||
"""
|
||||
).fetchone()[0]
|
||||
|
||||
if check_table:
|
||||
# Check if document_id column exists
|
||||
has_document_id = self.conn.execute(
|
||||
with self.get_connection() as conn:
|
||||
check_table = conn.execute(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id'
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'multi_vector_embeddings'
|
||||
);
|
||||
"""
|
||||
).fetchone()[0]
|
||||
|
||||
# If the table exists but doesn't have document_id, we need to add the required columns
|
||||
if not has_document_id:
|
||||
logger.info("Updating multi_vector_embeddings table with required columns")
|
||||
self.conn.execute(
|
||||
if check_table:
|
||||
# Check if document_id column exists
|
||||
has_document_id = conn.execute(
|
||||
"""
|
||||
ALTER TABLE multi_vector_embeddings
|
||||
ADD COLUMN document_id TEXT,
|
||||
ADD COLUMN chunk_number INTEGER,
|
||||
ADD COLUMN content TEXT,
|
||||
ADD COLUMN chunk_metadata TEXT
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.columns
|
||||
WHERE table_name = 'multi_vector_embeddings' AND column_name = 'document_id'
|
||||
);
|
||||
"""
|
||||
)
|
||||
self.conn.execute(
|
||||
).fetchone()[0]
|
||||
|
||||
# If the table exists but doesn't have document_id, we need to add the required columns
|
||||
if not has_document_id:
|
||||
logger.info("Updating multi_vector_embeddings table with required columns")
|
||||
conn.execute(
|
||||
"""
|
||||
ALTER TABLE multi_vector_embeddings
|
||||
ADD COLUMN document_id TEXT,
|
||||
ADD COLUMN chunk_number INTEGER,
|
||||
ADD COLUMN content TEXT,
|
||||
ADD COLUMN chunk_metadata TEXT
|
||||
"""
|
||||
ALTER TABLE multi_vector_embeddings
|
||||
ALTER COLUMN document_id SET NOT NULL
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
ALTER TABLE multi_vector_embeddings
|
||||
ALTER COLUMN document_id SET NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Add a commit to ensure changes are applied
|
||||
conn.commit()
|
||||
else:
|
||||
# Create table if it doesn't exist with all required columns
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS multi_vector_embeddings (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
chunk_number INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
chunk_metadata TEXT,
|
||||
embeddings BIT(128)[]
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Add a commit to ensure changes are applied
|
||||
self.conn.commit()
|
||||
else:
|
||||
# Create table if it doesn't exist with all required columns
|
||||
self.conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS multi_vector_embeddings (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL,
|
||||
chunk_number INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
chunk_metadata TEXT,
|
||||
embeddings BIT(128)[]
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Add a commit to ensure table creation is complete
|
||||
self.conn.commit()
|
||||
# Add a commit to ensure table creation is complete
|
||||
conn.commit()
|
||||
|
||||
try:
|
||||
# Create index on document_id
|
||||
self.conn.execute(
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id
|
||||
ON multi_vector_embeddings (document_id)
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_multi_vector_document_id
|
||||
ON multi_vector_embeddings (document_id)
|
||||
"""
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Log index creation failure but continue
|
||||
logger.warning(f"Failed to create index: {str(e)}")
|
||||
|
||||
try:
|
||||
# First, try to drop the existing function if it exists
|
||||
self.conn.execute(
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
DROP FUNCTION IF EXISTS max_sim(bit[], bit[])
|
||||
"""
|
||||
DROP FUNCTION IF EXISTS max_sim(bit[], bit[])
|
||||
"""
|
||||
)
|
||||
logger.info("Dropped existing max_sim function")
|
||||
)
|
||||
logger.info("Dropped existing max_sim function")
|
||||
|
||||
# Create max_sim function
|
||||
self.conn.execute(
|
||||
# Create max_sim function
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
|
||||
WITH queries AS (
|
||||
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
|
||||
),
|
||||
documents AS (
|
||||
SELECT unnest(document) AS document
|
||||
),
|
||||
similarities AS (
|
||||
SELECT
|
||||
query_number,
|
||||
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
|
||||
FROM queries CROSS JOIN documents
|
||||
),
|
||||
max_similarities AS (
|
||||
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
|
||||
)
|
||||
SELECT SUM(max_similarity) FROM max_similarities
|
||||
$$ LANGUAGE SQL
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION max_sim(document bit[], query bit[]) RETURNS double precision AS $$
|
||||
WITH queries AS (
|
||||
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query) AS foo
|
||||
),
|
||||
documents AS (
|
||||
SELECT unnest(document) AS document
|
||||
),
|
||||
similarities AS (
|
||||
SELECT
|
||||
query_number,
|
||||
1.0 - (bit_count(document # query)::float / greatest(bit_length(query), 1)::float) AS similarity
|
||||
FROM queries CROSS JOIN documents
|
||||
),
|
||||
max_similarities AS (
|
||||
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
|
||||
)
|
||||
SELECT SUM(max_similarity) FROM max_similarities
|
||||
$$ LANGUAGE SQL
|
||||
"""
|
||||
)
|
||||
logger.info("Created max_sim function successfully")
|
||||
)
|
||||
logger.info("Created max_sim function successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating max_sim function: {str(e)}")
|
||||
# Continue even if function creation fails - it might already exist and be usable
|
||||
@ -161,11 +211,12 @@ class MultiVectorStore(BaseVectorStore):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
if isinstance(embeddings, list) and not isinstance(embeddings[0], np.ndarray):
|
||||
embeddings = np.array(embeddings)
|
||||
# try:
|
||||
|
||||
# Add this check to ensure pgvector is registered for the connection
|
||||
with self.get_connection() as conn:
|
||||
register_vector(conn)
|
||||
|
||||
return [Bit(embedding > 0) for embedding in embeddings]
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error quantizing embeddings: {str(e)}")
|
||||
# raise e
|
||||
|
||||
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
||||
"""Store document chunks with their multi-vector embeddings."""
|
||||
@ -189,21 +240,22 @@ class MultiVectorStore(BaseVectorStore):
|
||||
# Create binary representation for each vector
|
||||
binary_embeddings = self._binary_quantize(embeddings)
|
||||
|
||||
# Insert into database
|
||||
self.conn.execute(
|
||||
"""
|
||||
INSERT INTO multi_vector_embeddings
|
||||
(document_id, chunk_number, content, chunk_metadata, embeddings)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
chunk.document_id,
|
||||
chunk.chunk_number,
|
||||
chunk.content,
|
||||
str(chunk.metadata),
|
||||
binary_embeddings,
|
||||
),
|
||||
)
|
||||
# Insert into database with retry logic
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO multi_vector_embeddings
|
||||
(document_id, chunk_number, content, chunk_metadata, embeddings)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
chunk.document_id,
|
||||
chunk.chunk_number,
|
||||
chunk.content,
|
||||
str(chunk.metadata),
|
||||
binary_embeddings,
|
||||
),
|
||||
)
|
||||
|
||||
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
|
||||
|
||||
@ -244,8 +296,9 @@ class MultiVectorStore(BaseVectorStore):
|
||||
query += " ORDER BY similarity DESC LIMIT %s"
|
||||
params.append(k)
|
||||
|
||||
# Execute query
|
||||
result = self.conn.execute(query, params).fetchall()
|
||||
# Execute query with retry logic
|
||||
with self.get_connection() as conn:
|
||||
result = conn.execute(query, params).fetchall()
|
||||
|
||||
# Convert to DocumentChunks
|
||||
chunks = []
|
||||
@ -305,7 +358,8 @@ class MultiVectorStore(BaseVectorStore):
|
||||
|
||||
logger.debug(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store")
|
||||
|
||||
result = self.conn.execute(query).fetchall()
|
||||
with self.get_connection() as conn:
|
||||
result = conn.execute(query).fetchall()
|
||||
|
||||
# Convert to DocumentChunks
|
||||
chunks = []
|
||||
@ -339,9 +393,10 @@ class MultiVectorStore(BaseVectorStore):
|
||||
bool: True if the operation was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Delete all chunks for the specified document
|
||||
# Delete all chunks for the specified document with retry logic
|
||||
query = f"DELETE FROM multi_vector_embeddings WHERE document_id = '{document_id}'"
|
||||
self.conn.execute(query)
|
||||
with self.get_connection() as conn:
|
||||
conn.execute(query)
|
||||
|
||||
logger.info(f"Deleted all chunks for document {document_id} from multi-vector store")
|
||||
return True
|
||||
@ -353,4 +408,8 @@ class MultiVectorStore(BaseVectorStore):
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
try:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection: {str(e)}")
|
||||
|
@ -1,10 +1,14 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, AsyncContextManager, Any
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy import Column, String, Integer, Index, select, text
|
||||
from sqlalchemy.sql.expression import func
|
||||
from sqlalchemy.types import UserDefinedType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from .base_vector_store import BaseVectorStore
|
||||
from core.models.chunk import DocumentChunk
|
||||
@ -68,34 +72,194 @@ class PGVectorStore(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
"""Initialize PostgreSQL connection for vector storage."""
|
||||
"""Initialize PostgreSQL connection for vector storage.
|
||||
|
||||
Args:
|
||||
uri: PostgreSQL connection URI
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
retry_delay: Delay in seconds between retry attempts
|
||||
"""
|
||||
# Use the URI exactly as provided without any modifications
|
||||
# This ensures compatibility with Supabase and other PostgreSQL providers
|
||||
logger.info(f"Initializing database engine with provided URI")
|
||||
|
||||
# Create the engine with the URI as is
|
||||
self.engine = create_async_engine(uri)
|
||||
|
||||
# Log success
|
||||
logger.info("Created database engine successfully")
|
||||
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session_with_retry(self) -> AsyncContextManager[AsyncSession]:
|
||||
"""Get a SQLAlchemy async session with retry logic.
|
||||
|
||||
Yields:
|
||||
AsyncSession: A SQLAlchemy async session
|
||||
|
||||
Raises:
|
||||
OperationalError: If all connection attempts fail
|
||||
"""
|
||||
attempt = 0
|
||||
last_error = None
|
||||
|
||||
while attempt < self.max_retries:
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
# Test if the connection is valid with a simple query
|
||||
await session.execute(text("SELECT 1"))
|
||||
yield session
|
||||
return
|
||||
except OperationalError as e:
|
||||
last_error = e
|
||||
attempt += 1
|
||||
if attempt < self.max_retries:
|
||||
logger.warning(f"Database connection attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
|
||||
# If we get here, all retries failed
|
||||
logger.error(f"All database connection attempts failed after {self.max_retries} retries: {str(last_error)}")
|
||||
raise last_error
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize database tables and vector extension."""
|
||||
try:
|
||||
# Import config to get vector dimensions
|
||||
from core.config import get_settings
|
||||
settings = get_settings()
|
||||
dimensions = settings.VECTOR_DIMENSIONS
|
||||
|
||||
logger.info(f"Initializing PGVector store with {dimensions} dimensions")
|
||||
|
||||
# Use retry logic for initialization
|
||||
attempt = 0
|
||||
last_error = None
|
||||
|
||||
while attempt < self.max_retries:
|
||||
try:
|
||||
async with self.engine.begin() as conn:
|
||||
# Enable pgvector extension
|
||||
await conn.execute(
|
||||
text("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
)
|
||||
logger.info("Enabled pgvector extension")
|
||||
|
||||
# Rest of initialization code follows
|
||||
break # Success, exit the retry loop
|
||||
except OperationalError as e:
|
||||
last_error = e
|
||||
attempt += 1
|
||||
if attempt < self.max_retries:
|
||||
logger.warning(f"Database initialization attempt {attempt} failed: {str(e)}. Retrying in {self.retry_delay} seconds...")
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error(f"All database initialization attempts failed after {self.max_retries} retries: {str(last_error)}")
|
||||
raise last_error
|
||||
|
||||
# Continue with the rest of the initialization
|
||||
async with self.engine.begin() as conn:
|
||||
# Enable pgvector extension
|
||||
await conn.execute(
|
||||
func.create_extension("vector", schema="public", if_not_exists=True)
|
||||
)
|
||||
|
||||
# Create tables and indexes
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Create vector index
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS vector_idx
|
||||
ON vector_embeddings
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
|
||||
# Check if vector_embeddings table exists
|
||||
check_table_sql = """
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'vector_embeddings'
|
||||
);
|
||||
"""
|
||||
result = await conn.execute(text(check_table_sql))
|
||||
table_exists = result.scalar()
|
||||
|
||||
if table_exists:
|
||||
# Check current vector dimensions
|
||||
check_dim_sql = """
|
||||
SELECT atttypmod - 4 AS dimensions
|
||||
FROM pg_attribute a
|
||||
JOIN pg_class c ON a.attrelid = c.oid
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
WHERE c.relname = 'vector_embeddings'
|
||||
AND a.attname = 'embedding'
|
||||
AND t.typname = 'vector';
|
||||
"""
|
||||
result = await conn.execute(text(check_dim_sql))
|
||||
current_dim = result.scalar()
|
||||
|
||||
if current_dim != dimensions:
|
||||
logger.info(f"Vector dimensions changed from {current_dim} to {dimensions}, recreating table")
|
||||
|
||||
# Drop existing vector index if it exists
|
||||
await conn.execute(text("DROP INDEX IF EXISTS vector_idx;"))
|
||||
|
||||
# Drop existing vector embeddings table
|
||||
await conn.execute(text("DROP TABLE IF EXISTS vector_embeddings;"))
|
||||
|
||||
# Create vector embeddings table with proper vector column
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE vector_embeddings (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id VARCHAR(255) NOT NULL,
|
||||
chunk_number INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
chunk_metadata TEXT,
|
||||
embedding vector({dimensions}) NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
await conn.execute(text(create_table_sql))
|
||||
logger.info(f"Created vector_embeddings table with vector({dimensions})")
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);"))
|
||||
|
||||
# Create vector index
|
||||
await conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE INDEX vector_idx
|
||||
ON vector_embeddings
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
"""
|
||||
)
|
||||
)
|
||||
logger.info("Created IVFFlat index on vector_embeddings")
|
||||
else:
|
||||
logger.info(f"Vector dimensions unchanged ({dimensions}), using existing table")
|
||||
else:
|
||||
# Create tables and indexes if they don't exist
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE vector_embeddings (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id VARCHAR(255) NOT NULL,
|
||||
chunk_number INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
chunk_metadata TEXT,
|
||||
embedding vector({dimensions}) NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
await conn.execute(text(create_table_sql))
|
||||
logger.info(f"Created vector_embeddings table with vector({dimensions})")
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(text("CREATE INDEX idx_document_id ON vector_embeddings(document_id);"))
|
||||
|
||||
# Create vector index
|
||||
await conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE INDEX vector_idx
|
||||
ON vector_embeddings
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.info("Created IVFFlat index on vector_embeddings")
|
||||
|
||||
logger.info("PGVector store initialized successfully")
|
||||
return True
|
||||
@ -109,7 +273,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
if not chunks:
|
||||
return True, []
|
||||
|
||||
async with self.async_session() as session:
|
||||
async with self.get_session_with_retry() as session:
|
||||
stored_ids = []
|
||||
for chunk in chunks:
|
||||
if not chunk.embedding:
|
||||
@ -143,7 +307,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
) -> List[DocumentChunk]:
|
||||
"""Find similar chunks using cosine similarity."""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
async with self.get_session_with_retry() as session:
|
||||
# Build query
|
||||
query = select(VectorEmbedding).order_by(
|
||||
VectorEmbedding.embedding.op("<->")(query_embedding)
|
||||
@ -196,7 +360,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
if not chunk_identifiers:
|
||||
return []
|
||||
|
||||
async with self.async_session() as session:
|
||||
async with self.get_session_with_retry() as session:
|
||||
# Create a list of OR conditions for the query
|
||||
conditions = []
|
||||
for doc_id, chunk_num in chunk_identifiers:
|
||||
@ -253,7 +417,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
bool: True if the operation was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
async with self.get_session_with_retry() as session:
|
||||
# Delete all chunks for the specified document
|
||||
query = text(f"DELETE FROM vector_embeddings WHERE document_id = :doc_id")
|
||||
await session.execute(query, {"doc_id": document_id})
|
||||
|
Loading…
x
Reference in New Issue
Block a user