diff --git a/config.toml b/config.toml index b2e70ce..2cd8f6a 100644 --- a/config.toml +++ b/config.toml @@ -25,8 +25,9 @@ chunks_collection = "document_chunks" # Vector Store Configuration [vector_store.mongodb] -dimensions = 1536 +dimensions = 1536 # 768 for nomic-embed-text index_name = "vector_index" +similarity_metric = "dotProduct" # Model Configurations [models] @@ -38,6 +39,9 @@ model_name = "gpt-4o-mini" default_max_tokens = 1000 default_temperature = 0.7 +[models.ollama] +base_url = "http://localhost:11434" + # Document Processing [processing] [processing.text] diff --git a/core/api.py b/core/api.py index bc07811..dff94f1 100644 --- a/core/api.py +++ b/core/api.py @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware import jwt import logging from core.completion.openai_completion import OpenAICompletionModel +from core.embedding.ollama_embedding_model import OllamaEmbeddingModel from core.models.request import ( IngestTextRequest, RetrieveRequest, @@ -103,6 +104,11 @@ match settings.EMBEDDING_PROVIDER: api_key=settings.OPENAI_API_KEY, model_name=settings.EMBEDDING_MODEL, ) + case "ollama": + embedding_model = OllamaEmbeddingModel( + model_name=settings.EMBEDDING_MODEL, + base_url=settings.OLLAMA_BASE_URL, + ) case _: raise ValueError( f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}" @@ -113,6 +119,7 @@ match settings.COMPLETION_PROVIDER: case "ollama": completion_model = OllamaCompletionModel( model_name=settings.COMPLETION_MODEL, + base_url=settings.OLLAMA_BASE_URL, ) case "openai": completion_model = OpenAICompletionModel( diff --git a/core/completion/ollama_completion.py b/core/completion/ollama_completion.py index b061ac2..2a79923 100644 --- a/core/completion/ollama_completion.py +++ b/core/completion/ollama_completion.py @@ -3,14 +3,15 @@ from core.completion.base_completion import ( CompletionRequest, CompletionResponse, ) -import ollama +from ollama import AsyncClient class OllamaCompletionModel(BaseCompletionModel): """Ollama completion model implementation""" - def __init__(self, model_name: str): + def __init__(self, model_name: str, base_url: str): self.model_name = model_name + self.client = AsyncClient(host=base_url) async def complete(self, request: CompletionRequest) -> CompletionResponse: """Generate completion using Ollama API""" @@ -24,7 +25,7 @@ Context: Question: {request.query}""" # Call Ollama API - response = ollama.chat( + response = await self.client.chat( model=self.model_name, messages=[{"role": "user", "content": prompt}], options={ @@ -36,7 +37,6 @@ Question: {request.query}""" # Ollama doesn't provide token usage info, so we'll estimate based on characters completion_text = response["message"]["content"] char_to_token_ratio = 4 # Rough estimate - total_chars = len(prompt) + len(completion_text) estimated_prompt_tokens = len(prompt) // char_to_token_ratio estimated_completion_tokens = len(completion_text) // char_to_token_ratio diff --git a/core/config.py b/core/config.py index cfe0c95..352e1d0 100644 --- a/core/config.py +++ b/core/config.py @@ -48,6 +48,7 @@ class Settings(BaseSettings): COMPLETION_MODEL: str = "llama3.1" COMPLETION_MAX_TOKENS: int = 1000 COMPLETION_TEMPERATURE: float = 0.7 + OLLAMA_BASE_URL: str = "http://localhost:11434" # Processing settings CHUNK_SIZE: int = 1000 @@ -96,6 +97,7 @@ def get_settings() -> Settings: "COMPLETION_MODEL": config["models"]["completion"]["model_name"], "COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"], "COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"], + "OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"], # Processing settings "CHUNK_SIZE": config["processing"]["text"]["chunk_size"], "CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"], diff --git a/core/embedding/ollama_embedding_model.py b/core/embedding/ollama_embedding_model.py new file mode 100644 index 0000000..95cf936 --- /dev/null +++ b/core/embedding/ollama_embedding_model.py @@ -0,0 +1,27 @@ +from typing import List, Union +from ollama import AsyncClient +from core.embedding.base_embedding_model import BaseEmbeddingModel + + +class OllamaEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name, base_url: str = "http://localhost:11434"): + self.model_name = model_name + self.client = AsyncClient(host=base_url) + + async def embed_for_ingestion( + self, text: Union[str, List[str]] + ) -> List[List[float]]: + if isinstance(text, str): + text = [text] + + embeddings: List[List[float]] = [] + for t in text: + response = await self.client.embeddings(model=self.model_name, prompt=t) + embedding = list(response["embedding"]) + embeddings.append(embedding) + + return embeddings + + async def embed_for_query(self, text: str) -> List[float]: + response = await self.client.embeddings(model=self.model_name, prompt=text) + return response["embedding"] diff --git a/quick_setup.py b/quick_setup.py index c36ebf4..85e56ce 100644 --- a/quick_setup.py +++ b/quick_setup.py @@ -43,13 +43,14 @@ with open(config_path, "rb") as f: LOGGER.info("Loaded configuration from config.toml") # Extract configuration values -DEFAULT_REGION = CONFIG["aws"]["default_region"] -DEFAULT_BUCKET_NAME = CONFIG["aws"]["default_bucket_name"] -DATABASE_NAME = CONFIG["mongodb"]["database_name"] -DOCUMENTS_COLLECTION = CONFIG["mongodb"]["documents_collection"] -CHUNKS_COLLECTION = CONFIG["mongodb"]["chunks_collection"] -VECTOR_DIMENSIONS = CONFIG["mongodb"]["vector"]["dimensions"] -VECTOR_INDEX_NAME = CONFIG["mongodb"]["vector"]["index_name"] +DEFAULT_REGION = CONFIG["storage"]["aws"]["region"] +DEFAULT_BUCKET_NAME = CONFIG["storage"]["aws"]["bucket_name"] +DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"] +DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"] +CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"] +VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"] +VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"] +SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"] def create_s3_bucket(bucket_name, region=DEFAULT_REGION): @@ -142,7 +143,7 @@ def setup_mongodb(): { "numDimensions": VECTOR_DIMENSIONS, "path": "embedding", - "similarity": "dotProduct", + "similarity": SIMILARITY_METRIC, "type": "vector", }, {"path": "document_id", "type": "filter"},