import os from typing import Literal, Optional from pydantic_settings import BaseSettings from functools import lru_cache import tomli from dotenv import load_dotenv class Settings(BaseSettings): """DataBridge configuration settings.""" # environment variables: JWT_SECRET_KEY: str POSTGRES_URI: Optional[str] = None MONGODB_URI: Optional[str] = None UNSTRUCTURED_API_KEY: Optional[str] = None AWS_ACCESS_KEY: Optional[str] = None AWS_SECRET_ACCESS_KEY: Optional[str] = None OPENAI_API_KEY: Optional[str] = None ANTHROPIC_API_KEY: Optional[str] = None # configuration variables: ## api: HOST: str PORT: int RELOAD: bool ## auth: JWT_ALGORITHM: str ## completion: COMPLETION_PROVIDER: Literal["ollama", "openai"] COMPLETION_MODEL: str COMPLETION_MAX_TOKENS: Optional[str] = None COMPLETION_TEMPERATURE: Optional[float] = None COMPLETION_OLLAMA_BASE_URL: Optional[str] = None ## database DATABASE_PROVIDER: Literal["postgres", "mongodb"] DATABASE_NAME: Optional[str] = None DOCUMENTS_COLLECTION: Optional[str] = None ## embedding EMBEDDING_PROVIDER: Literal["ollama", "openai"] EMBEDDING_MODEL: str VECTOR_DIMENSIONS: int EMBEDDING_SIMILARITY_METRIC: Literal["cosine", "dotProduct"] EMBEDDING_OLLAMA_BASE_URL: Optional[str] = None ## parser PARSER_PROVIDER: Literal["unstructured", "combined", "contextual"] CHUNK_SIZE: int CHUNK_OVERLAP: int USE_UNSTRUCTURED_API: bool FRAME_SAMPLE_RATE: Optional[int] = None ## reranker USE_RERANKING: bool RERANKER_PROVIDER: Optional[Literal["flag"]] = None RERANKER_MODEL: Optional[str] = None RERANKER_QUERY_MAX_LENGTH: Optional[int] = None RERANKER_PASSAGE_MAX_LENGTH: Optional[int] = None RERANKER_USE_FP16: Optional[bool] = None RERANKER_DEVICE: Optional[str] = None ## storage STORAGE_PROVIDER: Literal["local", "aws-s3"] STORAGE_PATH: Optional[str] = None AWS_REGION: Optional[str] = None S3_BUCKET: Optional[str] = None ## vector store VECTOR_STORE_PROVIDER: Literal["pgvector", "mongodb"] VECTOR_STORE_DATABASE_NAME: Optional[str] = None VECTOR_STORE_COLLECTION_NAME: Optional[str] = None @lru_cache() def get_settings() -> Settings: """Get cached settings instance.""" load_dotenv(override=True) # Load config.toml with open("databridge.toml", "rb") as f: config = tomli.load(f) em = "'{missing_value}' needed if '{field}' is set to '{value}'" # load api config api_config = { "HOST": config["api"]["host"], "PORT": int(config["api"]["port"]), "RELOAD": bool(config["api"]["reload"]), } # load auth config auth_config = { "JWT_ALGORITHM": config["auth"]["jwt_algorithm"], "JWT_SECRET_KEY": os.environ["JWT_SECRET_KEY"], } # load completion config completion_config = { "COMPLETION_PROVIDER": config["completion"]["provider"], "COMPLETION_MODEL": config["completion"]["model_name"], } match completion_config["COMPLETION_PROVIDER"]: case "openai" if "OPENAI_API_KEY" in os.environ: completion_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]}) case "openai": msg = em.format( missing_value="OPENAI_API_KEY", field="completion.provider", value="openai" ) raise ValueError(msg) case "ollama" if "base_url" in config["completion"]: completion_config.update( {"COMPLETION_OLLAMA_BASE_URL": config["completion"]["base_url"]} ) case "ollama": msg = em.format(missing_value="base_url", field="completion.provider", value="ollama") raise ValueError(msg) case _: prov = completion_config["COMPLETION_PROVIDER"] raise ValueError(f"Unknown completion provider selected: '{prov}'") # load database config database_config = {"DATABASE_PROVIDER": config["database"]["provider"]} match database_config["DATABASE_PROVIDER"]: case "mongodb": database_config.update( { "DATABASE_NAME": config["database"]["database_name"], "COLLECTION_NAME": config["database"]["collection_name"], } ) case "postgres" if "POSTGRES_URI" in os.environ: database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]}) case "postgres": msg = em.format( missing_value="POSTGRES_URI", field="database.provider", value="postgres" ) raise ValueError(msg) case _: prov = database_config["DATABASE_PROVIDER"] raise ValueError(f"Unknown database provider selected: '{prov}'") # load embedding config embedding_config = { "EMBEDDING_PROVIDER": config["embedding"]["provider"], "EMBEDDING_MODEL": config["embedding"]["model_name"], "VECTOR_DIMENSIONS": config["embedding"]["dimensions"], "EMBEDDING_SIMILARITY_METRIC": config["embedding"]["similarity_metric"], } match embedding_config["EMBEDDING_PROVIDER"]: case "openai" if "OPENAI_API_KEY" in os.environ: embedding_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]}) case "openai": msg = em.format( missing_value="OPENAI_API_KEY", field="embedding.provider", value="openai" ) raise ValueError(msg) case "ollama" if "base_url" in config["embedding"]: embedding_config.update({"EMBEDDING_OLLAMA_BASE_URL": config["embedding"]["base_url"]}) case "ollama": msg = em.format(missing_value="base_url", field="embedding.provider", value="ollama") raise ValueError(msg) case _: prov = embedding_config["EMBEDDING_PROVIDER"] raise ValueError(f"Unknown embedding provider selected: '{prov}'") # load parser config parser_config = { "PARSER_PROVIDER": config["parser"]["provider"], "CHUNK_SIZE": config["parser"]["chunk_size"], "CHUNK_OVERLAP": config["parser"]["chunk_overlap"], "USE_UNSTRUCTURED_API": config["parser"]["use_unstructured_api"], } if parser_config["USE_UNSTRUCTURED_API"] and "UNSTRUCTURED_API_KEY" not in os.environ: msg = em.format( missing_value="UNSTRUCTURED_API_KEY", field="parser.use_unstructured_api", value="true" ) raise ValueError(msg) elif parser_config["USE_UNSTRUCTURED_API"]: parser_config.update({"UNSTRUCTURED_API_KEY": os.environ["UNSTRUCTURED_API_KEY"]}) # load reranker config reranker_config = {"USE_RERANKING": config["reranker"]["use_reranker"]} if reranker_config["USE_RERANKING"]: reranker_config.update( { "RERANKER_PROVIDER": config["reranker"]["provider"], "RERANKER_MODEL": config["reranker"]["model_name"], "RERANKER_QUERY_MAX_LENGTH": config["reranker"]["query_max_length"], "RERANKER_PASSAGE_MAX_LENGTH": config["reranker"]["passage_max_length"], "RERANKER_USE_FP16": config["reranker"]["use_fp16"], "RERANKER_DEVICE": config["reranker"]["device"], } ) # load storage config storage_config = {"STORAGE_PROVIDER": config["storage"]["provider"]} match storage_config["STORAGE_PROVIDER"]: case "local": storage_config.update({"STORAGE_PATH": config["storage"]["storage_path"]}) case "aws-s3" if all( key in os.environ for key in ["AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"] ): storage_config.update( { "AWS_REGION": config["storage"]["region"], "S3_BUCKET": config["storage"]["bucket_name"], "AWS_ACCESS_KEY": os.environ["AWS_ACCESS_KEY"], "AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"], } ) case "aws-s3": msg = em.format( missing_value="AWS credentials", field="storage.provider", value="aws-s3" ) raise ValueError(msg) case _: prov = storage_config["STORAGE_PROVIDER"] raise ValueError(f"Unknown storage provider selected: '{prov}'") # load vector store config vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]} match vector_store_config["VECTOR_STORE_PROVIDER"]: case "mongodb": vector_store_config.update( { "VECTOR_STORE_DATABASE_NAME": config["vector_store"]["database_name"], "VECTOR_STORE_COLLECTION_NAME": config["vector_store"]["collection_name"], } ) case "pgvector": if "POSTGRES_URI" not in os.environ: msg = em.format( missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector" ) raise ValueError(msg) case _: prov = vector_store_config["VECTOR_STORE_PROVIDER"] raise ValueError(f"Unknown vector store provider selected: '{prov}'") settings_dict = {} settings_dict.update( **api_config, **auth_config, **completion_config, **database_config, **embedding_config, **parser_config, **reranker_config, **storage_config, **vector_store_config, ) return Settings(**settings_dict)