update configuration style to support easy model editing

This commit is contained in:
Arnav Agrawal 2024-12-27 11:19:07 +05:30
parent 36ccca0332
commit 418054e9a3
4 changed files with 200 additions and 92 deletions

View File

@ -1,32 +1,53 @@
[aws] # Core Service Components
default_region = "us-east-2" [service]
default_bucket_name = "databridge-s3-storage"
[mongodb]
database_name = "DataBridgeTest"
documents_collection = "documents"
chunks_collection = "document_chunks"
[mongodb.vector]
dimensions = 1536
index_name = "vector_index"
[model]
embedding_model = "text-embedding-3-small"
completion_model = "gpt-3.5-turbo"
[document_processing]
chunk_size = 1000
chunk_overlap = 200
default_k = 4
[video_processing]
frame_sample_rate = 120
[server]
host = "localhost" host = "localhost"
port = 8000 port = 8000
reload = false reload = false
[service.components]
storage = "aws-s3"
database = "mongodb"
vector_store = "mongodb"
embedding = "openai"
completion = "openai"
parser = "combined"
# Storage Configuration
[storage.aws]
region = "us-east-2"
bucket_name = "databridge-s3-storage"
# Database Configuration
[database.mongodb]
database_name = "DataBridgeTest"
documents_collection = "documents"
chunks_collection = "document_chunks"
# Vector Store Configuration
[vector_store.mongodb]
dimensions = 1536
index_name = "vector_index"
# Model Configurations
[models]
[models.embedding]
model_name = "text-embedding-3-small"
[models.completion]
model_name = "gpt-4o-mini"
default_max_tokens = 1000
default_temperature = 0.7
# Document Processing
[processing]
[processing.text]
chunk_size = 1000
chunk_overlap = 200
default_k = 4
[processing.video]
frame_sample_rate = 120
# Authentication
[auth] [auth]
jwt_algorithm = "HS256" jwt_algorithm = "HS256"

View File

@ -5,6 +5,7 @@ from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import jwt import jwt
import logging import logging
from core.completion.openai_completion import OpenAICompletionModel
from core.models.request import ( from core.models.request import (
IngestTextRequest, IngestTextRequest,
RetrieveRequest, RetrieveRequest,
@ -14,14 +15,14 @@ from core.models.documents import Document, DocumentResult, ChunkResult
from core.models.auth import AuthContext, EntityType from core.models.auth import AuthContext, EntityType
from core.parser.combined_parser import CombinedParser from core.parser.combined_parser import CombinedParser
from core.completion.base_completion import CompletionResponse from core.completion.base_completion import CompletionResponse
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.services.document_service import DocumentService from core.services.document_service import DocumentService
from core.config import get_settings from core.config import get_settings
from core.database.mongo_database import MongoDatabase from core.database.mongo_database import MongoDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.storage.s3_storage import S3Storage from core.storage.s3_storage import S3Storage
from core.parser.unstructured_parser import UnstructuredAPIParser from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel from core.completion.ollama_completion import OllamaCompletionModel
from core.completion.openai_completion import OpenAICompletionModel
# Initialize FastAPI app # Initialize FastAPI app
app = FastAPI(title="DataBridge API") app = FastAPI(title="DataBridge API")
@ -39,42 +40,90 @@ app.add_middleware(
# Initialize service # Initialize service
settings = get_settings() settings = get_settings()
# Initialize components # Initialize database
database = MongoDatabase( match settings.DATABASE_PROVIDER:
uri=settings.MONGODB_URI, case "mongodb":
db_name=settings.DATABRIDGE_DB, database = MongoDatabase(
collection_name=settings.DOCUMENTS_COLLECTION, uri=settings.MONGODB_URI,
) db_name=settings.DATABRIDGE_DB,
collection_name=settings.DOCUMENTS_COLLECTION,
)
case _:
raise ValueError(f"Unsupported database provider: {settings.DATABASE_PROVIDER}")
vector_store = MongoDBAtlasVectorStore( # Initialize vector store
uri=settings.MONGODB_URI, match settings.VECTOR_STORE_PROVIDER:
database_name=settings.DATABRIDGE_DB, case "mongodb":
collection_name=settings.CHUNKS_COLLECTION, vector_store = MongoDBAtlasVectorStore(
index_name=settings.VECTOR_INDEX_NAME, uri=settings.MONGODB_URI,
) database_name=settings.DATABRIDGE_DB,
collection_name=settings.CHUNKS_COLLECTION,
index_name=settings.VECTOR_INDEX_NAME,
)
case _:
raise ValueError(
f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}"
)
storage = S3Storage( # Initialize storage
aws_access_key=settings.AWS_ACCESS_KEY, match settings.STORAGE_PROVIDER:
aws_secret_key=settings.AWS_SECRET_ACCESS_KEY, case "aws-s3":
region_name=settings.AWS_REGION, storage = S3Storage(
default_bucket=settings.S3_BUCKET, aws_access_key=settings.AWS_ACCESS_KEY,
) aws_secret_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=settings.AWS_REGION,
default_bucket=settings.S3_BUCKET,
)
case _:
raise ValueError(f"Unsupported storage provider: {settings.STORAGE_PROVIDER}")
parser = CombinedParser( # Initialize parser
unstructured_api_key=settings.UNSTRUCTURED_API_KEY, match settings.PARSER_PROVIDER:
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY, case "combined":
chunk_size=settings.CHUNK_SIZE, parser = CombinedParser(
chunk_overlap=settings.CHUNK_OVERLAP, unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
frame_sample_rate=settings.FRAME_SAMPLE_RATE, assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
) chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
frame_sample_rate=settings.FRAME_SAMPLE_RATE,
)
case "unstructured":
parser = UnstructuredAPIParser(
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
)
case _:
raise ValueError(f"Unsupported parser provider: {settings.PARSER_PROVIDER}")
embedding_model = OpenAIEmbeddingModel( # Initialize embedding model
api_key=settings.OPENAI_API_KEY, model_name=settings.EMBEDDING_MODEL match settings.EMBEDDING_PROVIDER:
) case "openai":
embedding_model = OpenAIEmbeddingModel(
api_key=settings.OPENAI_API_KEY,
model_name=settings.EMBEDDING_MODEL,
)
case _:
raise ValueError(
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
)
completion_model = OpenAICompletionModel(model_name=settings.COMPLETION_MODEL) # Initialize completion model
match settings.COMPLETION_PROVIDER:
case "ollama":
completion_model = OllamaCompletionModel(
model_name=settings.COMPLETION_MODEL,
)
case "openai":
completion_model = OpenAICompletionModel(
model_name=settings.COMPLETION_MODEL,
)
case _:
raise ValueError(
f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}"
)
# Initialize document service # Initialize document service with configured components
document_service = DocumentService( document_service = DocumentService(
database=database, database=database,
vector_store=vector_store, vector_store=vector_store,

View File

@ -8,7 +8,7 @@ from dotenv import load_dotenv
class Settings(BaseSettings): class Settings(BaseSettings):
"""DataBridge configuration settings.""" """DataBridge configuration settings."""
# Required environment variables # Required environment variables (referenced in config.toml)
MONGODB_URI: str = Field(..., env="MONGODB_URI") MONGODB_URI: str = Field(..., env="MONGODB_URI")
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY") UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
@ -17,25 +17,47 @@ class Settings(BaseSettings):
AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY") AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY")
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY") JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
# Values from config.toml with defaults # Service settings
AWS_REGION: str = "us-east-2"
S3_BUCKET: str = "databridge-s3-storage"
DATABRIDGE_DB: str = "databridge"
DOCUMENTS_COLLECTION: str = "documents"
CHUNKS_COLLECTION: str = "document_chunks"
VECTOR_INDEX_NAME: str = "vector_index"
VECTOR_DIMENSIONS: int = 1536
EMBEDDING_MODEL: str = "text-embedding-3-small"
COMPLETION_MODEL: str = "gpt-3.5-turbo"
CHUNK_SIZE: int = 1000
CHUNK_OVERLAP: int = 200
DEFAULT_K: int = 4
HOST: str = "localhost" HOST: str = "localhost"
PORT: int = 8000 PORT: int = 8000
RELOAD: bool = False RELOAD: bool = False
JWT_ALGORITHM: str = "HS256"
# Component selection
STORAGE_PROVIDER: str = "aws-s3"
DATABASE_PROVIDER: str = "mongodb"
VECTOR_STORE_PROVIDER: str = "mongodb"
EMBEDDING_PROVIDER: str = "openai"
COMPLETION_PROVIDER: str = "ollama"
PARSER_PROVIDER: str = "combined"
# Storage settings
AWS_REGION: str = "us-east-2"
S3_BUCKET: str = "databridge-s3-storage"
# Database settings
DATABRIDGE_DB: str = "DataBridgeTest"
DOCUMENTS_COLLECTION: str = "documents"
CHUNKS_COLLECTION: str = "document_chunks"
# Vector store settings
VECTOR_INDEX_NAME: str = "vector_index"
VECTOR_DIMENSIONS: int = 1536
# Model settings
EMBEDDING_MODEL: str = "text-embedding-3-small"
COMPLETION_MODEL: str = "llama3.1"
COMPLETION_MAX_TOKENS: int = 1000
COMPLETION_TEMPERATURE: float = 0.7
# Processing settings
CHUNK_SIZE: int = 1000
CHUNK_OVERLAP: int = 200
DEFAULT_K: int = 4
FRAME_SAMPLE_RATE: int = 120 FRAME_SAMPLE_RATE: int = 120
# Auth settings
JWT_ALGORITHM: str = "HS256"
@lru_cache() @lru_cache()
def get_settings() -> Settings: def get_settings() -> Settings:
@ -48,23 +70,39 @@ def get_settings() -> Settings:
# Map config.toml values to settings # Map config.toml values to settings
settings_dict = { settings_dict = {
"AWS_REGION": config["aws"]["default_region"], # Service settings
"S3_BUCKET": config["aws"]["default_bucket_name"], "HOST": config["service"]["host"],
"DATABRIDGE_DB": config["mongodb"]["database_name"], "PORT": config["service"]["port"],
"DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"], "RELOAD": config["service"]["reload"],
"CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"], # Component selection
"VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"], "STORAGE_PROVIDER": config["service"]["components"]["storage"],
"VECTOR_DIMENSIONS": config["mongodb"]["vector"]["dimensions"], "DATABASE_PROVIDER": config["service"]["components"]["database"],
"EMBEDDING_MODEL": config["model"]["embedding_model"], "VECTOR_STORE_PROVIDER": config["service"]["components"]["vector_store"],
"COMPLETION_MODEL": config["model"]["completion_model"], "EMBEDDING_PROVIDER": config["service"]["components"]["embedding"],
"CHUNK_SIZE": config["document_processing"]["chunk_size"], "COMPLETION_PROVIDER": config["service"]["components"]["completion"],
"CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"], "PARSER_PROVIDER": config["service"]["components"]["parser"],
"DEFAULT_K": config["document_processing"]["default_k"], # Storage settings
"HOST": config["server"]["host"], "AWS_REGION": config["storage"]["aws"]["region"],
"PORT": config["server"]["port"], "S3_BUCKET": config["storage"]["aws"]["bucket_name"],
"RELOAD": config["server"]["reload"], # Database settings
"DATABRIDGE_DB": config["database"]["mongodb"]["database_name"],
"DOCUMENTS_COLLECTION": config["database"]["mongodb"]["documents_collection"],
"CHUNKS_COLLECTION": config["database"]["mongodb"]["chunks_collection"],
# Vector store settings
"VECTOR_INDEX_NAME": config["vector_store"]["mongodb"]["index_name"],
"VECTOR_DIMENSIONS": config["vector_store"]["mongodb"]["dimensions"],
# Model settings
"EMBEDDING_MODEL": config["models"]["embedding"]["model_name"],
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
# Processing settings
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
"DEFAULT_K": config["processing"]["text"]["default_k"],
"FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"],
# Auth settings
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"], "JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
"FRAME_SAMPLE_RATE": config["video_processing"]["frame_sample_rate"],
} }
return Settings(**settings_dict) return Settings(**settings_dict)

View File

@ -14,7 +14,7 @@ from ..models.auth import AuthContext
from core.database.base_database import BaseDatabase from core.database.base_database import BaseDatabase
from core.storage.base_storage import BaseStorage from core.storage.base_storage import BaseStorage
from core.vector_store.base_vector_store import BaseVectorStore from core.vector_store.base_vector_store import BaseVectorStore
from core.embedding_model.base_embedding_model import BaseEmbeddingModel from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.parser.base_parser import BaseParser from core.parser.base_parser import BaseParser
from core.completion.base_completion import BaseCompletionModel from core.completion.base_completion import BaseCompletionModel
from core.completion.base_completion import CompletionRequest, CompletionResponse from core.completion.base_completion import CompletionRequest, CompletionResponse