diff --git a/config.toml b/config.toml deleted file mode 100644 index 5b351ec..0000000 --- a/config.toml +++ /dev/null @@ -1,84 +0,0 @@ -# Core Service Components -[service] -host = "localhost" -port = 8000 -reload = false - -[service.components] -storage = "local" # "aws-s3" -database = "postgres" # "postgres", "mongodb" -vector_store = "pgvector" # "mongodb" -embedding = "ollama" # "openai", "ollama" -completion = "ollama" # "openai", "ollama" -parser = "combined" # "combined", "unstructured", "contextual" -reranker = "bge" # "bge" - -# Storage Configuration -[storage.local] -path = "./storage" - -[storage.aws] -region = "us-east-2" -bucket_name = "databridge-s3-storage" - -# Database Configuration -[database.postgres] -database_name = "databridge" -documents_table = "documents" -chunks_table = "document_chunks" - -[database.mongodb] -database_name = "databridge" -documents_collection = "documents" -chunks_collection = "document_chunks" - -# Vector Store Configuration -[vector_store.mongodb] -dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small -index_name = "vector_index" -similarity_metric = "cosine" - -[vector_store.pgvector] -dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small -table_name = "vector_embeddings" -index_method = "ivfflat" # "ivfflat" or "hnsw" -index_lists = 100 # Number of lists for ivfflat index -probes = 10 # Number of probes for ivfflat search - -# Model Configurations -[models] -[models.embedding] -model_name = "nomic-embed-text" # "text-embedding-3-small", "nomic-embed-text" - -[models.completion] -model_name = "llama3.2" # "gpt-4o-mini", "llama3.1", etc. -default_max_tokens = 1000 -default_temperature = 0.7 - -[models.ollama] -base_url = "http://localhost:11434" - -[models.reranker] -model_name = "BAAI/bge-reranker-large" # "BAAI/bge-reranker-v2-gemma", "BAAI/bge-reranker-large" -device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU -use_fp16 = true -query_max_length = 256 -passage_max_length = 512 - -# Document Processing -[processing] -[processing.text] -chunk_size = 1000 -chunk_overlap = 200 -default_k = 4 -use_reranking = true # Whether to use reranking by default - -[processing.video] -frame_sample_rate = 120 - -[processing.unstructured] -use_api = false - -# Authentication -[auth] -jwt_algorithm = "HS256" diff --git a/core/api.py b/core/api.py index 76ccd91..b9df0e0 100644 --- a/core/api.py +++ b/core/api.py @@ -29,7 +29,7 @@ from core.storage.local_storage import LocalStorage from core.embedding.openai_embedding_model import OpenAIEmbeddingModel from core.completion.ollama_completion import OllamaCompletionModel from core.parser.contextual_parser import ContextualParser -from core.reranker.bge_reranker import BGEReranker +from core.reranker.flag_reranker import FlagReranker # Initialize FastAPI app app = FastAPI(title="DataBridge API") @@ -145,7 +145,7 @@ match settings.PARSER_PROVIDER: match settings.EMBEDDING_PROVIDER: case "ollama": embedding_model = OllamaEmbeddingModel( - base_url=settings.OLLAMA_BASE_URL, + base_url=settings.EMBEDDING_OLLAMA_BASE_URL, model_name=settings.EMBEDDING_MODEL, ) case "openai": @@ -163,7 +163,7 @@ match settings.COMPLETION_PROVIDER: case "ollama": completion_model = OllamaCompletionModel( model_name=settings.COMPLETION_MODEL, - base_url=settings.OLLAMA_BASE_URL, + base_url=settings.COMPLETION_OLLAMA_BASE_URL, ) case "openai": if not settings.OPENAI_API_KEY: @@ -176,8 +176,8 @@ match settings.COMPLETION_PROVIDER: # Initialize reranker match settings.RERANKER_PROVIDER: - case "bge": - reranker = BGEReranker( + case "flag": + reranker = FlagReranker( model_name=settings.RERANKER_MODEL, device=settings.RERANKER_DEVICE, use_fp16=settings.RERANKER_USE_FP16, diff --git a/core/config.py b/core/config.py index aa55d69..f0731a5 100644 --- a/core/config.py +++ b/core/config.py @@ -1,5 +1,5 @@ -from typing import Optional -from pydantic import Field +import os +from typing import Literal, Optional from pydantic_settings import BaseSettings from functools import lru_cache import tomli @@ -9,74 +9,70 @@ from dotenv import load_dotenv class Settings(BaseSettings): """DataBridge configuration settings.""" - # Required environment variables (referenced in config.toml) - JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY") - MONGODB_URI: Optional[str] = Field(None, env="MONGODB_URI") - POSTGRES_URI: Optional[str] = Field(None, env="POSTGRES_URI") + # environment variables: + JWT_SECRET_KEY: str + POSTGRES_URI: Optional[str] = None + MONGODB_URI: Optional[str] = None + UNSTRUCTURED_API_KEY: Optional[str] = None + AWS_ACCESS_KEY: Optional[str] = None + AWS_SECRET_ACCESS_KEY: Optional[str] = None + OPENAI_API_KEY: Optional[str] = None + ANTHROPIC_API_KEY: Optional[str] = None - UNSTRUCTURED_API_KEY: Optional[str] = Field(None, env="UNSTRUCTURED_API_KEY") - AWS_ACCESS_KEY: Optional[str] = Field(None, env="AWS_ACCESS_KEY") - AWS_SECRET_ACCESS_KEY: Optional[str] = Field(None, env="AWS_SECRET_ACCESS_KEY") - ASSEMBLYAI_API_KEY: Optional[str] = Field(None, env="ASSEMBLYAI_API_KEY") - OPENAI_API_KEY: Optional[str] = Field(None, env="OPENAI_API_KEY") - ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY") + # configuration variables: + ## api: + HOST: str + PORT: int + RELOAD: bool - # Service settings - HOST: str = "localhost" - PORT: int = 8000 - RELOAD: bool = False + ## auth: + JWT_ALGORITHM: str - # Component selection - STORAGE_PROVIDER: str = "local" - DATABASE_PROVIDER: str = "mongodb" - VECTOR_STORE_PROVIDER: str = "mongodb" - EMBEDDING_PROVIDER: str = "openai" - COMPLETION_PROVIDER: str = "ollama" - PARSER_PROVIDER: str = "combined" - RERANKER_PROVIDER: str = "bge" + ## completion: + COMPLETION_PROVIDER: Literal["ollama", "openai"] + COMPLETION_MODEL: str + COMPLETION_MAX_TOKENS: Optional[str] = None + COMPLETION_TEMPERATURE: Optional[float] = None + COMPLETION_OLLAMA_BASE_URL: Optional[str] = None - # Storage settings - STORAGE_PATH: str = "./storage" - AWS_REGION: str = "us-east-2" - S3_BUCKET: str = "databridge-s3-storage" + ## database + DATABASE_PROVIDER: Literal["postgres", "mongodb"] + DATABASE_NAME: Optional[str] = None + DOCUMENTS_COLLECTION: Optional[str] = None - # Database settings - DATABRIDGE_DB: str = "DataBridgeTest" - DOCUMENTS_TABLE: str = "documents" - CHUNKS_TABLE: str = "document_chunks" - DOCUMENTS_COLLECTION: str = "documents" - CHUNKS_COLLECTION: str = "document_chunks" + ## embedding + EMBEDDING_PROVIDER: Literal["ollama", "openai"] + EMBEDDING_MODEL: str + VECTOR_DIMENSIONS: int + EMBEDDING_SIMILARITY_METRIC: Literal["cosine", "dotProduct"] + EMBEDDING_OLLAMA_BASE_URL: Optional[str] = None - # Vector store settings - VECTOR_INDEX_NAME: str = "vector_index" - VECTOR_DIMENSIONS: int = 1536 - PGVECTOR_TABLE_NAME: str = "vector_embeddings" - PGVECTOR_INDEX_METHOD: str = "ivfflat" - PGVECTOR_INDEX_LISTS: int = 100 - PGVECTOR_PROBES: int = 10 + ## parser + PARSER_PROVIDER: Literal["unstructured", "combined", "contextual"] + CHUNK_SIZE: int + CHUNK_OVERLAP: int + USE_UNSTRUCTURED_API: bool + FRAME_SAMPLE_RATE: Optional[int] = None - # Model settings - EMBEDDING_MODEL: str = "text-embedding-3-small" - COMPLETION_MODEL: str = "llama3.1" - COMPLETION_MAX_TOKENS: int = 1000 - COMPLETION_TEMPERATURE: float = 0.7 - OLLAMA_BASE_URL: str = "http://localhost:11434" - RERANKER_MODEL: str = "BAAI/bge-reranker-v2-gemma" + ## reranker + USE_RERANKING: bool + RERANKER_PROVIDER: Optional[Literal["flag"]] = None + RERANKER_MODEL: Optional[str] = None + RERANKER_QUERY_MAX_LENGTH: Optional[int] = None + RERANKER_PASSAGE_MAX_LENGTH: Optional[int] = None + RERANKER_USE_FP16: Optional[bool] = None RERANKER_DEVICE: Optional[str] = None - RERANKER_USE_FP16: bool = True - RERANKER_QUERY_MAX_LENGTH: int = 256 - RERANKER_PASSAGE_MAX_LENGTH: int = 512 - # Processing settings - CHUNK_SIZE: int = 1000 - CHUNK_OVERLAP: int = 200 - DEFAULT_K: int = 4 - FRAME_SAMPLE_RATE: int = 120 - USE_UNSTRUCTURED_API: bool = False - USE_RERANKING: bool = True + ## storage + STORAGE_PROVIDER: Literal["local", "aws-s3"] + STORAGE_PATH: Optional[str] = None + AWS_REGION: Optional[str] = None + S3_BUCKET: Optional[str] = None - # Auth settings - JWT_ALGORITHM: str = "HS256" + ## vector store + VECTOR_STORE_PROVIDER: Literal["pgvector", "mongodb"] + VECTOR_STORE_DATABASE_NAME: Optional[str] = None + VECTOR_STORE_COLLECTION_NAME: Optional[str] = None @lru_cache() @@ -85,76 +81,176 @@ def get_settings() -> Settings: load_dotenv(override=True) # Load config.toml - with open("config.toml", "rb") as f: + with open("databridge.toml", "rb") as f: config = tomli.load(f) - # Map config.toml values to settings - settings_dict = { - # Service settings - "HOST": config["service"]["host"], - "PORT": config["service"]["port"], - "RELOAD": config["service"]["reload"], - # Component selection - "STORAGE_PROVIDER": config["service"]["components"]["storage"], - "DATABASE_PROVIDER": config["service"]["components"]["database"], - "VECTOR_STORE_PROVIDER": config["service"]["components"]["vector_store"], - "EMBEDDING_PROVIDER": config["service"]["components"]["embedding"], - "COMPLETION_PROVIDER": config["service"]["components"]["completion"], - "PARSER_PROVIDER": config["service"]["components"]["parser"], - "RERANKER_PROVIDER": config["service"]["components"]["reranker"], - # Storage settings - "STORAGE_PATH": config["storage"]["local"]["path"], - "AWS_REGION": config["storage"]["aws"]["region"], - "S3_BUCKET": config["storage"]["aws"]["bucket_name"], - # Database settings - "DATABRIDGE_DB": config["database"][config["service"]["components"]["database"]][ - "database_name" - ], - "DOCUMENTS_TABLE": config["database"] - .get("postgres", {}) - .get("documents_table", "documents"), - "CHUNKS_TABLE": config["database"] - .get("postgres", {}) - .get("chunks_table", "document_chunks"), - "DOCUMENTS_COLLECTION": config["database"] - .get("mongodb", {}) - .get("documents_collection", "documents"), - "CHUNKS_COLLECTION": config["database"] - .get("mongodb", {}) - .get("chunks_collection", "document_chunks"), - # Vector store settings - "VECTOR_INDEX_NAME": config["vector_store"]["mongodb"]["index_name"], - "VECTOR_DIMENSIONS": config["vector_store"][ - config["service"]["components"]["vector_store"] - ]["dimensions"], - "PGVECTOR_TABLE_NAME": config["vector_store"] - .get("pgvector", {}) - .get("table_name", "vector_embeddings"), - "PGVECTOR_INDEX_METHOD": config["vector_store"] - .get("pgvector", {}) - .get("index_method", "ivfflat"), - "PGVECTOR_INDEX_LISTS": config["vector_store"].get("pgvector", {}).get("index_lists", 100), - "PGVECTOR_PROBES": config["vector_store"].get("pgvector", {}).get("probes", 10), - # Model settings - "EMBEDDING_MODEL": config["models"]["embedding"]["model_name"], - "COMPLETION_MODEL": config["models"]["completion"]["model_name"], - "COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"], - "COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"], - "OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"], - "RERANKER_MODEL": config["models"]["reranker"]["model_name"], - "RERANKER_DEVICE": config["models"]["reranker"].get("device"), - "RERANKER_USE_FP16": config["models"]["reranker"].get("use_fp16", True), - "RERANKER_QUERY_MAX_LENGTH": config["models"]["reranker"].get("query_max_length", 256), - "RERANKER_PASSAGE_MAX_LENGTH": config["models"]["reranker"].get("passage_max_length", 512), - # Processing settings - "CHUNK_SIZE": config["processing"]["text"]["chunk_size"], - "CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"], - "DEFAULT_K": config["processing"]["text"]["default_k"], - "USE_RERANKING": config["processing"]["text"]["use_reranking"], - "FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"], - "USE_UNSTRUCTURED_API": config["processing"]["unstructured"]["use_api"], - # Auth settings - "JWT_ALGORITHM": config["auth"]["jwt_algorithm"], + em = "'{missing_value}' needed if '{field}' is set to '{value}'" + # load api config + api_config = { + "HOST": config["api"]["host"], + "PORT": int(config["api"]["port"]), + "RELOAD": bool(config["api"]["reload"]), } + # load auth config + auth_config = { + "JWT_ALGORITHM": config["auth"]["jwt_algorithm"], + "JWT_SECRET_KEY": os.environ["JWT_SECRET_KEY"], + } + + # load completion config + completion_config = { + "COMPLETION_PROVIDER": config["completion"]["provider"], + "COMPLETION_MODEL": config["completion"]["model_name"], + } + match completion_config["COMPLETION_PROVIDER"]: + case "openai" if "OPENAI_API_KEY" in os.environ: + completion_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]}) + case "openai": + msg = em.format( + missing_value="OPENAI_API_KEY", field="completion.provider", value="openai" + ) + raise ValueError(msg) + case "ollama" if "base_url" in config["completion"]: + completion_config.update( + {"COMPLETION_OLLAMA_BASE_URL": config["completion"]["base_url"]} + ) + case "ollama": + msg = em.format(missing_value="base_url", field="completion.provider", value="ollama") + raise ValueError(msg) + case _: + prov = completion_config["COMPLETION_PROVIDER"] + raise ValueError(f"Unknown completion provider selected: '{prov}'") + + # load database config + database_config = {"DATABASE_PROVIDER": config["database"]["provider"]} + match database_config["DATABASE_PROVIDER"]: + case "mongodb": + database_config.update( + { + "DATABASE_NAME": config["database"]["database_name"], + "COLLECTION_NAME": config["database"]["collection_name"], + } + ) + case "postgres" if "POSTGRES_URI" in os.environ: + database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]}) + case "postgres": + msg = em.format( + missing_value="POSTGRES_URI", field="database.provider", value="postgres" + ) + raise ValueError(msg) + case _: + prov = database_config["DATABASE_PROVIDER"] + raise ValueError(f"Unknown database provider selected: '{prov}'") + + # load embedding config + embedding_config = { + "EMBEDDING_PROVIDER": config["embedding"]["provider"], + "EMBEDDING_MODEL": config["embedding"]["model_name"], + "VECTOR_DIMENSIONS": config["embedding"]["dimensions"], + "EMBEDDING_SIMILARITY_METRIC": config["embedding"]["similarity_metric"], + } + match embedding_config["EMBEDDING_PROVIDER"]: + case "openai" if "OPENAI_API_KEY" in os.environ: + embedding_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]}) + case "openai": + msg = em.format( + missing_value="OPENAI_API_KEY", field="embedding.provider", value="openai" + ) + raise ValueError(msg) + case "ollama" if "base_url" in config["embedding"]: + embedding_config.update({"EMBEDDING_OLLAMA_BASE_URL": config["embedding"]["base_url"]}) + case "ollama": + msg = em.format(missing_value="base_url", field="embedding.provider", value="ollama") + raise ValueError(msg) + case _: + prov = embedding_config["EMBEDDING_PROVIDER"] + raise ValueError(f"Unknown embedding provider selected: '{prov}'") + + # load parser config + parser_config = { + "PARSER_PROVIDER": config["parser"]["provider"], + "CHUNK_SIZE": config["parser"]["chunk_size"], + "CHUNK_OVERLAP": config["parser"]["chunk_overlap"], + "USE_UNSTRUCTURED_API": config["parser"]["use_unstructured_api"], + } + if parser_config["USE_UNSTRUCTURED_API"] and "UNSTRUCTURED_API_KEY" not in os.environ: + msg = em.format( + missing_value="UNSTRUCTURED_API_KEY", field="parser.use_unstructured_api", value="true" + ) + raise ValueError(msg) + elif parser_config["USE_UNSTRUCTURED_API"]: + parser_config.update({"UNSTRUCTURED_API_KEY": os.environ["UNSTRUCTURED_API_KEY"]}) + + # load reranker config + reranker_config = {"USE_RERANKING": config["reranker"]["use_reranker"]} + if reranker_config["USE_RERANKING"]: + reranker_config.update( + { + "RERANKER_PROVIDER": config["reranker"]["provider"], + "RERANKER_MODEL": config["reranker"]["model_name"], + "RERANKER_QUERY_MAX_LENGTH": config["reranker"]["query_max_length"], + "RERANKER_PASSAGE_MAX_LENGTH": config["reranker"]["passage_max_length"], + "RERANKER_USE_FP16": config["reranker"]["use_fp16"], + "RERANKER_DEVICE": config["reranker"]["device"], + } + ) + + # load storage config + storage_config = {"STORAGE_PROVIDER": config["storage"]["provider"]} + match storage_config["STORAGE_PROVIDER"]: + case "local": + storage_config.update({"STORAGE_PATH": config["storage"]["storage_path"]}) + case "aws-s3" if all( + key in os.environ for key in ["AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"] + ): + storage_config.update( + { + "AWS_REGION": config["storage"]["region"], + "S3_BUCKET": config["storage"]["bucket_name"], + "AWS_ACCESS_KEY": os.environ["AWS_ACCESS_KEY"], + "AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"], + } + ) + case "aws-s3": + msg = em.format( + missing_value="AWS credentials", field="storage.provider", value="aws-s3" + ) + raise ValueError(msg) + case _: + prov = storage_config["STORAGE_PROVIDER"] + raise ValueError(f"Unknown storage provider selected: '{prov}'") + + # load vector store config + vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]} + match vector_store_config["VECTOR_STORE_PROVIDER"]: + case "mongodb": + vector_store_config.update( + { + "VECTOR_STORE_DATABASE_NAME": config["vector_store"]["database_name"], + "VECTOR_STORE_COLLECTION_NAME": config["vector_store"]["collection_name"], + } + ) + case "pgvector": + if "POSTGRES_URI" not in os.environ: + msg = em.format( + missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector" + ) + raise ValueError(msg) + case _: + prov = vector_store_config["VECTOR_STORE_PROVIDER"] + raise ValueError(f"Unknown vector store provider selected: '{prov}'") + + settings_dict = {} + settings_dict.update( + **api_config, + **auth_config, + **completion_config, + **database_config, + **embedding_config, + **parser_config, + **reranker_config, + **storage_config, + **vector_store_config, + ) return Settings(**settings_dict) diff --git a/core/reranker/bge_reranker.py b/core/reranker/flag_reranker.py similarity index 93% rename from core/reranker/bge_reranker.py rename to core/reranker/flag_reranker.py index 1b6618d..c5c6af4 100644 --- a/core/reranker/bge_reranker.py +++ b/core/reranker/flag_reranker.py @@ -5,8 +5,8 @@ from core.models.chunk import DocumentChunk from core.reranker.base_reranker import BaseReranker -class BGEReranker(BaseReranker): - """BGE reranker implementation using FlagEmbedding""" +class FlagReranker(BaseReranker): + """Reranker implementation using FlagEmbedding""" def __init__( self, @@ -16,7 +16,7 @@ class BGEReranker(BaseReranker): use_fp16: bool = True, device: Optional[str] = None, ): - """Initialize BGE reranker""" + """Initialize flag reranker""" devices = [device] if device else None self.reranker = FlagAutoReranker.from_finetuned( model_name_or_path=model_name, diff --git a/core/tests/unit/test_reranker.py b/core/tests/unit/test_reranker.py index 656e232..fed68d3 100644 --- a/core/tests/unit/test_reranker.py +++ b/core/tests/unit/test_reranker.py @@ -2,7 +2,7 @@ import pytest from typing import List from core.models.chunk import DocumentChunk -from core.reranker.bge_reranker import BGEReranker +from core.reranker.flag_reranker import FlagReranker from core.config import get_settings @@ -35,9 +35,9 @@ def sample_chunks() -> List[DocumentChunk]: @pytest.fixture def reranker(): - """Fixture to create and reuse a BGE reranker instance""" + """Fixture to create and reuse a flag reranker instance""" settings = get_settings() - return BGEReranker( + return FlagReranker( model_name=settings.RERANKER_MODEL, device=settings.RERANKER_DEVICE, use_fp16=settings.RERANKER_USE_FP16, diff --git a/databridge.toml b/databridge.toml new file mode 100644 index 0000000..a591604 --- /dev/null +++ b/databridge.toml @@ -0,0 +1,84 @@ +[api] +host = "localhost" +port = 8000 +reload = true + +[auth] +jwt_algorithm = "HS256" + +[completion] +provider = "ollama" +model_name = "llama3.2" +default_max_tokens = "1000" +default_temperature = 0.7 +base_url = "http://localhost:11434" + +# [completion] +# provider = "openai" +# model_name = "gpt4o-mini" +# default_max_tokens = "1000" +# default_temperature = 0.7 + +[database] +provider = "postgres" + +# [database] +# provider = "mongodb" +# database_name = "databridge" +# collection_name = "documents" + +[embedding] +provider = "ollama" +model_name = "nomic-embed-text" +dimensions = 768 +similarity_metric = "cosine" +base_url = "http://localhost:11434" + +# [embedding] +# provider = "openai" +# model_name = "text-embedding-3-small" +# dimensions = 1536 +# similarity_metric = "dotProduct" + + +[parser] +provider = "unstructured" +chunk_size = 1000 +chunk_overlap = 200 +use_unstructured_api = false + +# [parser] +# provider = "combined" | "contextual" +# chunk_size = 1000 +# chunk_overlap = 200 +# use_unstructured_api = false +# frame_sample_rate = 120 + +[reranker] +use_reranker = true +provider = "flag" +model_name = "BAAI/bge-reranker-large" +query_max_length = 256 +passage_max_length = 512 +use_fp16 = true +device = "mps" + +# [reranker] +# use_reranker = false + +[storage] +provider = "local" +storage_path = "../storage" + +# [storage] +# provider = "aws-s3" +# region = "us-east-2" +# bucket_name = "databridge-s3-storage" + +[vector_store] +provider = "pgvector" + +# [vector_store] +# provider = "mongodb" +# database_name = "databridge" +# collection_name = "document_chunks" diff --git a/quick_setup.py b/quick_setup.py index fe9b210..bd4aec0 100644 --- a/quick_setup.py +++ b/quick_setup.py @@ -37,33 +37,33 @@ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(messag console_handler.setFormatter(formatter) LOGGER.addHandler(console_handler) -# Load configuration from config.toml -config_path = Path("config.toml") +# Load configuration from databridge.toml +config_path = Path("databridge.toml") with open(config_path, "rb") as f: CONFIG = tomli.load(f) - LOGGER.info("Loaded configuration from config.toml") + LOGGER.info("Loaded configuration from databridge.toml") # Extract configuration values -STORAGE_PROVIDER = CONFIG["service"]["components"]["storage"] -DATABASE_PROVIDER = CONFIG["service"]["components"]["database"] -DATABASE_NAME = CONFIG["database"][DATABASE_PROVIDER]["database_name"] +STORAGE_PROVIDER = CONFIG["storage"]["provider"] +DATABASE_PROVIDER = CONFIG["database"]["provider"] # MongoDB specific config -DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"] -CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"] -VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"] -VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"] -SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"] - -# PostgreSQL specific config -DOCUMENTS_TABLE = CONFIG["database"]["postgres"]["documents_table"] -CHUNKS_TABLE = CONFIG["database"]["postgres"]["chunks_table"] +if "mongodb" in CONFIG["database"]: + DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"] + DOCUMENTS_COLLECTION = "documents" + CHUNKS_COLLECTION = "document_chunks" + if "mongodb" in CONFIG["vector_store"]: + VECTOR_DIMENSIONS = CONFIG["embedding"]["dimensions"] + VECTOR_INDEX_NAME = "vector_index" + SIMILARITY_METRIC = CONFIG["embedding"]["similarity_metric"] # Extract storage-specific configuration -DEFAULT_REGION = CONFIG["storage"]["aws"]["region"] if STORAGE_PROVIDER == "aws-s3" else None -DEFAULT_BUCKET_NAME = ( - CONFIG["storage"]["aws"]["bucket_name"] if STORAGE_PROVIDER == "aws-s3" else None -) +if STORAGE_PROVIDER == "aws-s3": + DEFAULT_REGION = CONFIG["storage"]["region"] + DEFAULT_BUCKET_NAME = CONFIG["storage"]["bucket_name"] +else: + DEFAULT_REGION = None + DEFAULT_BUCKET_NAME = None def create_s3_bucket(bucket_name, region=DEFAULT_REGION): @@ -248,10 +248,8 @@ def setup_postgres(): LOGGER.info("Created all PostgreSQL tables and indexes") # Create vector index with configuration from settings - table_name = CONFIG["vector_store"]["pgvector"]["table_name"] - index_method = CONFIG["vector_store"]["pgvector"]["index_method"] - index_lists = CONFIG["vector_store"]["pgvector"]["index_lists"] - dimensions = CONFIG["vector_store"]["pgvector"]["dimensions"] + table_name = "document_chunks" # Default table name for pgvector + dimensions = CONFIG["embedding"]["dimensions"] # First, alter the embedding column to be a vector alter_sql = f""" @@ -262,15 +260,14 @@ def setup_postgres(): await conn.execute(text(alter_sql)) LOGGER.info(f"Altered embedding column to be vector({dimensions})") - # Then create the vector index - if index_method == "ivfflat": - index_sql = f""" - CREATE INDEX IF NOT EXISTS vector_idx - ON {table_name} USING ivfflat (embedding vector_l2_ops) - WITH (lists = {index_lists}); - """ - await conn.execute(text(index_sql)) - LOGGER.info(f"Created IVFFlat index on {table_name} with {index_lists} lists") + # Create the vector index + index_sql = f""" + CREATE INDEX IF NOT EXISTS vector_idx + ON {table_name} USING ivfflat (embedding vector_l2_ops) + WITH (lists = 100); + """ + await conn.execute(text(index_sql)) + LOGGER.info(f"Created IVFFlat index on {table_name}") await engine.dispose() LOGGER.info("PostgreSQL setup completed successfully")