mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
add ollama embeddings and test them out
This commit is contained in:
parent
124ce195c0
commit
b883f52a11
@ -25,8 +25,9 @@ chunks_collection = "document_chunks"
|
||||
|
||||
# Vector Store Configuration
|
||||
[vector_store.mongodb]
|
||||
dimensions = 1536
|
||||
dimensions = 1536 # 768 for nomic-embed-text
|
||||
index_name = "vector_index"
|
||||
similarity_metric = "dotProduct"
|
||||
|
||||
# Model Configurations
|
||||
[models]
|
||||
@ -38,6 +39,9 @@ model_name = "gpt-4o-mini"
|
||||
default_max_tokens = 1000
|
||||
default_temperature = 0.7
|
||||
|
||||
[models.ollama]
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
# Document Processing
|
||||
[processing]
|
||||
[processing.text]
|
||||
|
@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
import logging
|
||||
from core.completion.openai_completion import OpenAICompletionModel
|
||||
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
||||
from core.models.request import (
|
||||
IngestTextRequest,
|
||||
RetrieveRequest,
|
||||
@ -103,6 +104,11 @@ match settings.EMBEDDING_PROVIDER:
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
)
|
||||
case "ollama":
|
||||
embedding_model = OllamaEmbeddingModel(
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
|
||||
@ -113,6 +119,7 @@ match settings.COMPLETION_PROVIDER:
|
||||
case "ollama":
|
||||
completion_model = OllamaCompletionModel(
|
||||
model_name=settings.COMPLETION_MODEL,
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
)
|
||||
case "openai":
|
||||
completion_model = OpenAICompletionModel(
|
||||
|
@ -3,14 +3,15 @@ from core.completion.base_completion import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
)
|
||||
import ollama
|
||||
from ollama import AsyncClient
|
||||
|
||||
|
||||
class OllamaCompletionModel(BaseCompletionModel):
|
||||
"""Ollama completion model implementation"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
def __init__(self, model_name: str, base_url: str):
|
||||
self.model_name = model_name
|
||||
self.client = AsyncClient(host=base_url)
|
||||
|
||||
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
||||
"""Generate completion using Ollama API"""
|
||||
@ -24,7 +25,7 @@ Context:
|
||||
Question: {request.query}"""
|
||||
|
||||
# Call Ollama API
|
||||
response = ollama.chat(
|
||||
response = await self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
options={
|
||||
@ -36,7 +37,6 @@ Question: {request.query}"""
|
||||
# Ollama doesn't provide token usage info, so we'll estimate based on characters
|
||||
completion_text = response["message"]["content"]
|
||||
char_to_token_ratio = 4 # Rough estimate
|
||||
total_chars = len(prompt) + len(completion_text)
|
||||
estimated_prompt_tokens = len(prompt) // char_to_token_ratio
|
||||
estimated_completion_tokens = len(completion_text) // char_to_token_ratio
|
||||
|
||||
|
@ -48,6 +48,7 @@ class Settings(BaseSettings):
|
||||
COMPLETION_MODEL: str = "llama3.1"
|
||||
COMPLETION_MAX_TOKENS: int = 1000
|
||||
COMPLETION_TEMPERATURE: float = 0.7
|
||||
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
||||
|
||||
# Processing settings
|
||||
CHUNK_SIZE: int = 1000
|
||||
@ -96,6 +97,7 @@ def get_settings() -> Settings:
|
||||
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
|
||||
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
|
||||
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
|
||||
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
|
||||
# Processing settings
|
||||
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
|
||||
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
|
||||
|
27
core/embedding/ollama_embedding_model.py
Normal file
27
core/embedding/ollama_embedding_model.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import List, Union
|
||||
from ollama import AsyncClient
|
||||
from core.embedding.base_embedding_model import BaseEmbeddingModel
|
||||
|
||||
|
||||
class OllamaEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self, model_name, base_url: str = "http://localhost:11434"):
|
||||
self.model_name = model_name
|
||||
self.client = AsyncClient(host=base_url)
|
||||
|
||||
async def embed_for_ingestion(
|
||||
self, text: Union[str, List[str]]
|
||||
) -> List[List[float]]:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
embeddings: List[List[float]] = []
|
||||
for t in text:
|
||||
response = await self.client.embeddings(model=self.model_name, prompt=t)
|
||||
embedding = list(response["embedding"])
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def embed_for_query(self, text: str) -> List[float]:
|
||||
response = await self.client.embeddings(model=self.model_name, prompt=text)
|
||||
return response["embedding"]
|
@ -43,13 +43,14 @@ with open(config_path, "rb") as f:
|
||||
LOGGER.info("Loaded configuration from config.toml")
|
||||
|
||||
# Extract configuration values
|
||||
DEFAULT_REGION = CONFIG["aws"]["default_region"]
|
||||
DEFAULT_BUCKET_NAME = CONFIG["aws"]["default_bucket_name"]
|
||||
DATABASE_NAME = CONFIG["mongodb"]["database_name"]
|
||||
DOCUMENTS_COLLECTION = CONFIG["mongodb"]["documents_collection"]
|
||||
CHUNKS_COLLECTION = CONFIG["mongodb"]["chunks_collection"]
|
||||
VECTOR_DIMENSIONS = CONFIG["mongodb"]["vector"]["dimensions"]
|
||||
VECTOR_INDEX_NAME = CONFIG["mongodb"]["vector"]["index_name"]
|
||||
DEFAULT_REGION = CONFIG["storage"]["aws"]["region"]
|
||||
DEFAULT_BUCKET_NAME = CONFIG["storage"]["aws"]["bucket_name"]
|
||||
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
|
||||
DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"]
|
||||
CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"]
|
||||
VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"]
|
||||
VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"]
|
||||
SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"]
|
||||
|
||||
|
||||
def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
|
||||
@ -142,7 +143,7 @@ def setup_mongodb():
|
||||
{
|
||||
"numDimensions": VECTOR_DIMENSIONS,
|
||||
"path": "embedding",
|
||||
"similarity": "dotProduct",
|
||||
"similarity": SIMILARITY_METRIC,
|
||||
"type": "vector",
|
||||
},
|
||||
{"path": "document_id", "type": "filter"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user