add ollama embeddings and test them out

This commit is contained in:
Arnav Agrawal 2024-12-27 12:17:16 +05:30
parent 124ce195c0
commit b883f52a11
6 changed files with 54 additions and 13 deletions

View File

@ -25,8 +25,9 @@ chunks_collection = "document_chunks"
# Vector Store Configuration # Vector Store Configuration
[vector_store.mongodb] [vector_store.mongodb]
dimensions = 1536 dimensions = 1536 # 768 for nomic-embed-text
index_name = "vector_index" index_name = "vector_index"
similarity_metric = "dotProduct"
# Model Configurations # Model Configurations
[models] [models]
@ -38,6 +39,9 @@ model_name = "gpt-4o-mini"
default_max_tokens = 1000 default_max_tokens = 1000
default_temperature = 0.7 default_temperature = 0.7
[models.ollama]
base_url = "http://localhost:11434"
# Document Processing # Document Processing
[processing] [processing]
[processing.text] [processing.text]

View File

@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
import jwt import jwt
import logging import logging
from core.completion.openai_completion import OpenAICompletionModel from core.completion.openai_completion import OpenAICompletionModel
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
from core.models.request import ( from core.models.request import (
IngestTextRequest, IngestTextRequest,
RetrieveRequest, RetrieveRequest,
@ -103,6 +104,11 @@ match settings.EMBEDDING_PROVIDER:
api_key=settings.OPENAI_API_KEY, api_key=settings.OPENAI_API_KEY,
model_name=settings.EMBEDDING_MODEL, model_name=settings.EMBEDDING_MODEL,
) )
case "ollama":
embedding_model = OllamaEmbeddingModel(
model_name=settings.EMBEDDING_MODEL,
base_url=settings.OLLAMA_BASE_URL,
)
case _: case _:
raise ValueError( raise ValueError(
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}" f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
@ -113,6 +119,7 @@ match settings.COMPLETION_PROVIDER:
case "ollama": case "ollama":
completion_model = OllamaCompletionModel( completion_model = OllamaCompletionModel(
model_name=settings.COMPLETION_MODEL, model_name=settings.COMPLETION_MODEL,
base_url=settings.OLLAMA_BASE_URL,
) )
case "openai": case "openai":
completion_model = OpenAICompletionModel( completion_model = OpenAICompletionModel(

View File

@ -3,14 +3,15 @@ from core.completion.base_completion import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
) )
import ollama from ollama import AsyncClient
class OllamaCompletionModel(BaseCompletionModel): class OllamaCompletionModel(BaseCompletionModel):
"""Ollama completion model implementation""" """Ollama completion model implementation"""
def __init__(self, model_name: str): def __init__(self, model_name: str, base_url: str):
self.model_name = model_name self.model_name = model_name
self.client = AsyncClient(host=base_url)
async def complete(self, request: CompletionRequest) -> CompletionResponse: async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using Ollama API""" """Generate completion using Ollama API"""
@ -24,7 +25,7 @@ Context:
Question: {request.query}""" Question: {request.query}"""
# Call Ollama API # Call Ollama API
response = ollama.chat( response = await self.client.chat(
model=self.model_name, model=self.model_name,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
options={ options={
@ -36,7 +37,6 @@ Question: {request.query}"""
# Ollama doesn't provide token usage info, so we'll estimate based on characters # Ollama doesn't provide token usage info, so we'll estimate based on characters
completion_text = response["message"]["content"] completion_text = response["message"]["content"]
char_to_token_ratio = 4 # Rough estimate char_to_token_ratio = 4 # Rough estimate
total_chars = len(prompt) + len(completion_text)
estimated_prompt_tokens = len(prompt) // char_to_token_ratio estimated_prompt_tokens = len(prompt) // char_to_token_ratio
estimated_completion_tokens = len(completion_text) // char_to_token_ratio estimated_completion_tokens = len(completion_text) // char_to_token_ratio

View File

@ -48,6 +48,7 @@ class Settings(BaseSettings):
COMPLETION_MODEL: str = "llama3.1" COMPLETION_MODEL: str = "llama3.1"
COMPLETION_MAX_TOKENS: int = 1000 COMPLETION_MAX_TOKENS: int = 1000
COMPLETION_TEMPERATURE: float = 0.7 COMPLETION_TEMPERATURE: float = 0.7
OLLAMA_BASE_URL: str = "http://localhost:11434"
# Processing settings # Processing settings
CHUNK_SIZE: int = 1000 CHUNK_SIZE: int = 1000
@ -96,6 +97,7 @@ def get_settings() -> Settings:
"COMPLETION_MODEL": config["models"]["completion"]["model_name"], "COMPLETION_MODEL": config["models"]["completion"]["model_name"],
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"], "COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"], "COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
# Processing settings # Processing settings
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"], "CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"], "CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],

View File

@ -0,0 +1,27 @@
from typing import List, Union
from ollama import AsyncClient
from core.embedding.base_embedding_model import BaseEmbeddingModel
class OllamaEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name, base_url: str = "http://localhost:11434"):
self.model_name = model_name
self.client = AsyncClient(host=base_url)
async def embed_for_ingestion(
self, text: Union[str, List[str]]
) -> List[List[float]]:
if isinstance(text, str):
text = [text]
embeddings: List[List[float]] = []
for t in text:
response = await self.client.embeddings(model=self.model_name, prompt=t)
embedding = list(response["embedding"])
embeddings.append(embedding)
return embeddings
async def embed_for_query(self, text: str) -> List[float]:
response = await self.client.embeddings(model=self.model_name, prompt=text)
return response["embedding"]

View File

@ -43,13 +43,14 @@ with open(config_path, "rb") as f:
LOGGER.info("Loaded configuration from config.toml") LOGGER.info("Loaded configuration from config.toml")
# Extract configuration values # Extract configuration values
DEFAULT_REGION = CONFIG["aws"]["default_region"] DEFAULT_REGION = CONFIG["storage"]["aws"]["region"]
DEFAULT_BUCKET_NAME = CONFIG["aws"]["default_bucket_name"] DEFAULT_BUCKET_NAME = CONFIG["storage"]["aws"]["bucket_name"]
DATABASE_NAME = CONFIG["mongodb"]["database_name"] DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
DOCUMENTS_COLLECTION = CONFIG["mongodb"]["documents_collection"] DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"]
CHUNKS_COLLECTION = CONFIG["mongodb"]["chunks_collection"] CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"]
VECTOR_DIMENSIONS = CONFIG["mongodb"]["vector"]["dimensions"] VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"]
VECTOR_INDEX_NAME = CONFIG["mongodb"]["vector"]["index_name"] VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"]
SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"]
def create_s3_bucket(bucket_name, region=DEFAULT_REGION): def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
@ -142,7 +143,7 @@ def setup_mongodb():
{ {
"numDimensions": VECTOR_DIMENSIONS, "numDimensions": VECTOR_DIMENSIONS,
"path": "embedding", "path": "embedding",
"similarity": "dotProduct", "similarity": SIMILARITY_METRIC,
"type": "vector", "type": "vector",
}, },
{"path": "document_id", "type": "filter"}, {"path": "document_id", "type": "filter"},