Config improvements (#17)

This commit is contained in:
Arnav Agrawal 2025-01-07 01:42:10 -05:00 committed by GitHub
parent f5666155c1
commit f72f6f0249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 348 additions and 255 deletions

View File

@ -1,84 +0,0 @@
# Core Service Components
[service]
host = "localhost"
port = 8000
reload = false
[service.components]
storage = "local" # "aws-s3"
database = "postgres" # "postgres", "mongodb"
vector_store = "pgvector" # "mongodb"
embedding = "ollama" # "openai", "ollama"
completion = "ollama" # "openai", "ollama"
parser = "combined" # "combined", "unstructured", "contextual"
reranker = "bge" # "bge"
# Storage Configuration
[storage.local]
path = "./storage"
[storage.aws]
region = "us-east-2"
bucket_name = "databridge-s3-storage"
# Database Configuration
[database.postgres]
database_name = "databridge"
documents_table = "documents"
chunks_table = "document_chunks"
[database.mongodb]
database_name = "databridge"
documents_collection = "documents"
chunks_collection = "document_chunks"
# Vector Store Configuration
[vector_store.mongodb]
dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small
index_name = "vector_index"
similarity_metric = "cosine"
[vector_store.pgvector]
dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small
table_name = "vector_embeddings"
index_method = "ivfflat" # "ivfflat" or "hnsw"
index_lists = 100 # Number of lists for ivfflat index
probes = 10 # Number of probes for ivfflat search
# Model Configurations
[models]
[models.embedding]
model_name = "nomic-embed-text" # "text-embedding-3-small", "nomic-embed-text"
[models.completion]
model_name = "llama3.2" # "gpt-4o-mini", "llama3.1", etc.
default_max_tokens = 1000
default_temperature = 0.7
[models.ollama]
base_url = "http://localhost:11434"
[models.reranker]
model_name = "BAAI/bge-reranker-large" # "BAAI/bge-reranker-v2-gemma", "BAAI/bge-reranker-large"
device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU
use_fp16 = true
query_max_length = 256
passage_max_length = 512
# Document Processing
[processing]
[processing.text]
chunk_size = 1000
chunk_overlap = 200
default_k = 4
use_reranking = true # Whether to use reranking by default
[processing.video]
frame_sample_rate = 120
[processing.unstructured]
use_api = false
# Authentication
[auth]
jwt_algorithm = "HS256"

View File

@ -29,7 +29,7 @@ from core.storage.local_storage import LocalStorage
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
from core.completion.ollama_completion import OllamaCompletionModel
from core.parser.contextual_parser import ContextualParser
from core.reranker.bge_reranker import BGEReranker
from core.reranker.flag_reranker import FlagReranker
# Initialize FastAPI app
app = FastAPI(title="DataBridge API")
@ -145,7 +145,7 @@ match settings.PARSER_PROVIDER:
match settings.EMBEDDING_PROVIDER:
case "ollama":
embedding_model = OllamaEmbeddingModel(
base_url=settings.OLLAMA_BASE_URL,
base_url=settings.EMBEDDING_OLLAMA_BASE_URL,
model_name=settings.EMBEDDING_MODEL,
)
case "openai":
@ -163,7 +163,7 @@ match settings.COMPLETION_PROVIDER:
case "ollama":
completion_model = OllamaCompletionModel(
model_name=settings.COMPLETION_MODEL,
base_url=settings.OLLAMA_BASE_URL,
base_url=settings.COMPLETION_OLLAMA_BASE_URL,
)
case "openai":
if not settings.OPENAI_API_KEY:
@ -176,8 +176,8 @@ match settings.COMPLETION_PROVIDER:
# Initialize reranker
match settings.RERANKER_PROVIDER:
case "bge":
reranker = BGEReranker(
case "flag":
reranker = FlagReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,

View File

@ -1,5 +1,5 @@
from typing import Optional
from pydantic import Field
import os
from typing import Literal, Optional
from pydantic_settings import BaseSettings
from functools import lru_cache
import tomli
@ -9,74 +9,70 @@ from dotenv import load_dotenv
class Settings(BaseSettings):
"""DataBridge configuration settings."""
# Required environment variables (referenced in config.toml)
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
MONGODB_URI: Optional[str] = Field(None, env="MONGODB_URI")
POSTGRES_URI: Optional[str] = Field(None, env="POSTGRES_URI")
# environment variables:
JWT_SECRET_KEY: str
POSTGRES_URI: Optional[str] = None
MONGODB_URI: Optional[str] = None
UNSTRUCTURED_API_KEY: Optional[str] = None
AWS_ACCESS_KEY: Optional[str] = None
AWS_SECRET_ACCESS_KEY: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
UNSTRUCTURED_API_KEY: Optional[str] = Field(None, env="UNSTRUCTURED_API_KEY")
AWS_ACCESS_KEY: Optional[str] = Field(None, env="AWS_ACCESS_KEY")
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(None, env="AWS_SECRET_ACCESS_KEY")
ASSEMBLYAI_API_KEY: Optional[str] = Field(None, env="ASSEMBLYAI_API_KEY")
OPENAI_API_KEY: Optional[str] = Field(None, env="OPENAI_API_KEY")
ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
# configuration variables:
## api:
HOST: str
PORT: int
RELOAD: bool
# Service settings
HOST: str = "localhost"
PORT: int = 8000
RELOAD: bool = False
## auth:
JWT_ALGORITHM: str
# Component selection
STORAGE_PROVIDER: str = "local"
DATABASE_PROVIDER: str = "mongodb"
VECTOR_STORE_PROVIDER: str = "mongodb"
EMBEDDING_PROVIDER: str = "openai"
COMPLETION_PROVIDER: str = "ollama"
PARSER_PROVIDER: str = "combined"
RERANKER_PROVIDER: str = "bge"
## completion:
COMPLETION_PROVIDER: Literal["ollama", "openai"]
COMPLETION_MODEL: str
COMPLETION_MAX_TOKENS: Optional[str] = None
COMPLETION_TEMPERATURE: Optional[float] = None
COMPLETION_OLLAMA_BASE_URL: Optional[str] = None
# Storage settings
STORAGE_PATH: str = "./storage"
AWS_REGION: str = "us-east-2"
S3_BUCKET: str = "databridge-s3-storage"
## database
DATABASE_PROVIDER: Literal["postgres", "mongodb"]
DATABASE_NAME: Optional[str] = None
DOCUMENTS_COLLECTION: Optional[str] = None
# Database settings
DATABRIDGE_DB: str = "DataBridgeTest"
DOCUMENTS_TABLE: str = "documents"
CHUNKS_TABLE: str = "document_chunks"
DOCUMENTS_COLLECTION: str = "documents"
CHUNKS_COLLECTION: str = "document_chunks"
## embedding
EMBEDDING_PROVIDER: Literal["ollama", "openai"]
EMBEDDING_MODEL: str
VECTOR_DIMENSIONS: int
EMBEDDING_SIMILARITY_METRIC: Literal["cosine", "dotProduct"]
EMBEDDING_OLLAMA_BASE_URL: Optional[str] = None
# Vector store settings
VECTOR_INDEX_NAME: str = "vector_index"
VECTOR_DIMENSIONS: int = 1536
PGVECTOR_TABLE_NAME: str = "vector_embeddings"
PGVECTOR_INDEX_METHOD: str = "ivfflat"
PGVECTOR_INDEX_LISTS: int = 100
PGVECTOR_PROBES: int = 10
## parser
PARSER_PROVIDER: Literal["unstructured", "combined", "contextual"]
CHUNK_SIZE: int
CHUNK_OVERLAP: int
USE_UNSTRUCTURED_API: bool
FRAME_SAMPLE_RATE: Optional[int] = None
# Model settings
EMBEDDING_MODEL: str = "text-embedding-3-small"
COMPLETION_MODEL: str = "llama3.1"
COMPLETION_MAX_TOKENS: int = 1000
COMPLETION_TEMPERATURE: float = 0.7
OLLAMA_BASE_URL: str = "http://localhost:11434"
RERANKER_MODEL: str = "BAAI/bge-reranker-v2-gemma"
## reranker
USE_RERANKING: bool
RERANKER_PROVIDER: Optional[Literal["flag"]] = None
RERANKER_MODEL: Optional[str] = None
RERANKER_QUERY_MAX_LENGTH: Optional[int] = None
RERANKER_PASSAGE_MAX_LENGTH: Optional[int] = None
RERANKER_USE_FP16: Optional[bool] = None
RERANKER_DEVICE: Optional[str] = None
RERANKER_USE_FP16: bool = True
RERANKER_QUERY_MAX_LENGTH: int = 256
RERANKER_PASSAGE_MAX_LENGTH: int = 512
# Processing settings
CHUNK_SIZE: int = 1000
CHUNK_OVERLAP: int = 200
DEFAULT_K: int = 4
FRAME_SAMPLE_RATE: int = 120
USE_UNSTRUCTURED_API: bool = False
USE_RERANKING: bool = True
## storage
STORAGE_PROVIDER: Literal["local", "aws-s3"]
STORAGE_PATH: Optional[str] = None
AWS_REGION: Optional[str] = None
S3_BUCKET: Optional[str] = None
# Auth settings
JWT_ALGORITHM: str = "HS256"
## vector store
VECTOR_STORE_PROVIDER: Literal["pgvector", "mongodb"]
VECTOR_STORE_DATABASE_NAME: Optional[str] = None
VECTOR_STORE_COLLECTION_NAME: Optional[str] = None
@lru_cache()
@ -85,76 +81,176 @@ def get_settings() -> Settings:
load_dotenv(override=True)
# Load config.toml
with open("config.toml", "rb") as f:
with open("databridge.toml", "rb") as f:
config = tomli.load(f)
# Map config.toml values to settings
settings_dict = {
# Service settings
"HOST": config["service"]["host"],
"PORT": config["service"]["port"],
"RELOAD": config["service"]["reload"],
# Component selection
"STORAGE_PROVIDER": config["service"]["components"]["storage"],
"DATABASE_PROVIDER": config["service"]["components"]["database"],
"VECTOR_STORE_PROVIDER": config["service"]["components"]["vector_store"],
"EMBEDDING_PROVIDER": config["service"]["components"]["embedding"],
"COMPLETION_PROVIDER": config["service"]["components"]["completion"],
"PARSER_PROVIDER": config["service"]["components"]["parser"],
"RERANKER_PROVIDER": config["service"]["components"]["reranker"],
# Storage settings
"STORAGE_PATH": config["storage"]["local"]["path"],
"AWS_REGION": config["storage"]["aws"]["region"],
"S3_BUCKET": config["storage"]["aws"]["bucket_name"],
# Database settings
"DATABRIDGE_DB": config["database"][config["service"]["components"]["database"]][
"database_name"
],
"DOCUMENTS_TABLE": config["database"]
.get("postgres", {})
.get("documents_table", "documents"),
"CHUNKS_TABLE": config["database"]
.get("postgres", {})
.get("chunks_table", "document_chunks"),
"DOCUMENTS_COLLECTION": config["database"]
.get("mongodb", {})
.get("documents_collection", "documents"),
"CHUNKS_COLLECTION": config["database"]
.get("mongodb", {})
.get("chunks_collection", "document_chunks"),
# Vector store settings
"VECTOR_INDEX_NAME": config["vector_store"]["mongodb"]["index_name"],
"VECTOR_DIMENSIONS": config["vector_store"][
config["service"]["components"]["vector_store"]
]["dimensions"],
"PGVECTOR_TABLE_NAME": config["vector_store"]
.get("pgvector", {})
.get("table_name", "vector_embeddings"),
"PGVECTOR_INDEX_METHOD": config["vector_store"]
.get("pgvector", {})
.get("index_method", "ivfflat"),
"PGVECTOR_INDEX_LISTS": config["vector_store"].get("pgvector", {}).get("index_lists", 100),
"PGVECTOR_PROBES": config["vector_store"].get("pgvector", {}).get("probes", 10),
# Model settings
"EMBEDDING_MODEL": config["models"]["embedding"]["model_name"],
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
"RERANKER_MODEL": config["models"]["reranker"]["model_name"],
"RERANKER_DEVICE": config["models"]["reranker"].get("device"),
"RERANKER_USE_FP16": config["models"]["reranker"].get("use_fp16", True),
"RERANKER_QUERY_MAX_LENGTH": config["models"]["reranker"].get("query_max_length", 256),
"RERANKER_PASSAGE_MAX_LENGTH": config["models"]["reranker"].get("passage_max_length", 512),
# Processing settings
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
"DEFAULT_K": config["processing"]["text"]["default_k"],
"USE_RERANKING": config["processing"]["text"]["use_reranking"],
"FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"],
"USE_UNSTRUCTURED_API": config["processing"]["unstructured"]["use_api"],
# Auth settings
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
em = "'{missing_value}' needed if '{field}' is set to '{value}'"
# load api config
api_config = {
"HOST": config["api"]["host"],
"PORT": int(config["api"]["port"]),
"RELOAD": bool(config["api"]["reload"]),
}
# load auth config
auth_config = {
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
"JWT_SECRET_KEY": os.environ["JWT_SECRET_KEY"],
}
# load completion config
completion_config = {
"COMPLETION_PROVIDER": config["completion"]["provider"],
"COMPLETION_MODEL": config["completion"]["model_name"],
}
match completion_config["COMPLETION_PROVIDER"]:
case "openai" if "OPENAI_API_KEY" in os.environ:
completion_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]})
case "openai":
msg = em.format(
missing_value="OPENAI_API_KEY", field="completion.provider", value="openai"
)
raise ValueError(msg)
case "ollama" if "base_url" in config["completion"]:
completion_config.update(
{"COMPLETION_OLLAMA_BASE_URL": config["completion"]["base_url"]}
)
case "ollama":
msg = em.format(missing_value="base_url", field="completion.provider", value="ollama")
raise ValueError(msg)
case _:
prov = completion_config["COMPLETION_PROVIDER"]
raise ValueError(f"Unknown completion provider selected: '{prov}'")
# load database config
database_config = {"DATABASE_PROVIDER": config["database"]["provider"]}
match database_config["DATABASE_PROVIDER"]:
case "mongodb":
database_config.update(
{
"DATABASE_NAME": config["database"]["database_name"],
"COLLECTION_NAME": config["database"]["collection_name"],
}
)
case "postgres" if "POSTGRES_URI" in os.environ:
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
case "postgres":
msg = em.format(
missing_value="POSTGRES_URI", field="database.provider", value="postgres"
)
raise ValueError(msg)
case _:
prov = database_config["DATABASE_PROVIDER"]
raise ValueError(f"Unknown database provider selected: '{prov}'")
# load embedding config
embedding_config = {
"EMBEDDING_PROVIDER": config["embedding"]["provider"],
"EMBEDDING_MODEL": config["embedding"]["model_name"],
"VECTOR_DIMENSIONS": config["embedding"]["dimensions"],
"EMBEDDING_SIMILARITY_METRIC": config["embedding"]["similarity_metric"],
}
match embedding_config["EMBEDDING_PROVIDER"]:
case "openai" if "OPENAI_API_KEY" in os.environ:
embedding_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]})
case "openai":
msg = em.format(
missing_value="OPENAI_API_KEY", field="embedding.provider", value="openai"
)
raise ValueError(msg)
case "ollama" if "base_url" in config["embedding"]:
embedding_config.update({"EMBEDDING_OLLAMA_BASE_URL": config["embedding"]["base_url"]})
case "ollama":
msg = em.format(missing_value="base_url", field="embedding.provider", value="ollama")
raise ValueError(msg)
case _:
prov = embedding_config["EMBEDDING_PROVIDER"]
raise ValueError(f"Unknown embedding provider selected: '{prov}'")
# load parser config
parser_config = {
"PARSER_PROVIDER": config["parser"]["provider"],
"CHUNK_SIZE": config["parser"]["chunk_size"],
"CHUNK_OVERLAP": config["parser"]["chunk_overlap"],
"USE_UNSTRUCTURED_API": config["parser"]["use_unstructured_api"],
}
if parser_config["USE_UNSTRUCTURED_API"] and "UNSTRUCTURED_API_KEY" not in os.environ:
msg = em.format(
missing_value="UNSTRUCTURED_API_KEY", field="parser.use_unstructured_api", value="true"
)
raise ValueError(msg)
elif parser_config["USE_UNSTRUCTURED_API"]:
parser_config.update({"UNSTRUCTURED_API_KEY": os.environ["UNSTRUCTURED_API_KEY"]})
# load reranker config
reranker_config = {"USE_RERANKING": config["reranker"]["use_reranker"]}
if reranker_config["USE_RERANKING"]:
reranker_config.update(
{
"RERANKER_PROVIDER": config["reranker"]["provider"],
"RERANKER_MODEL": config["reranker"]["model_name"],
"RERANKER_QUERY_MAX_LENGTH": config["reranker"]["query_max_length"],
"RERANKER_PASSAGE_MAX_LENGTH": config["reranker"]["passage_max_length"],
"RERANKER_USE_FP16": config["reranker"]["use_fp16"],
"RERANKER_DEVICE": config["reranker"]["device"],
}
)
# load storage config
storage_config = {"STORAGE_PROVIDER": config["storage"]["provider"]}
match storage_config["STORAGE_PROVIDER"]:
case "local":
storage_config.update({"STORAGE_PATH": config["storage"]["storage_path"]})
case "aws-s3" if all(
key in os.environ for key in ["AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"]
):
storage_config.update(
{
"AWS_REGION": config["storage"]["region"],
"S3_BUCKET": config["storage"]["bucket_name"],
"AWS_ACCESS_KEY": os.environ["AWS_ACCESS_KEY"],
"AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"],
}
)
case "aws-s3":
msg = em.format(
missing_value="AWS credentials", field="storage.provider", value="aws-s3"
)
raise ValueError(msg)
case _:
prov = storage_config["STORAGE_PROVIDER"]
raise ValueError(f"Unknown storage provider selected: '{prov}'")
# load vector store config
vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]}
match vector_store_config["VECTOR_STORE_PROVIDER"]:
case "mongodb":
vector_store_config.update(
{
"VECTOR_STORE_DATABASE_NAME": config["vector_store"]["database_name"],
"VECTOR_STORE_COLLECTION_NAME": config["vector_store"]["collection_name"],
}
)
case "pgvector":
if "POSTGRES_URI" not in os.environ:
msg = em.format(
missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector"
)
raise ValueError(msg)
case _:
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
settings_dict = {}
settings_dict.update(
**api_config,
**auth_config,
**completion_config,
**database_config,
**embedding_config,
**parser_config,
**reranker_config,
**storage_config,
**vector_store_config,
)
return Settings(**settings_dict)

View File

@ -5,8 +5,8 @@ from core.models.chunk import DocumentChunk
from core.reranker.base_reranker import BaseReranker
class BGEReranker(BaseReranker):
"""BGE reranker implementation using FlagEmbedding"""
class FlagReranker(BaseReranker):
"""Reranker implementation using FlagEmbedding"""
def __init__(
self,
@ -16,7 +16,7 @@ class BGEReranker(BaseReranker):
use_fp16: bool = True,
device: Optional[str] = None,
):
"""Initialize BGE reranker"""
"""Initialize flag reranker"""
devices = [device] if device else None
self.reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_name,

View File

@ -2,7 +2,7 @@ import pytest
from typing import List
from core.models.chunk import DocumentChunk
from core.reranker.bge_reranker import BGEReranker
from core.reranker.flag_reranker import FlagReranker
from core.config import get_settings
@ -35,9 +35,9 @@ def sample_chunks() -> List[DocumentChunk]:
@pytest.fixture
def reranker():
"""Fixture to create and reuse a BGE reranker instance"""
"""Fixture to create and reuse a flag reranker instance"""
settings = get_settings()
return BGEReranker(
return FlagReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,

84
databridge.toml Normal file
View File

@ -0,0 +1,84 @@
[api]
host = "localhost"
port = 8000
reload = true
[auth]
jwt_algorithm = "HS256"
[completion]
provider = "ollama"
model_name = "llama3.2"
default_max_tokens = "1000"
default_temperature = 0.7
base_url = "http://localhost:11434"
# [completion]
# provider = "openai"
# model_name = "gpt4o-mini"
# default_max_tokens = "1000"
# default_temperature = 0.7
[database]
provider = "postgres"
# [database]
# provider = "mongodb"
# database_name = "databridge"
# collection_name = "documents"
[embedding]
provider = "ollama"
model_name = "nomic-embed-text"
dimensions = 768
similarity_metric = "cosine"
base_url = "http://localhost:11434"
# [embedding]
# provider = "openai"
# model_name = "text-embedding-3-small"
# dimensions = 1536
# similarity_metric = "dotProduct"
[parser]
provider = "unstructured"
chunk_size = 1000
chunk_overlap = 200
use_unstructured_api = false
# [parser]
# provider = "combined" | "contextual"
# chunk_size = 1000
# chunk_overlap = 200
# use_unstructured_api = false
# frame_sample_rate = 120
[reranker]
use_reranker = true
provider = "flag"
model_name = "BAAI/bge-reranker-large"
query_max_length = 256
passage_max_length = 512
use_fp16 = true
device = "mps"
# [reranker]
# use_reranker = false
[storage]
provider = "local"
storage_path = "../storage"
# [storage]
# provider = "aws-s3"
# region = "us-east-2"
# bucket_name = "databridge-s3-storage"
[vector_store]
provider = "pgvector"
# [vector_store]
# provider = "mongodb"
# database_name = "databridge"
# collection_name = "document_chunks"

View File

@ -37,33 +37,33 @@ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(messag
console_handler.setFormatter(formatter)
LOGGER.addHandler(console_handler)
# Load configuration from config.toml
config_path = Path("config.toml")
# Load configuration from databridge.toml
config_path = Path("databridge.toml")
with open(config_path, "rb") as f:
CONFIG = tomli.load(f)
LOGGER.info("Loaded configuration from config.toml")
LOGGER.info("Loaded configuration from databridge.toml")
# Extract configuration values
STORAGE_PROVIDER = CONFIG["service"]["components"]["storage"]
DATABASE_PROVIDER = CONFIG["service"]["components"]["database"]
DATABASE_NAME = CONFIG["database"][DATABASE_PROVIDER]["database_name"]
STORAGE_PROVIDER = CONFIG["storage"]["provider"]
DATABASE_PROVIDER = CONFIG["database"]["provider"]
# MongoDB specific config
DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"]
CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"]
VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"]
VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"]
SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"]
# PostgreSQL specific config
DOCUMENTS_TABLE = CONFIG["database"]["postgres"]["documents_table"]
CHUNKS_TABLE = CONFIG["database"]["postgres"]["chunks_table"]
if "mongodb" in CONFIG["database"]:
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
DOCUMENTS_COLLECTION = "documents"
CHUNKS_COLLECTION = "document_chunks"
if "mongodb" in CONFIG["vector_store"]:
VECTOR_DIMENSIONS = CONFIG["embedding"]["dimensions"]
VECTOR_INDEX_NAME = "vector_index"
SIMILARITY_METRIC = CONFIG["embedding"]["similarity_metric"]
# Extract storage-specific configuration
DEFAULT_REGION = CONFIG["storage"]["aws"]["region"] if STORAGE_PROVIDER == "aws-s3" else None
DEFAULT_BUCKET_NAME = (
CONFIG["storage"]["aws"]["bucket_name"] if STORAGE_PROVIDER == "aws-s3" else None
)
if STORAGE_PROVIDER == "aws-s3":
DEFAULT_REGION = CONFIG["storage"]["region"]
DEFAULT_BUCKET_NAME = CONFIG["storage"]["bucket_name"]
else:
DEFAULT_REGION = None
DEFAULT_BUCKET_NAME = None
def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
@ -248,10 +248,8 @@ def setup_postgres():
LOGGER.info("Created all PostgreSQL tables and indexes")
# Create vector index with configuration from settings
table_name = CONFIG["vector_store"]["pgvector"]["table_name"]
index_method = CONFIG["vector_store"]["pgvector"]["index_method"]
index_lists = CONFIG["vector_store"]["pgvector"]["index_lists"]
dimensions = CONFIG["vector_store"]["pgvector"]["dimensions"]
table_name = "document_chunks" # Default table name for pgvector
dimensions = CONFIG["embedding"]["dimensions"]
# First, alter the embedding column to be a vector
alter_sql = f"""
@ -262,15 +260,14 @@ def setup_postgres():
await conn.execute(text(alter_sql))
LOGGER.info(f"Altered embedding column to be vector({dimensions})")
# Then create the vector index
if index_method == "ivfflat":
index_sql = f"""
CREATE INDEX IF NOT EXISTS vector_idx
ON {table_name} USING ivfflat (embedding vector_l2_ops)
WITH (lists = {index_lists});
"""
await conn.execute(text(index_sql))
LOGGER.info(f"Created IVFFlat index on {table_name} with {index_lists} lists")
# Create the vector index
index_sql = f"""
CREATE INDEX IF NOT EXISTS vector_idx
ON {table_name} USING ivfflat (embedding vector_l2_ops)
WITH (lists = 100);
"""
await conn.execute(text(index_sql))
LOGGER.info(f"Created IVFFlat index on {table_name}")
await engine.dispose()
LOGGER.info("PostgreSQL setup completed successfully")