mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Config improvements (#17)
This commit is contained in:
parent
f5666155c1
commit
f72f6f0249
84
config.toml
84
config.toml
@ -1,84 +0,0 @@
|
||||
# Core Service Components
|
||||
[service]
|
||||
host = "localhost"
|
||||
port = 8000
|
||||
reload = false
|
||||
|
||||
[service.components]
|
||||
storage = "local" # "aws-s3"
|
||||
database = "postgres" # "postgres", "mongodb"
|
||||
vector_store = "pgvector" # "mongodb"
|
||||
embedding = "ollama" # "openai", "ollama"
|
||||
completion = "ollama" # "openai", "ollama"
|
||||
parser = "combined" # "combined", "unstructured", "contextual"
|
||||
reranker = "bge" # "bge"
|
||||
|
||||
# Storage Configuration
|
||||
[storage.local]
|
||||
path = "./storage"
|
||||
|
||||
[storage.aws]
|
||||
region = "us-east-2"
|
||||
bucket_name = "databridge-s3-storage"
|
||||
|
||||
# Database Configuration
|
||||
[database.postgres]
|
||||
database_name = "databridge"
|
||||
documents_table = "documents"
|
||||
chunks_table = "document_chunks"
|
||||
|
||||
[database.mongodb]
|
||||
database_name = "databridge"
|
||||
documents_collection = "documents"
|
||||
chunks_collection = "document_chunks"
|
||||
|
||||
# Vector Store Configuration
|
||||
[vector_store.mongodb]
|
||||
dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small
|
||||
index_name = "vector_index"
|
||||
similarity_metric = "cosine"
|
||||
|
||||
[vector_store.pgvector]
|
||||
dimensions = 768 # 768 for nomic-embed-text, 1536 for text-embedding-3-small
|
||||
table_name = "vector_embeddings"
|
||||
index_method = "ivfflat" # "ivfflat" or "hnsw"
|
||||
index_lists = 100 # Number of lists for ivfflat index
|
||||
probes = 10 # Number of probes for ivfflat search
|
||||
|
||||
# Model Configurations
|
||||
[models]
|
||||
[models.embedding]
|
||||
model_name = "nomic-embed-text" # "text-embedding-3-small", "nomic-embed-text"
|
||||
|
||||
[models.completion]
|
||||
model_name = "llama3.2" # "gpt-4o-mini", "llama3.1", etc.
|
||||
default_max_tokens = 1000
|
||||
default_temperature = 0.7
|
||||
|
||||
[models.ollama]
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
[models.reranker]
|
||||
model_name = "BAAI/bge-reranker-large" # "BAAI/bge-reranker-v2-gemma", "BAAI/bge-reranker-large"
|
||||
device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU
|
||||
use_fp16 = true
|
||||
query_max_length = 256
|
||||
passage_max_length = 512
|
||||
|
||||
# Document Processing
|
||||
[processing]
|
||||
[processing.text]
|
||||
chunk_size = 1000
|
||||
chunk_overlap = 200
|
||||
default_k = 4
|
||||
use_reranking = true # Whether to use reranking by default
|
||||
|
||||
[processing.video]
|
||||
frame_sample_rate = 120
|
||||
|
||||
[processing.unstructured]
|
||||
use_api = false
|
||||
|
||||
# Authentication
|
||||
[auth]
|
||||
jwt_algorithm = "HS256"
|
10
core/api.py
10
core/api.py
@ -29,7 +29,7 @@ from core.storage.local_storage import LocalStorage
|
||||
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
|
||||
from core.completion.ollama_completion import OllamaCompletionModel
|
||||
from core.parser.contextual_parser import ContextualParser
|
||||
from core.reranker.bge_reranker import BGEReranker
|
||||
from core.reranker.flag_reranker import FlagReranker
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="DataBridge API")
|
||||
@ -145,7 +145,7 @@ match settings.PARSER_PROVIDER:
|
||||
match settings.EMBEDDING_PROVIDER:
|
||||
case "ollama":
|
||||
embedding_model = OllamaEmbeddingModel(
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
base_url=settings.EMBEDDING_OLLAMA_BASE_URL,
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
)
|
||||
case "openai":
|
||||
@ -163,7 +163,7 @@ match settings.COMPLETION_PROVIDER:
|
||||
case "ollama":
|
||||
completion_model = OllamaCompletionModel(
|
||||
model_name=settings.COMPLETION_MODEL,
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
base_url=settings.COMPLETION_OLLAMA_BASE_URL,
|
||||
)
|
||||
case "openai":
|
||||
if not settings.OPENAI_API_KEY:
|
||||
@ -176,8 +176,8 @@ match settings.COMPLETION_PROVIDER:
|
||||
|
||||
# Initialize reranker
|
||||
match settings.RERANKER_PROVIDER:
|
||||
case "bge":
|
||||
reranker = BGEReranker(
|
||||
case "flag":
|
||||
reranker = FlagReranker(
|
||||
model_name=settings.RERANKER_MODEL,
|
||||
device=settings.RERANKER_DEVICE,
|
||||
use_fp16=settings.RERANKER_USE_FP16,
|
||||
|
352
core/config.py
352
core/config.py
@ -1,5 +1,5 @@
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
import tomli
|
||||
@ -9,74 +9,70 @@ from dotenv import load_dotenv
|
||||
class Settings(BaseSettings):
|
||||
"""DataBridge configuration settings."""
|
||||
|
||||
# Required environment variables (referenced in config.toml)
|
||||
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
|
||||
MONGODB_URI: Optional[str] = Field(None, env="MONGODB_URI")
|
||||
POSTGRES_URI: Optional[str] = Field(None, env="POSTGRES_URI")
|
||||
# environment variables:
|
||||
JWT_SECRET_KEY: str
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
MONGODB_URI: Optional[str] = None
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = None
|
||||
AWS_ACCESS_KEY: Optional[str] = None
|
||||
AWS_SECRET_ACCESS_KEY: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = Field(None, env="UNSTRUCTURED_API_KEY")
|
||||
AWS_ACCESS_KEY: Optional[str] = Field(None, env="AWS_ACCESS_KEY")
|
||||
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(None, env="AWS_SECRET_ACCESS_KEY")
|
||||
ASSEMBLYAI_API_KEY: Optional[str] = Field(None, env="ASSEMBLYAI_API_KEY")
|
||||
OPENAI_API_KEY: Optional[str] = Field(None, env="OPENAI_API_KEY")
|
||||
ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
|
||||
# configuration variables:
|
||||
## api:
|
||||
HOST: str
|
||||
PORT: int
|
||||
RELOAD: bool
|
||||
|
||||
# Service settings
|
||||
HOST: str = "localhost"
|
||||
PORT: int = 8000
|
||||
RELOAD: bool = False
|
||||
## auth:
|
||||
JWT_ALGORITHM: str
|
||||
|
||||
# Component selection
|
||||
STORAGE_PROVIDER: str = "local"
|
||||
DATABASE_PROVIDER: str = "mongodb"
|
||||
VECTOR_STORE_PROVIDER: str = "mongodb"
|
||||
EMBEDDING_PROVIDER: str = "openai"
|
||||
COMPLETION_PROVIDER: str = "ollama"
|
||||
PARSER_PROVIDER: str = "combined"
|
||||
RERANKER_PROVIDER: str = "bge"
|
||||
## completion:
|
||||
COMPLETION_PROVIDER: Literal["ollama", "openai"]
|
||||
COMPLETION_MODEL: str
|
||||
COMPLETION_MAX_TOKENS: Optional[str] = None
|
||||
COMPLETION_TEMPERATURE: Optional[float] = None
|
||||
COMPLETION_OLLAMA_BASE_URL: Optional[str] = None
|
||||
|
||||
# Storage settings
|
||||
STORAGE_PATH: str = "./storage"
|
||||
AWS_REGION: str = "us-east-2"
|
||||
S3_BUCKET: str = "databridge-s3-storage"
|
||||
## database
|
||||
DATABASE_PROVIDER: Literal["postgres", "mongodb"]
|
||||
DATABASE_NAME: Optional[str] = None
|
||||
DOCUMENTS_COLLECTION: Optional[str] = None
|
||||
|
||||
# Database settings
|
||||
DATABRIDGE_DB: str = "DataBridgeTest"
|
||||
DOCUMENTS_TABLE: str = "documents"
|
||||
CHUNKS_TABLE: str = "document_chunks"
|
||||
DOCUMENTS_COLLECTION: str = "documents"
|
||||
CHUNKS_COLLECTION: str = "document_chunks"
|
||||
## embedding
|
||||
EMBEDDING_PROVIDER: Literal["ollama", "openai"]
|
||||
EMBEDDING_MODEL: str
|
||||
VECTOR_DIMENSIONS: int
|
||||
EMBEDDING_SIMILARITY_METRIC: Literal["cosine", "dotProduct"]
|
||||
EMBEDDING_OLLAMA_BASE_URL: Optional[str] = None
|
||||
|
||||
# Vector store settings
|
||||
VECTOR_INDEX_NAME: str = "vector_index"
|
||||
VECTOR_DIMENSIONS: int = 1536
|
||||
PGVECTOR_TABLE_NAME: str = "vector_embeddings"
|
||||
PGVECTOR_INDEX_METHOD: str = "ivfflat"
|
||||
PGVECTOR_INDEX_LISTS: int = 100
|
||||
PGVECTOR_PROBES: int = 10
|
||||
## parser
|
||||
PARSER_PROVIDER: Literal["unstructured", "combined", "contextual"]
|
||||
CHUNK_SIZE: int
|
||||
CHUNK_OVERLAP: int
|
||||
USE_UNSTRUCTURED_API: bool
|
||||
FRAME_SAMPLE_RATE: Optional[int] = None
|
||||
|
||||
# Model settings
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
COMPLETION_MODEL: str = "llama3.1"
|
||||
COMPLETION_MAX_TOKENS: int = 1000
|
||||
COMPLETION_TEMPERATURE: float = 0.7
|
||||
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
||||
RERANKER_MODEL: str = "BAAI/bge-reranker-v2-gemma"
|
||||
## reranker
|
||||
USE_RERANKING: bool
|
||||
RERANKER_PROVIDER: Optional[Literal["flag"]] = None
|
||||
RERANKER_MODEL: Optional[str] = None
|
||||
RERANKER_QUERY_MAX_LENGTH: Optional[int] = None
|
||||
RERANKER_PASSAGE_MAX_LENGTH: Optional[int] = None
|
||||
RERANKER_USE_FP16: Optional[bool] = None
|
||||
RERANKER_DEVICE: Optional[str] = None
|
||||
RERANKER_USE_FP16: bool = True
|
||||
RERANKER_QUERY_MAX_LENGTH: int = 256
|
||||
RERANKER_PASSAGE_MAX_LENGTH: int = 512
|
||||
|
||||
# Processing settings
|
||||
CHUNK_SIZE: int = 1000
|
||||
CHUNK_OVERLAP: int = 200
|
||||
DEFAULT_K: int = 4
|
||||
FRAME_SAMPLE_RATE: int = 120
|
||||
USE_UNSTRUCTURED_API: bool = False
|
||||
USE_RERANKING: bool = True
|
||||
## storage
|
||||
STORAGE_PROVIDER: Literal["local", "aws-s3"]
|
||||
STORAGE_PATH: Optional[str] = None
|
||||
AWS_REGION: Optional[str] = None
|
||||
S3_BUCKET: Optional[str] = None
|
||||
|
||||
# Auth settings
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
## vector store
|
||||
VECTOR_STORE_PROVIDER: Literal["pgvector", "mongodb"]
|
||||
VECTOR_STORE_DATABASE_NAME: Optional[str] = None
|
||||
VECTOR_STORE_COLLECTION_NAME: Optional[str] = None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@ -85,76 +81,176 @@ def get_settings() -> Settings:
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Load config.toml
|
||||
with open("config.toml", "rb") as f:
|
||||
with open("databridge.toml", "rb") as f:
|
||||
config = tomli.load(f)
|
||||
|
||||
# Map config.toml values to settings
|
||||
settings_dict = {
|
||||
# Service settings
|
||||
"HOST": config["service"]["host"],
|
||||
"PORT": config["service"]["port"],
|
||||
"RELOAD": config["service"]["reload"],
|
||||
# Component selection
|
||||
"STORAGE_PROVIDER": config["service"]["components"]["storage"],
|
||||
"DATABASE_PROVIDER": config["service"]["components"]["database"],
|
||||
"VECTOR_STORE_PROVIDER": config["service"]["components"]["vector_store"],
|
||||
"EMBEDDING_PROVIDER": config["service"]["components"]["embedding"],
|
||||
"COMPLETION_PROVIDER": config["service"]["components"]["completion"],
|
||||
"PARSER_PROVIDER": config["service"]["components"]["parser"],
|
||||
"RERANKER_PROVIDER": config["service"]["components"]["reranker"],
|
||||
# Storage settings
|
||||
"STORAGE_PATH": config["storage"]["local"]["path"],
|
||||
"AWS_REGION": config["storage"]["aws"]["region"],
|
||||
"S3_BUCKET": config["storage"]["aws"]["bucket_name"],
|
||||
# Database settings
|
||||
"DATABRIDGE_DB": config["database"][config["service"]["components"]["database"]][
|
||||
"database_name"
|
||||
],
|
||||
"DOCUMENTS_TABLE": config["database"]
|
||||
.get("postgres", {})
|
||||
.get("documents_table", "documents"),
|
||||
"CHUNKS_TABLE": config["database"]
|
||||
.get("postgres", {})
|
||||
.get("chunks_table", "document_chunks"),
|
||||
"DOCUMENTS_COLLECTION": config["database"]
|
||||
.get("mongodb", {})
|
||||
.get("documents_collection", "documents"),
|
||||
"CHUNKS_COLLECTION": config["database"]
|
||||
.get("mongodb", {})
|
||||
.get("chunks_collection", "document_chunks"),
|
||||
# Vector store settings
|
||||
"VECTOR_INDEX_NAME": config["vector_store"]["mongodb"]["index_name"],
|
||||
"VECTOR_DIMENSIONS": config["vector_store"][
|
||||
config["service"]["components"]["vector_store"]
|
||||
]["dimensions"],
|
||||
"PGVECTOR_TABLE_NAME": config["vector_store"]
|
||||
.get("pgvector", {})
|
||||
.get("table_name", "vector_embeddings"),
|
||||
"PGVECTOR_INDEX_METHOD": config["vector_store"]
|
||||
.get("pgvector", {})
|
||||
.get("index_method", "ivfflat"),
|
||||
"PGVECTOR_INDEX_LISTS": config["vector_store"].get("pgvector", {}).get("index_lists", 100),
|
||||
"PGVECTOR_PROBES": config["vector_store"].get("pgvector", {}).get("probes", 10),
|
||||
# Model settings
|
||||
"EMBEDDING_MODEL": config["models"]["embedding"]["model_name"],
|
||||
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
|
||||
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
|
||||
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
|
||||
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
|
||||
"RERANKER_MODEL": config["models"]["reranker"]["model_name"],
|
||||
"RERANKER_DEVICE": config["models"]["reranker"].get("device"),
|
||||
"RERANKER_USE_FP16": config["models"]["reranker"].get("use_fp16", True),
|
||||
"RERANKER_QUERY_MAX_LENGTH": config["models"]["reranker"].get("query_max_length", 256),
|
||||
"RERANKER_PASSAGE_MAX_LENGTH": config["models"]["reranker"].get("passage_max_length", 512),
|
||||
# Processing settings
|
||||
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
|
||||
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
|
||||
"DEFAULT_K": config["processing"]["text"]["default_k"],
|
||||
"USE_RERANKING": config["processing"]["text"]["use_reranking"],
|
||||
"FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"],
|
||||
"USE_UNSTRUCTURED_API": config["processing"]["unstructured"]["use_api"],
|
||||
# Auth settings
|
||||
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
|
||||
em = "'{missing_value}' needed if '{field}' is set to '{value}'"
|
||||
# load api config
|
||||
api_config = {
|
||||
"HOST": config["api"]["host"],
|
||||
"PORT": int(config["api"]["port"]),
|
||||
"RELOAD": bool(config["api"]["reload"]),
|
||||
}
|
||||
|
||||
# load auth config
|
||||
auth_config = {
|
||||
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
|
||||
"JWT_SECRET_KEY": os.environ["JWT_SECRET_KEY"],
|
||||
}
|
||||
|
||||
# load completion config
|
||||
completion_config = {
|
||||
"COMPLETION_PROVIDER": config["completion"]["provider"],
|
||||
"COMPLETION_MODEL": config["completion"]["model_name"],
|
||||
}
|
||||
match completion_config["COMPLETION_PROVIDER"]:
|
||||
case "openai" if "OPENAI_API_KEY" in os.environ:
|
||||
completion_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]})
|
||||
case "openai":
|
||||
msg = em.format(
|
||||
missing_value="OPENAI_API_KEY", field="completion.provider", value="openai"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case "ollama" if "base_url" in config["completion"]:
|
||||
completion_config.update(
|
||||
{"COMPLETION_OLLAMA_BASE_URL": config["completion"]["base_url"]}
|
||||
)
|
||||
case "ollama":
|
||||
msg = em.format(missing_value="base_url", field="completion.provider", value="ollama")
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = completion_config["COMPLETION_PROVIDER"]
|
||||
raise ValueError(f"Unknown completion provider selected: '{prov}'")
|
||||
|
||||
# load database config
|
||||
database_config = {"DATABASE_PROVIDER": config["database"]["provider"]}
|
||||
match database_config["DATABASE_PROVIDER"]:
|
||||
case "mongodb":
|
||||
database_config.update(
|
||||
{
|
||||
"DATABASE_NAME": config["database"]["database_name"],
|
||||
"COLLECTION_NAME": config["database"]["collection_name"],
|
||||
}
|
||||
)
|
||||
case "postgres" if "POSTGRES_URI" in os.environ:
|
||||
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
|
||||
case "postgres":
|
||||
msg = em.format(
|
||||
missing_value="POSTGRES_URI", field="database.provider", value="postgres"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = database_config["DATABASE_PROVIDER"]
|
||||
raise ValueError(f"Unknown database provider selected: '{prov}'")
|
||||
|
||||
# load embedding config
|
||||
embedding_config = {
|
||||
"EMBEDDING_PROVIDER": config["embedding"]["provider"],
|
||||
"EMBEDDING_MODEL": config["embedding"]["model_name"],
|
||||
"VECTOR_DIMENSIONS": config["embedding"]["dimensions"],
|
||||
"EMBEDDING_SIMILARITY_METRIC": config["embedding"]["similarity_metric"],
|
||||
}
|
||||
match embedding_config["EMBEDDING_PROVIDER"]:
|
||||
case "openai" if "OPENAI_API_KEY" in os.environ:
|
||||
embedding_config.update({"OPENAI_API_KEY": os.environ["OPENAI_API_KEY"]})
|
||||
case "openai":
|
||||
msg = em.format(
|
||||
missing_value="OPENAI_API_KEY", field="embedding.provider", value="openai"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case "ollama" if "base_url" in config["embedding"]:
|
||||
embedding_config.update({"EMBEDDING_OLLAMA_BASE_URL": config["embedding"]["base_url"]})
|
||||
case "ollama":
|
||||
msg = em.format(missing_value="base_url", field="embedding.provider", value="ollama")
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = embedding_config["EMBEDDING_PROVIDER"]
|
||||
raise ValueError(f"Unknown embedding provider selected: '{prov}'")
|
||||
|
||||
# load parser config
|
||||
parser_config = {
|
||||
"PARSER_PROVIDER": config["parser"]["provider"],
|
||||
"CHUNK_SIZE": config["parser"]["chunk_size"],
|
||||
"CHUNK_OVERLAP": config["parser"]["chunk_overlap"],
|
||||
"USE_UNSTRUCTURED_API": config["parser"]["use_unstructured_api"],
|
||||
}
|
||||
if parser_config["USE_UNSTRUCTURED_API"] and "UNSTRUCTURED_API_KEY" not in os.environ:
|
||||
msg = em.format(
|
||||
missing_value="UNSTRUCTURED_API_KEY", field="parser.use_unstructured_api", value="true"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif parser_config["USE_UNSTRUCTURED_API"]:
|
||||
parser_config.update({"UNSTRUCTURED_API_KEY": os.environ["UNSTRUCTURED_API_KEY"]})
|
||||
|
||||
# load reranker config
|
||||
reranker_config = {"USE_RERANKING": config["reranker"]["use_reranker"]}
|
||||
if reranker_config["USE_RERANKING"]:
|
||||
reranker_config.update(
|
||||
{
|
||||
"RERANKER_PROVIDER": config["reranker"]["provider"],
|
||||
"RERANKER_MODEL": config["reranker"]["model_name"],
|
||||
"RERANKER_QUERY_MAX_LENGTH": config["reranker"]["query_max_length"],
|
||||
"RERANKER_PASSAGE_MAX_LENGTH": config["reranker"]["passage_max_length"],
|
||||
"RERANKER_USE_FP16": config["reranker"]["use_fp16"],
|
||||
"RERANKER_DEVICE": config["reranker"]["device"],
|
||||
}
|
||||
)
|
||||
|
||||
# load storage config
|
||||
storage_config = {"STORAGE_PROVIDER": config["storage"]["provider"]}
|
||||
match storage_config["STORAGE_PROVIDER"]:
|
||||
case "local":
|
||||
storage_config.update({"STORAGE_PATH": config["storage"]["storage_path"]})
|
||||
case "aws-s3" if all(
|
||||
key in os.environ for key in ["AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"]
|
||||
):
|
||||
storage_config.update(
|
||||
{
|
||||
"AWS_REGION": config["storage"]["region"],
|
||||
"S3_BUCKET": config["storage"]["bucket_name"],
|
||||
"AWS_ACCESS_KEY": os.environ["AWS_ACCESS_KEY"],
|
||||
"AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"],
|
||||
}
|
||||
)
|
||||
case "aws-s3":
|
||||
msg = em.format(
|
||||
missing_value="AWS credentials", field="storage.provider", value="aws-s3"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = storage_config["STORAGE_PROVIDER"]
|
||||
raise ValueError(f"Unknown storage provider selected: '{prov}'")
|
||||
|
||||
# load vector store config
|
||||
vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]}
|
||||
match vector_store_config["VECTOR_STORE_PROVIDER"]:
|
||||
case "mongodb":
|
||||
vector_store_config.update(
|
||||
{
|
||||
"VECTOR_STORE_DATABASE_NAME": config["vector_store"]["database_name"],
|
||||
"VECTOR_STORE_COLLECTION_NAME": config["vector_store"]["collection_name"],
|
||||
}
|
||||
)
|
||||
case "pgvector":
|
||||
if "POSTGRES_URI" not in os.environ:
|
||||
msg = em.format(
|
||||
missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
|
||||
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
|
||||
|
||||
settings_dict = {}
|
||||
settings_dict.update(
|
||||
**api_config,
|
||||
**auth_config,
|
||||
**completion_config,
|
||||
**database_config,
|
||||
**embedding_config,
|
||||
**parser_config,
|
||||
**reranker_config,
|
||||
**storage_config,
|
||||
**vector_store_config,
|
||||
)
|
||||
return Settings(**settings_dict)
|
||||
|
@ -5,8 +5,8 @@ from core.models.chunk import DocumentChunk
|
||||
from core.reranker.base_reranker import BaseReranker
|
||||
|
||||
|
||||
class BGEReranker(BaseReranker):
|
||||
"""BGE reranker implementation using FlagEmbedding"""
|
||||
class FlagReranker(BaseReranker):
|
||||
"""Reranker implementation using FlagEmbedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -16,7 +16,7 @@ class BGEReranker(BaseReranker):
|
||||
use_fp16: bool = True,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""Initialize BGE reranker"""
|
||||
"""Initialize flag reranker"""
|
||||
devices = [device] if device else None
|
||||
self.reranker = FlagAutoReranker.from_finetuned(
|
||||
model_name_or_path=model_name,
|
@ -2,7 +2,7 @@ import pytest
|
||||
from typing import List
|
||||
|
||||
from core.models.chunk import DocumentChunk
|
||||
from core.reranker.bge_reranker import BGEReranker
|
||||
from core.reranker.flag_reranker import FlagReranker
|
||||
from core.config import get_settings
|
||||
|
||||
|
||||
@ -35,9 +35,9 @@ def sample_chunks() -> List[DocumentChunk]:
|
||||
|
||||
@pytest.fixture
|
||||
def reranker():
|
||||
"""Fixture to create and reuse a BGE reranker instance"""
|
||||
"""Fixture to create and reuse a flag reranker instance"""
|
||||
settings = get_settings()
|
||||
return BGEReranker(
|
||||
return FlagReranker(
|
||||
model_name=settings.RERANKER_MODEL,
|
||||
device=settings.RERANKER_DEVICE,
|
||||
use_fp16=settings.RERANKER_USE_FP16,
|
||||
|
84
databridge.toml
Normal file
84
databridge.toml
Normal file
@ -0,0 +1,84 @@
|
||||
[api]
|
||||
host = "localhost"
|
||||
port = 8000
|
||||
reload = true
|
||||
|
||||
[auth]
|
||||
jwt_algorithm = "HS256"
|
||||
|
||||
[completion]
|
||||
provider = "ollama"
|
||||
model_name = "llama3.2"
|
||||
default_max_tokens = "1000"
|
||||
default_temperature = 0.7
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
# [completion]
|
||||
# provider = "openai"
|
||||
# model_name = "gpt4o-mini"
|
||||
# default_max_tokens = "1000"
|
||||
# default_temperature = 0.7
|
||||
|
||||
[database]
|
||||
provider = "postgres"
|
||||
|
||||
# [database]
|
||||
# provider = "mongodb"
|
||||
# database_name = "databridge"
|
||||
# collection_name = "documents"
|
||||
|
||||
[embedding]
|
||||
provider = "ollama"
|
||||
model_name = "nomic-embed-text"
|
||||
dimensions = 768
|
||||
similarity_metric = "cosine"
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
# [embedding]
|
||||
# provider = "openai"
|
||||
# model_name = "text-embedding-3-small"
|
||||
# dimensions = 1536
|
||||
# similarity_metric = "dotProduct"
|
||||
|
||||
|
||||
[parser]
|
||||
provider = "unstructured"
|
||||
chunk_size = 1000
|
||||
chunk_overlap = 200
|
||||
use_unstructured_api = false
|
||||
|
||||
# [parser]
|
||||
# provider = "combined" | "contextual"
|
||||
# chunk_size = 1000
|
||||
# chunk_overlap = 200
|
||||
# use_unstructured_api = false
|
||||
# frame_sample_rate = 120
|
||||
|
||||
[reranker]
|
||||
use_reranker = true
|
||||
provider = "flag"
|
||||
model_name = "BAAI/bge-reranker-large"
|
||||
query_max_length = 256
|
||||
passage_max_length = 512
|
||||
use_fp16 = true
|
||||
device = "mps"
|
||||
|
||||
# [reranker]
|
||||
# use_reranker = false
|
||||
|
||||
[storage]
|
||||
provider = "local"
|
||||
storage_path = "../storage"
|
||||
|
||||
# [storage]
|
||||
# provider = "aws-s3"
|
||||
# region = "us-east-2"
|
||||
# bucket_name = "databridge-s3-storage"
|
||||
|
||||
[vector_store]
|
||||
provider = "pgvector"
|
||||
|
||||
# [vector_store]
|
||||
# provider = "mongodb"
|
||||
# database_name = "databridge"
|
||||
# collection_name = "document_chunks"
|
@ -37,33 +37,33 @@ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(messag
|
||||
console_handler.setFormatter(formatter)
|
||||
LOGGER.addHandler(console_handler)
|
||||
|
||||
# Load configuration from config.toml
|
||||
config_path = Path("config.toml")
|
||||
# Load configuration from databridge.toml
|
||||
config_path = Path("databridge.toml")
|
||||
with open(config_path, "rb") as f:
|
||||
CONFIG = tomli.load(f)
|
||||
LOGGER.info("Loaded configuration from config.toml")
|
||||
LOGGER.info("Loaded configuration from databridge.toml")
|
||||
|
||||
# Extract configuration values
|
||||
STORAGE_PROVIDER = CONFIG["service"]["components"]["storage"]
|
||||
DATABASE_PROVIDER = CONFIG["service"]["components"]["database"]
|
||||
DATABASE_NAME = CONFIG["database"][DATABASE_PROVIDER]["database_name"]
|
||||
STORAGE_PROVIDER = CONFIG["storage"]["provider"]
|
||||
DATABASE_PROVIDER = CONFIG["database"]["provider"]
|
||||
|
||||
# MongoDB specific config
|
||||
DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"]
|
||||
CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"]
|
||||
VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"]
|
||||
VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"]
|
||||
SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"]
|
||||
|
||||
# PostgreSQL specific config
|
||||
DOCUMENTS_TABLE = CONFIG["database"]["postgres"]["documents_table"]
|
||||
CHUNKS_TABLE = CONFIG["database"]["postgres"]["chunks_table"]
|
||||
if "mongodb" in CONFIG["database"]:
|
||||
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
|
||||
DOCUMENTS_COLLECTION = "documents"
|
||||
CHUNKS_COLLECTION = "document_chunks"
|
||||
if "mongodb" in CONFIG["vector_store"]:
|
||||
VECTOR_DIMENSIONS = CONFIG["embedding"]["dimensions"]
|
||||
VECTOR_INDEX_NAME = "vector_index"
|
||||
SIMILARITY_METRIC = CONFIG["embedding"]["similarity_metric"]
|
||||
|
||||
# Extract storage-specific configuration
|
||||
DEFAULT_REGION = CONFIG["storage"]["aws"]["region"] if STORAGE_PROVIDER == "aws-s3" else None
|
||||
DEFAULT_BUCKET_NAME = (
|
||||
CONFIG["storage"]["aws"]["bucket_name"] if STORAGE_PROVIDER == "aws-s3" else None
|
||||
)
|
||||
if STORAGE_PROVIDER == "aws-s3":
|
||||
DEFAULT_REGION = CONFIG["storage"]["region"]
|
||||
DEFAULT_BUCKET_NAME = CONFIG["storage"]["bucket_name"]
|
||||
else:
|
||||
DEFAULT_REGION = None
|
||||
DEFAULT_BUCKET_NAME = None
|
||||
|
||||
|
||||
def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
|
||||
@ -248,10 +248,8 @@ def setup_postgres():
|
||||
LOGGER.info("Created all PostgreSQL tables and indexes")
|
||||
|
||||
# Create vector index with configuration from settings
|
||||
table_name = CONFIG["vector_store"]["pgvector"]["table_name"]
|
||||
index_method = CONFIG["vector_store"]["pgvector"]["index_method"]
|
||||
index_lists = CONFIG["vector_store"]["pgvector"]["index_lists"]
|
||||
dimensions = CONFIG["vector_store"]["pgvector"]["dimensions"]
|
||||
table_name = "document_chunks" # Default table name for pgvector
|
||||
dimensions = CONFIG["embedding"]["dimensions"]
|
||||
|
||||
# First, alter the embedding column to be a vector
|
||||
alter_sql = f"""
|
||||
@ -262,15 +260,14 @@ def setup_postgres():
|
||||
await conn.execute(text(alter_sql))
|
||||
LOGGER.info(f"Altered embedding column to be vector({dimensions})")
|
||||
|
||||
# Then create the vector index
|
||||
if index_method == "ivfflat":
|
||||
index_sql = f"""
|
||||
CREATE INDEX IF NOT EXISTS vector_idx
|
||||
ON {table_name} USING ivfflat (embedding vector_l2_ops)
|
||||
WITH (lists = {index_lists});
|
||||
"""
|
||||
await conn.execute(text(index_sql))
|
||||
LOGGER.info(f"Created IVFFlat index on {table_name} with {index_lists} lists")
|
||||
# Create the vector index
|
||||
index_sql = f"""
|
||||
CREATE INDEX IF NOT EXISTS vector_idx
|
||||
ON {table_name} USING ivfflat (embedding vector_l2_ops)
|
||||
WITH (lists = 100);
|
||||
"""
|
||||
await conn.execute(text(index_sql))
|
||||
LOGGER.info(f"Created IVFFlat index on {table_name}")
|
||||
|
||||
await engine.dispose()
|
||||
LOGGER.info("PostgreSQL setup completed successfully")
|
||||
|
Loading…
x
Reference in New Issue
Block a user