add ollama embeddings and test them out

This commit is contained in:
Arnav Agrawal 2024-12-27 12:17:16 +05:30
parent 124ce195c0
commit b883f52a11
6 changed files with 54 additions and 13 deletions

View File

@ -25,8 +25,9 @@ chunks_collection = "document_chunks"
# Vector Store Configuration
[vector_store.mongodb]
dimensions = 1536
dimensions = 1536 # 768 for nomic-embed-text
index_name = "vector_index"
similarity_metric = "dotProduct"
# Model Configurations
[models]
@ -38,6 +39,9 @@ model_name = "gpt-4o-mini"
default_max_tokens = 1000
default_temperature = 0.7
[models.ollama]
base_url = "http://localhost:11434"
# Document Processing
[processing]
[processing.text]

View File

@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
import jwt
import logging
from core.completion.openai_completion import OpenAICompletionModel
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
from core.models.request import (
IngestTextRequest,
RetrieveRequest,
@ -103,6 +104,11 @@ match settings.EMBEDDING_PROVIDER:
api_key=settings.OPENAI_API_KEY,
model_name=settings.EMBEDDING_MODEL,
)
case "ollama":
embedding_model = OllamaEmbeddingModel(
model_name=settings.EMBEDDING_MODEL,
base_url=settings.OLLAMA_BASE_URL,
)
case _:
raise ValueError(
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
@ -113,6 +119,7 @@ match settings.COMPLETION_PROVIDER:
case "ollama":
completion_model = OllamaCompletionModel(
model_name=settings.COMPLETION_MODEL,
base_url=settings.OLLAMA_BASE_URL,
)
case "openai":
completion_model = OpenAICompletionModel(

View File

@ -3,14 +3,15 @@ from core.completion.base_completion import (
CompletionRequest,
CompletionResponse,
)
import ollama
from ollama import AsyncClient
class OllamaCompletionModel(BaseCompletionModel):
"""Ollama completion model implementation"""
def __init__(self, model_name: str):
def __init__(self, model_name: str, base_url: str):
self.model_name = model_name
self.client = AsyncClient(host=base_url)
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using Ollama API"""
@ -24,7 +25,7 @@ Context:
Question: {request.query}"""
# Call Ollama API
response = ollama.chat(
response = await self.client.chat(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
options={
@ -36,7 +37,6 @@ Question: {request.query}"""
# Ollama doesn't provide token usage info, so we'll estimate based on characters
completion_text = response["message"]["content"]
char_to_token_ratio = 4 # Rough estimate
total_chars = len(prompt) + len(completion_text)
estimated_prompt_tokens = len(prompt) // char_to_token_ratio
estimated_completion_tokens = len(completion_text) // char_to_token_ratio

View File

@ -48,6 +48,7 @@ class Settings(BaseSettings):
COMPLETION_MODEL: str = "llama3.1"
COMPLETION_MAX_TOKENS: int = 1000
COMPLETION_TEMPERATURE: float = 0.7
OLLAMA_BASE_URL: str = "http://localhost:11434"
# Processing settings
CHUNK_SIZE: int = 1000
@ -96,6 +97,7 @@ def get_settings() -> Settings:
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
# Processing settings
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],

View File

@ -0,0 +1,27 @@
from typing import List, Union
from ollama import AsyncClient
from core.embedding.base_embedding_model import BaseEmbeddingModel
class OllamaEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name, base_url: str = "http://localhost:11434"):
self.model_name = model_name
self.client = AsyncClient(host=base_url)
async def embed_for_ingestion(
self, text: Union[str, List[str]]
) -> List[List[float]]:
if isinstance(text, str):
text = [text]
embeddings: List[List[float]] = []
for t in text:
response = await self.client.embeddings(model=self.model_name, prompt=t)
embedding = list(response["embedding"])
embeddings.append(embedding)
return embeddings
async def embed_for_query(self, text: str) -> List[float]:
response = await self.client.embeddings(model=self.model_name, prompt=text)
return response["embedding"]

View File

@ -43,13 +43,14 @@ with open(config_path, "rb") as f:
LOGGER.info("Loaded configuration from config.toml")
# Extract configuration values
DEFAULT_REGION = CONFIG["aws"]["default_region"]
DEFAULT_BUCKET_NAME = CONFIG["aws"]["default_bucket_name"]
DATABASE_NAME = CONFIG["mongodb"]["database_name"]
DOCUMENTS_COLLECTION = CONFIG["mongodb"]["documents_collection"]
CHUNKS_COLLECTION = CONFIG["mongodb"]["chunks_collection"]
VECTOR_DIMENSIONS = CONFIG["mongodb"]["vector"]["dimensions"]
VECTOR_INDEX_NAME = CONFIG["mongodb"]["vector"]["index_name"]
DEFAULT_REGION = CONFIG["storage"]["aws"]["region"]
DEFAULT_BUCKET_NAME = CONFIG["storage"]["aws"]["bucket_name"]
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"]
CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"]
VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"]
VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"]
SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"]
def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
@ -142,7 +143,7 @@ def setup_mongodb():
{
"numDimensions": VECTOR_DIMENSIONS,
"path": "embedding",
"similarity": "dotProduct",
"similarity": SIMILARITY_METRIC,
"type": "vector",
},
{"path": "document_id", "type": "filter"},