add support for PostgreSQL and pgvector (#15)

Co-authored-by: Adityavardhan Agrawal <aa729@cornell.edu>
This commit is contained in:
Arnav Agrawal 2025-01-04 08:14:52 -05:00 committed by GitHub
parent 273dfcc5e7
commit c3726504f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 366 additions and 9 deletions

View File

@ -7,7 +7,7 @@ reload = false
[service.components]
storage = "local" # "aws-s3"
database = "postgres" # "postgres", "mongodb"
vector_store = "mongodb"
vector_store = "pgvector" # "mongodb"
embedding = "ollama" # "openai", "ollama"
completion = "ollama" # "openai", "ollama"
parser = "combined" # "combined", "unstructured", "contextual"
@ -38,6 +38,13 @@ dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small
index_name = "vector_index"
similarity_metric = "cosine"
[vector_store.pgvector]
dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small
table_name = "vector_embeddings"
index_method = "ivfflat" # "ivfflat" or "hnsw"
index_lists = 100 # Number of lists for ivfflat index
probes = 10 # Number of probes for ivfflat search
# Model Configurations
[models]
[models.embedding]

View File

@ -79,6 +79,14 @@ match settings.VECTOR_STORE_PROVIDER:
collection_name=settings.CHUNKS_COLLECTION,
index_name=settings.VECTOR_INDEX_NAME,
)
case "pgvector":
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
from core.vector_store.pgvector_store import PGVectorStore
vector_store = PGVectorStore(
uri=settings.POSTGRES_URI,
)
case _:
raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}")
@ -347,7 +355,7 @@ async def query_completion(
async def list_documents(
auth: AuthContext = Depends(verify_token),
skip: int = 0,
limit: int = 100,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
):
"""List accessible documents."""

View File

@ -50,6 +50,10 @@ class Settings(BaseSettings):
# Vector store settings
VECTOR_INDEX_NAME: str = "vector_index"
VECTOR_DIMENSIONS: int = 1536
PGVECTOR_TABLE_NAME: str = "vector_embeddings"
PGVECTOR_INDEX_METHOD: str = "ivfflat"
PGVECTOR_INDEX_LISTS: int = 100
PGVECTOR_PROBES: int = 10
# Model settings
EMBEDDING_MODEL: str = "text-embedding-3-small"
@ -120,7 +124,17 @@ def get_settings() -> Settings:
.get("chunks_collection", "document_chunks"),
# Vector store settings
"VECTOR_INDEX_NAME": config["vector_store"]["mongodb"]["index_name"],
"VECTOR_DIMENSIONS": config["vector_store"]["mongodb"]["dimensions"],
"VECTOR_DIMENSIONS": config["vector_store"][
config["service"]["components"]["vector_store"]
]["dimensions"],
"PGVECTOR_TABLE_NAME": config["vector_store"]
.get("pgvector", {})
.get("table_name", "vector_embeddings"),
"PGVECTOR_INDEX_METHOD": config["vector_store"]
.get("pgvector", {})
.get("index_method", "ivfflat"),
"PGVECTOR_INDEX_LISTS": config["vector_store"].get("pgvector", {}).get("index_lists", 100),
"PGVECTOR_PROBES": config["vector_store"].get("pgvector", {}).get("probes", 10),
# Model settings
"EMBEDDING_MODEL": config["models"]["embedding"]["model_name"],
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],

View File

@ -79,7 +79,7 @@ class PostgresDatabase(BaseDatabase):
# Rename metadata to doc_metadata
if "metadata" in doc_dict:
doc_dict["doc_metadata"] = doc_dict.pop("metadata")
doc_dict["doc_metadata"]["external_id"] = doc_dict["external_id"]
# Ensure system metadata
if "system_metadata" not in doc_dict:
doc_dict["system_metadata"] = {}
@ -141,7 +141,7 @@ class PostgresDatabase(BaseDatabase):
self,
auth: AuthContext,
skip: int = 0,
limit: int = 100,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List documents the user has access to."""
@ -244,12 +244,20 @@ class PostgresDatabase(BaseDatabase):
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
logger.debug(f"Access filter: {access_filter}")
logger.debug(f"Metadata filter: {metadata_filter}")
logger.debug(f"Original filters: {filters}")
query = select(DocumentModel.external_id).where(text(f"({access_filter})"))
if metadata_filter:
query = query.where(text(metadata_filter))
logger.debug(f"Final query: {query}")
result = await session.execute(query)
return [row[0] for row in result.all()]
doc_ids = [row[0] for row in result.all()]
logger.debug(f"Found document IDs: {doc_ids}")
return doc_ids
except Exception as e:
logger.error(f"Error finding authorized documents: {str(e)}")
@ -310,6 +318,9 @@ class PostgresDatabase(BaseDatabase):
filter_conditions = []
for key, value in filters.items():
# Convert boolean values to string 'true' or 'false'
if isinstance(value, bool):
value = str(value).lower()
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)

View File

@ -336,7 +336,7 @@ async def test_retrieve_chunks(client: AsyncClient):
results = list(response.json())
assert len(results) > 0
assert results[0]["score"] > 0.5
assert results[0]["content"] == upload_string
assert any(upload_string == result["content"] for result in results)
@pytest.mark.asyncio

View File

@ -0,0 +1,180 @@
from typing import List, Optional, Tuple
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, Integer, Index, select, text
from sqlalchemy.sql.expression import func
from sqlalchemy.types import UserDefinedType
from .base_vector_store import BaseVectorStore
from core.models.chunk import DocumentChunk
logger = logging.getLogger(__name__)
Base = declarative_base()
class Vector(UserDefinedType):
"""Custom type for pgvector vectors."""
def get_col_spec(self, **kw):
return "vector"
def bind_processor(self, dialect):
def process(value):
if isinstance(value, list):
return f"[{','.join(str(x) for x in value)}]"
return value
return process
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
# Remove brackets and split by comma
value = value[1:-1].split(",")
return [float(x) for x in value]
return process
class VectorEmbedding(Base):
"""SQLAlchemy model for vector embeddings."""
__tablename__ = "vector_embeddings"
id = Column(Integer, primary_key=True)
document_id = Column(String, nullable=False)
chunk_number = Column(Integer, nullable=False)
content = Column(String, nullable=False)
chunk_metadata = Column(String, nullable=True)
embedding = Column(Vector, nullable=False)
# Create indexes
__table_args__ = (
Index("idx_document_id", "document_id"),
Index(
"idx_vector_embedding",
embedding,
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
),
)
class PGVectorStore(BaseVectorStore):
"""PostgreSQL with pgvector implementation for vector storage."""
def __init__(
self,
uri: str,
):
"""Initialize PostgreSQL connection for vector storage."""
self.engine = create_async_engine(uri)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
async def initialize(self):
"""Initialize database tables and vector extension."""
try:
async with self.engine.begin() as conn:
# Enable pgvector extension
await conn.execute(
func.create_extension("vector", schema="public", if_not_exists=True)
)
# Create tables and indexes
await conn.run_sync(Base.metadata.create_all)
# Create vector index
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS vector_idx
ON vector_embeddings
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
)
)
logger.info("PGVector store initialized successfully")
return True
except Exception as e:
logger.error(f"Error initializing PGVector store: {str(e)}")
return False
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
"""Store document chunks with their embeddings."""
try:
if not chunks:
return True, []
async with self.async_session() as session:
stored_ids = []
for chunk in chunks:
if not chunk.embedding:
logger.error(
f"Missing embedding for chunk {chunk.document_id}-{chunk.chunk_number}"
)
continue
vector_embedding = VectorEmbedding(
document_id=chunk.document_id,
chunk_number=chunk.chunk_number,
content=chunk.content,
chunk_metadata=str(chunk.metadata),
embedding=chunk.embedding,
)
session.add(vector_embedding)
stored_ids.append(f"{chunk.document_id}-{chunk.chunk_number}")
await session.commit()
return len(stored_ids) > 0, stored_ids
except Exception as e:
logger.error(f"Error storing embeddings: {str(e)}")
return False, []
async def query_similar(
self,
query_embedding: List[float],
k: int,
doc_ids: Optional[List[str]] = None,
) -> List[DocumentChunk]:
"""Find similar chunks using cosine similarity."""
try:
async with self.async_session() as session:
# Build query
query = select(VectorEmbedding).order_by(
VectorEmbedding.embedding.op("<->")(query_embedding)
)
if doc_ids:
query = query.filter(VectorEmbedding.document_id.in_(doc_ids))
query = query.limit(k)
result = await session.execute(query)
embeddings = result.scalars().all()
# Convert to DocumentChunks
chunks = []
for emb in embeddings:
try:
metadata = eval(emb.chunk_metadata) if emb.chunk_metadata else {}
except (ValueError, SyntaxError):
metadata = {}
chunk = DocumentChunk(
document_id=emb.document_id,
chunk_number=emb.chunk_number,
content=emb.content,
embedding=[], # Don't send embeddings back
metadata=metadata,
)
chunks.append(chunk)
return chunks
except Exception as e:
logger.error(f"Error querying similar chunks: {str(e)}")
return []

63
databridge.toml Normal file
View File

@ -0,0 +1,63 @@
[api]
host = "localhost"
port = 8000
reload = false
[auth]
jwt_algorithm = "HS256"
[completion]
provider = "ollama" # ollama, openai
model_name = "llama3.2"
default_max_tokens = 1000
default_temperature = 0.7
base_url = "http://localhost:11434"
[database]
provider = "postgres"
database_name = "databridge"
documents_table = "documents"
chunks_table = "document_chunks"
# documents_collection = "documents"
# chunks_collection = "document_chunks"
[embedding]
provider = "ollama" # "ollama", "openai"
model_name = "nomic-embed-text"
dimensions = 768
similarity_metric = "cosine" # "cosine", "dotProduct", "euclidean"
base_url = "http://localhost:11434"
[parser]
provider = "combined" # options: "combined", "unstructured", "contextual"
chunk_size = 1000
chunk_overlap = 200
use_unstructured_api = false
video_frame_sample_rate = 120 # not needed for unstructured
[reranker]
provider = "bge"
model_name = "BAAI/bge-reranker-large" # could also be "BAAI/bge-reranker-v2-gemma"
use_fp16 = true
query_max_length = 256
passage_max_length = 512
device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU
[storage]
provider = "local"
path = "./storage"
# region = "us-east-2"
# bucket_name = "databridge-s3-storage"
[vector_store]
provider = "mongodb"
num_chunks_to_retrieve = 20
index_name = "vector_index"

View File

@ -9,6 +9,8 @@ from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, OperationFailure
from pymongo.operations import SearchIndexModel
import argparse
import platform
import subprocess
# Force reload of environment variables
load_dotenv(find_dotenv(), override=True)
@ -183,26 +185,95 @@ def setup_postgres():
"""
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
# Load PostgreSQL URI from .env file
postgres_uri = os.getenv("POSTGRES_URI")
if not postgres_uri:
raise ValueError("POSTGRES_URI not found in .env file.")
# Check if pgvector is installed when on macOS
if platform.system() == "Darwin":
try:
# Check if postgresql is installed via homebrew
result = subprocess.run(
["brew", "list", "postgresql@14"], capture_output=True, text=True
)
if result.returncode != 0:
LOGGER.error(
"PostgreSQL not found. Please install it with: brew install postgresql@14"
)
raise RuntimeError("PostgreSQL not installed")
# Check if pgvector is installed
result = subprocess.run(["brew", "list", "pgvector"], capture_output=True, text=True)
if result.returncode != 0:
LOGGER.error(
"\nError: pgvector extension not found. Please install it with:\n"
"brew install pgvector\n"
"brew services stop postgresql@14\n"
"brew services start postgresql@14\n"
)
raise RuntimeError("pgvector not installed")
except FileNotFoundError:
LOGGER.error("Homebrew not found. Please install it from https://brew.sh")
raise
async def _setup_postgres():
try:
# Create async engine
engine = create_async_engine(postgres_uri)
async with engine.begin() as conn:
try:
# Enable pgvector extension
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
LOGGER.info("Enabled pgvector extension")
except Exception as e:
if "could not open extension control file" in str(e):
LOGGER.error(
"\nError: pgvector extension not found. Please install it:\n"
"- On macOS: brew install pgvector\n"
"- On Ubuntu: sudo apt install postgresql-14-pgvector\n"
"- On other systems: check https://github.com/pgvector/pgvector#installation\n"
)
raise
# Import and create all tables
from core.database.postgres_database import Base
from core.vector_store.pgvector_store import Base as VectorBase
await conn.run_sync(Base.metadata.create_all)
LOGGER.info("Created all PostgreSQL tables and indexes.")
await conn.run_sync(VectorBase.metadata.create_all)
LOGGER.info("Created all PostgreSQL tables and indexes")
# Create vector index with configuration from settings
table_name = CONFIG["vector_store"]["pgvector"]["table_name"]
index_method = CONFIG["vector_store"]["pgvector"]["index_method"]
index_lists = CONFIG["vector_store"]["pgvector"]["index_lists"]
dimensions = CONFIG["vector_store"]["pgvector"]["dimensions"]
# First, alter the embedding column to be a vector
alter_sql = f"""
ALTER TABLE {table_name}
ALTER COLUMN embedding TYPE vector({dimensions})
USING embedding::vector({dimensions});
"""
await conn.execute(text(alter_sql))
LOGGER.info(f"Altered embedding column to be vector({dimensions})")
# Then create the vector index
if index_method == "ivfflat":
index_sql = f"""
CREATE INDEX IF NOT EXISTS vector_idx
ON {table_name} USING ivfflat (embedding vector_l2_ops)
WITH (lists = {index_lists});
"""
await conn.execute(text(index_sql))
LOGGER.info(f"Created IVFFlat index on {table_name} with {index_lists} lists")
await engine.dispose()
LOGGER.info("PostgreSQL setup completed successfully.")
LOGGER.info("PostgreSQL setup completed successfully")
except Exception as e:
LOGGER.error(f"Failed to setup PostgreSQL: {e}")

View File

@ -290,3 +290,6 @@ psycopg2-binary==2.9.9
alembic==1.13.1
zipp==3.21.0
zlib-state==0.1.9
pgvector==0.2.5
psycopg[binary]==3.1.18
psycopg-binary==3.1.18