From c3726504f7cf70fc645ae57a599668672b8edf59 Mon Sep 17 00:00:00 2001 From: Arnav Agrawal <88790414+ArnavAgrawal03@users.noreply.github.com> Date: Sat, 4 Jan 2025 08:14:52 -0500 Subject: [PATCH] add support for PostgreSQL and pgvector (#15) Co-authored-by: Adityavardhan Agrawal --- config.toml | 9 +- core/api.py | 10 +- core/config.py | 16 ++- core/database/postgres_database.py | 17 ++- core/tests/integration/test_api.py | 2 +- core/vector_store/pgvector_store.py | 180 ++++++++++++++++++++++++++++ databridge.toml | 63 ++++++++++ quick_setup.py | 75 +++++++++++- requirements.txt | 3 + 9 files changed, 366 insertions(+), 9 deletions(-) create mode 100644 core/vector_store/pgvector_store.py create mode 100644 databridge.toml diff --git a/config.toml b/config.toml index d70f627..5b351ec 100644 --- a/config.toml +++ b/config.toml @@ -7,7 +7,7 @@ reload = false [service.components] storage = "local" # "aws-s3" database = "postgres" # "postgres", "mongodb" -vector_store = "mongodb" +vector_store = "pgvector" # "mongodb" embedding = "ollama" # "openai", "ollama" completion = "ollama" # "openai", "ollama" parser = "combined" # "combined", "unstructured", "contextual" @@ -38,6 +38,13 @@ dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small index_name = "vector_index" similarity_metric = "cosine" +[vector_store.pgvector] +dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small +table_name = "vector_embeddings" +index_method = "ivfflat" # "ivfflat" or "hnsw" +index_lists = 100 # Number of lists for ivfflat index +probes = 10 # Number of probes for ivfflat search + # Model Configurations [models] [models.embedding] diff --git a/core/api.py b/core/api.py index 3c656af..76ccd91 100644 --- a/core/api.py +++ b/core/api.py @@ -79,6 +79,14 @@ match settings.VECTOR_STORE_PROVIDER: collection_name=settings.CHUNKS_COLLECTION, index_name=settings.VECTOR_INDEX_NAME, ) + case "pgvector": + if not settings.POSTGRES_URI: + raise ValueError("PostgreSQL URI is required for pgvector store") + from core.vector_store.pgvector_store import PGVectorStore + + vector_store = PGVectorStore( + uri=settings.POSTGRES_URI, + ) case _: raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}") @@ -347,7 +355,7 @@ async def query_completion( async def list_documents( auth: AuthContext = Depends(verify_token), skip: int = 0, - limit: int = 100, + limit: int = 10000, filters: Optional[Dict[str, Any]] = None, ): """List accessible documents.""" diff --git a/core/config.py b/core/config.py index 9b94488..aa55d69 100644 --- a/core/config.py +++ b/core/config.py @@ -50,6 +50,10 @@ class Settings(BaseSettings): # Vector store settings VECTOR_INDEX_NAME: str = "vector_index" VECTOR_DIMENSIONS: int = 1536 + PGVECTOR_TABLE_NAME: str = "vector_embeddings" + PGVECTOR_INDEX_METHOD: str = "ivfflat" + PGVECTOR_INDEX_LISTS: int = 100 + PGVECTOR_PROBES: int = 10 # Model settings EMBEDDING_MODEL: str = "text-embedding-3-small" @@ -120,7 +124,17 @@ def get_settings() -> Settings: .get("chunks_collection", "document_chunks"), # Vector store settings "VECTOR_INDEX_NAME": config["vector_store"]["mongodb"]["index_name"], - "VECTOR_DIMENSIONS": config["vector_store"]["mongodb"]["dimensions"], + "VECTOR_DIMENSIONS": config["vector_store"][ + config["service"]["components"]["vector_store"] + ]["dimensions"], + "PGVECTOR_TABLE_NAME": config["vector_store"] + .get("pgvector", {}) + .get("table_name", "vector_embeddings"), + "PGVECTOR_INDEX_METHOD": config["vector_store"] + .get("pgvector", {}) + .get("index_method", "ivfflat"), + "PGVECTOR_INDEX_LISTS": config["vector_store"].get("pgvector", {}).get("index_lists", 100), + "PGVECTOR_PROBES": config["vector_store"].get("pgvector", {}).get("probes", 10), # Model settings "EMBEDDING_MODEL": config["models"]["embedding"]["model_name"], "COMPLETION_MODEL": config["models"]["completion"]["model_name"], diff --git a/core/database/postgres_database.py b/core/database/postgres_database.py index 00b2260..f20b740 100644 --- a/core/database/postgres_database.py +++ b/core/database/postgres_database.py @@ -79,7 +79,7 @@ class PostgresDatabase(BaseDatabase): # Rename metadata to doc_metadata if "metadata" in doc_dict: doc_dict["doc_metadata"] = doc_dict.pop("metadata") - + doc_dict["doc_metadata"]["external_id"] = doc_dict["external_id"] # Ensure system metadata if "system_metadata" not in doc_dict: doc_dict["system_metadata"] = {} @@ -141,7 +141,7 @@ class PostgresDatabase(BaseDatabase): self, auth: AuthContext, skip: int = 0, - limit: int = 100, + limit: int = 10000, filters: Optional[Dict[str, Any]] = None, ) -> List[Document]: """List documents the user has access to.""" @@ -244,12 +244,20 @@ class PostgresDatabase(BaseDatabase): access_filter = self._build_access_filter(auth) metadata_filter = self._build_metadata_filter(filters) + logger.debug(f"Access filter: {access_filter}") + logger.debug(f"Metadata filter: {metadata_filter}") + logger.debug(f"Original filters: {filters}") + query = select(DocumentModel.external_id).where(text(f"({access_filter})")) if metadata_filter: query = query.where(text(metadata_filter)) + logger.debug(f"Final query: {query}") + result = await session.execute(query) - return [row[0] for row in result.all()] + doc_ids = [row[0] for row in result.all()] + logger.debug(f"Found document IDs: {doc_ids}") + return doc_ids except Exception as e: logger.error(f"Error finding authorized documents: {str(e)}") @@ -310,6 +318,9 @@ class PostgresDatabase(BaseDatabase): filter_conditions = [] for key, value in filters.items(): + # Convert boolean values to string 'true' or 'false' + if isinstance(value, bool): + value = str(value).lower() filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'") return " AND ".join(filter_conditions) diff --git a/core/tests/integration/test_api.py b/core/tests/integration/test_api.py index 563394a..6a504f4 100644 --- a/core/tests/integration/test_api.py +++ b/core/tests/integration/test_api.py @@ -336,7 +336,7 @@ async def test_retrieve_chunks(client: AsyncClient): results = list(response.json()) assert len(results) > 0 assert results[0]["score"] > 0.5 - assert results[0]["content"] == upload_string + assert any(upload_string == result["content"] for result in results) @pytest.mark.asyncio diff --git a/core/vector_store/pgvector_store.py b/core/vector_store/pgvector_store.py new file mode 100644 index 0000000..0e75c30 --- /dev/null +++ b/core/vector_store/pgvector_store.py @@ -0,0 +1,180 @@ +from typing import List, Optional, Tuple +import logging +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy import Column, String, Integer, Index, select, text +from sqlalchemy.sql.expression import func +from sqlalchemy.types import UserDefinedType + +from .base_vector_store import BaseVectorStore +from core.models.chunk import DocumentChunk + +logger = logging.getLogger(__name__) +Base = declarative_base() + + +class Vector(UserDefinedType): + """Custom type for pgvector vectors.""" + + def get_col_spec(self, **kw): + return "vector" + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, list): + return f"[{','.join(str(x) for x in value)}]" + return value + + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + # Remove brackets and split by comma + value = value[1:-1].split(",") + return [float(x) for x in value] + + return process + + +class VectorEmbedding(Base): + """SQLAlchemy model for vector embeddings.""" + + __tablename__ = "vector_embeddings" + + id = Column(Integer, primary_key=True) + document_id = Column(String, nullable=False) + chunk_number = Column(Integer, nullable=False) + content = Column(String, nullable=False) + chunk_metadata = Column(String, nullable=True) + embedding = Column(Vector, nullable=False) + + # Create indexes + __table_args__ = ( + Index("idx_document_id", "document_id"), + Index( + "idx_vector_embedding", + embedding, + postgresql_using="ivfflat", + postgresql_with={"lists": 100}, + ), + ) + + +class PGVectorStore(BaseVectorStore): + """PostgreSQL with pgvector implementation for vector storage.""" + + def __init__( + self, + uri: str, + ): + """Initialize PostgreSQL connection for vector storage.""" + self.engine = create_async_engine(uri) + self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) + + async def initialize(self): + """Initialize database tables and vector extension.""" + try: + async with self.engine.begin() as conn: + # Enable pgvector extension + await conn.execute( + func.create_extension("vector", schema="public", if_not_exists=True) + ) + + # Create tables and indexes + await conn.run_sync(Base.metadata.create_all) + + # Create vector index + await conn.execute( + text( + """ + CREATE INDEX IF NOT EXISTS vector_idx + ON vector_embeddings + USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + """ + ) + ) + + logger.info("PGVector store initialized successfully") + return True + except Exception as e: + logger.error(f"Error initializing PGVector store: {str(e)}") + return False + + async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]: + """Store document chunks with their embeddings.""" + try: + if not chunks: + return True, [] + + async with self.async_session() as session: + stored_ids = [] + for chunk in chunks: + if not chunk.embedding: + logger.error( + f"Missing embedding for chunk {chunk.document_id}-{chunk.chunk_number}" + ) + continue + + vector_embedding = VectorEmbedding( + document_id=chunk.document_id, + chunk_number=chunk.chunk_number, + content=chunk.content, + chunk_metadata=str(chunk.metadata), + embedding=chunk.embedding, + ) + session.add(vector_embedding) + stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}") + + await session.commit() + return len(stored_ids) > 0, stored_ids + + except Exception as e: + logger.error(f"Error storing embeddings: {str(e)}") + return False, [] + + async def query_similar( + self, + query_embedding: List[float], + k: int, + doc_ids: Optional[List[str]] = None, + ) -> List[DocumentChunk]: + """Find similar chunks using cosine similarity.""" + try: + async with self.async_session() as session: + # Build query + query = select(VectorEmbedding).order_by( + VectorEmbedding.embedding.op("<->")(query_embedding) + ) + + if doc_ids: + query = query.filter(VectorEmbedding.document_id.in_(doc_ids)) + + query = query.limit(k) + result = await session.execute(query) + embeddings = result.scalars().all() + + # Convert to DocumentChunks + chunks = [] + for emb in embeddings: + try: + metadata = eval(emb.chunk_metadata) if emb.chunk_metadata else {} + except (ValueError, SyntaxError): + metadata = {} + + chunk = DocumentChunk( + document_id=emb.document_id, + chunk_number=emb.chunk_number, + content=emb.content, + embedding=[], # Don't send embeddings back + metadata=metadata, + ) + chunks.append(chunk) + + return chunks + + except Exception as e: + logger.error(f"Error querying similar chunks: {str(e)}") + return [] diff --git a/databridge.toml b/databridge.toml new file mode 100644 index 0000000..9995ed7 --- /dev/null +++ b/databridge.toml @@ -0,0 +1,63 @@ +[api] +host = "localhost" +port = 8000 +reload = false + + +[auth] +jwt_algorithm = "HS256" + + +[completion] +provider = "ollama" # ollama, openai +model_name = "llama3.2" +default_max_tokens = 1000 +default_temperature = 0.7 +base_url = "http://localhost:11434" + + +[database] +provider = "postgres" +database_name = "databridge" +documents_table = "documents" +chunks_table = "document_chunks" +# documents_collection = "documents" +# chunks_collection = "document_chunks" + + +[embedding] +provider = "ollama" # "ollama", "openai" +model_name = "nomic-embed-text" +dimensions = 768 +similarity_metric = "cosine" # "cosine", "dotProduct", "euclidean" +base_url = "http://localhost:11434" + + +[parser] +provider = "combined" # options: "combined", "unstructured", "contextual" +chunk_size = 1000 +chunk_overlap = 200 +use_unstructured_api = false +video_frame_sample_rate = 120 # not needed for unstructured + + +[reranker] +provider = "bge" +model_name = "BAAI/bge-reranker-large" # could also be "BAAI/bge-reranker-v2-gemma" +use_fp16 = true +query_max_length = 256 +passage_max_length = 512 +device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU + + +[storage] +provider = "local" +path = "./storage" +# region = "us-east-2" +# bucket_name = "databridge-s3-storage" + + +[vector_store] +provider = "mongodb" +num_chunks_to_retrieve = 20 +index_name = "vector_index" diff --git a/quick_setup.py b/quick_setup.py index 7dd0851..fe9b210 100644 --- a/quick_setup.py +++ b/quick_setup.py @@ -9,6 +9,8 @@ from pymongo import MongoClient from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.operations import SearchIndexModel import argparse +import platform +import subprocess # Force reload of environment variables load_dotenv(find_dotenv(), override=True) @@ -183,26 +185,95 @@ def setup_postgres(): """ import asyncio from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy import text # Load PostgreSQL URI from .env file postgres_uri = os.getenv("POSTGRES_URI") if not postgres_uri: raise ValueError("POSTGRES_URI not found in .env file.") + # Check if pgvector is installed when on macOS + if platform.system() == "Darwin": + try: + # Check if postgresql is installed via homebrew + result = subprocess.run( + ["brew", "list", "postgresql@14"], capture_output=True, text=True + ) + if result.returncode != 0: + LOGGER.error( + "PostgreSQL not found. Please install it with: brew install postgresql@14" + ) + raise RuntimeError("PostgreSQL not installed") + + # Check if pgvector is installed + result = subprocess.run(["brew", "list", "pgvector"], capture_output=True, text=True) + if result.returncode != 0: + LOGGER.error( + "\nError: pgvector extension not found. Please install it with:\n" + "brew install pgvector\n" + "brew services stop postgresql@14\n" + "brew services start postgresql@14\n" + ) + raise RuntimeError("pgvector not installed") + except FileNotFoundError: + LOGGER.error("Homebrew not found. Please install it from https://brew.sh") + raise + async def _setup_postgres(): try: # Create async engine engine = create_async_engine(postgres_uri) async with engine.begin() as conn: + try: + # Enable pgvector extension + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + LOGGER.info("Enabled pgvector extension") + except Exception as e: + if "could not open extension control file" in str(e): + LOGGER.error( + "\nError: pgvector extension not found. Please install it:\n" + "- On macOS: brew install pgvector\n" + "- On Ubuntu: sudo apt install postgresql-14-pgvector\n" + "- On other systems: check https://github.com/pgvector/pgvector#installation\n" + ) + raise + # Import and create all tables from core.database.postgres_database import Base + from core.vector_store.pgvector_store import Base as VectorBase await conn.run_sync(Base.metadata.create_all) - LOGGER.info("Created all PostgreSQL tables and indexes.") + await conn.run_sync(VectorBase.metadata.create_all) + LOGGER.info("Created all PostgreSQL tables and indexes") + + # Create vector index with configuration from settings + table_name = CONFIG["vector_store"]["pgvector"]["table_name"] + index_method = CONFIG["vector_store"]["pgvector"]["index_method"] + index_lists = CONFIG["vector_store"]["pgvector"]["index_lists"] + dimensions = CONFIG["vector_store"]["pgvector"]["dimensions"] + + # First, alter the embedding column to be a vector + alter_sql = f""" + ALTER TABLE {table_name} + ALTER COLUMN embedding TYPE vector({dimensions}) + USING embedding::vector({dimensions}); + """ + await conn.execute(text(alter_sql)) + LOGGER.info(f"Altered embedding column to be vector({dimensions})") + + # Then create the vector index + if index_method == "ivfflat": + index_sql = f""" + CREATE INDEX IF NOT EXISTS vector_idx + ON {table_name} USING ivfflat (embedding vector_l2_ops) + WITH (lists = {index_lists}); + """ + await conn.execute(text(index_sql)) + LOGGER.info(f"Created IVFFlat index on {table_name} with {index_lists} lists") await engine.dispose() - LOGGER.info("PostgreSQL setup completed successfully.") + LOGGER.info("PostgreSQL setup completed successfully") except Exception as e: LOGGER.error(f"Failed to setup PostgreSQL: {e}") diff --git a/requirements.txt b/requirements.txt index d93ad7e..6a8149f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -290,3 +290,6 @@ psycopg2-binary==2.9.9 alembic==1.13.1 zipp==3.21.0 zlib-state==0.1.9 +pgvector==0.2.5 +psycopg[binary]==3.1.18 +psycopg-binary==3.1.18