mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
add ollama embeddings and test them out
This commit is contained in:
parent
124ce195c0
commit
b883f52a11
@ -25,8 +25,9 @@ chunks_collection = "document_chunks"
|
|||||||
|
|
||||||
# Vector Store Configuration
|
# Vector Store Configuration
|
||||||
[vector_store.mongodb]
|
[vector_store.mongodb]
|
||||||
dimensions = 1536
|
dimensions = 1536 # 768 for nomic-embed-text
|
||||||
index_name = "vector_index"
|
index_name = "vector_index"
|
||||||
|
similarity_metric = "dotProduct"
|
||||||
|
|
||||||
# Model Configurations
|
# Model Configurations
|
||||||
[models]
|
[models]
|
||||||
@ -38,6 +39,9 @@ model_name = "gpt-4o-mini"
|
|||||||
default_max_tokens = 1000
|
default_max_tokens = 1000
|
||||||
default_temperature = 0.7
|
default_temperature = 0.7
|
||||||
|
|
||||||
|
[models.ollama]
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
|
||||||
# Document Processing
|
# Document Processing
|
||||||
[processing]
|
[processing]
|
||||||
[processing.text]
|
[processing.text]
|
||||||
|
@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
import jwt
|
import jwt
|
||||||
import logging
|
import logging
|
||||||
from core.completion.openai_completion import OpenAICompletionModel
|
from core.completion.openai_completion import OpenAICompletionModel
|
||||||
|
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
||||||
from core.models.request import (
|
from core.models.request import (
|
||||||
IngestTextRequest,
|
IngestTextRequest,
|
||||||
RetrieveRequest,
|
RetrieveRequest,
|
||||||
@ -103,6 +104,11 @@ match settings.EMBEDDING_PROVIDER:
|
|||||||
api_key=settings.OPENAI_API_KEY,
|
api_key=settings.OPENAI_API_KEY,
|
||||||
model_name=settings.EMBEDDING_MODEL,
|
model_name=settings.EMBEDDING_MODEL,
|
||||||
)
|
)
|
||||||
|
case "ollama":
|
||||||
|
embedding_model = OllamaEmbeddingModel(
|
||||||
|
model_name=settings.EMBEDDING_MODEL,
|
||||||
|
base_url=settings.OLLAMA_BASE_URL,
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
|
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
|
||||||
@ -113,6 +119,7 @@ match settings.COMPLETION_PROVIDER:
|
|||||||
case "ollama":
|
case "ollama":
|
||||||
completion_model = OllamaCompletionModel(
|
completion_model = OllamaCompletionModel(
|
||||||
model_name=settings.COMPLETION_MODEL,
|
model_name=settings.COMPLETION_MODEL,
|
||||||
|
base_url=settings.OLLAMA_BASE_URL,
|
||||||
)
|
)
|
||||||
case "openai":
|
case "openai":
|
||||||
completion_model = OpenAICompletionModel(
|
completion_model = OpenAICompletionModel(
|
||||||
|
@ -3,14 +3,15 @@ from core.completion.base_completion import (
|
|||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
)
|
)
|
||||||
import ollama
|
from ollama import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
class OllamaCompletionModel(BaseCompletionModel):
|
class OllamaCompletionModel(BaseCompletionModel):
|
||||||
"""Ollama completion model implementation"""
|
"""Ollama completion model implementation"""
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, base_url: str):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.client = AsyncClient(host=base_url)
|
||||||
|
|
||||||
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
"""Generate completion using Ollama API"""
|
"""Generate completion using Ollama API"""
|
||||||
@ -24,7 +25,7 @@ Context:
|
|||||||
Question: {request.query}"""
|
Question: {request.query}"""
|
||||||
|
|
||||||
# Call Ollama API
|
# Call Ollama API
|
||||||
response = ollama.chat(
|
response = await self.client.chat(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
options={
|
options={
|
||||||
@ -36,7 +37,6 @@ Question: {request.query}"""
|
|||||||
# Ollama doesn't provide token usage info, so we'll estimate based on characters
|
# Ollama doesn't provide token usage info, so we'll estimate based on characters
|
||||||
completion_text = response["message"]["content"]
|
completion_text = response["message"]["content"]
|
||||||
char_to_token_ratio = 4 # Rough estimate
|
char_to_token_ratio = 4 # Rough estimate
|
||||||
total_chars = len(prompt) + len(completion_text)
|
|
||||||
estimated_prompt_tokens = len(prompt) // char_to_token_ratio
|
estimated_prompt_tokens = len(prompt) // char_to_token_ratio
|
||||||
estimated_completion_tokens = len(completion_text) // char_to_token_ratio
|
estimated_completion_tokens = len(completion_text) // char_to_token_ratio
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ class Settings(BaseSettings):
|
|||||||
COMPLETION_MODEL: str = "llama3.1"
|
COMPLETION_MODEL: str = "llama3.1"
|
||||||
COMPLETION_MAX_TOKENS: int = 1000
|
COMPLETION_MAX_TOKENS: int = 1000
|
||||||
COMPLETION_TEMPERATURE: float = 0.7
|
COMPLETION_TEMPERATURE: float = 0.7
|
||||||
|
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
||||||
|
|
||||||
# Processing settings
|
# Processing settings
|
||||||
CHUNK_SIZE: int = 1000
|
CHUNK_SIZE: int = 1000
|
||||||
@ -96,6 +97,7 @@ def get_settings() -> Settings:
|
|||||||
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
|
"COMPLETION_MODEL": config["models"]["completion"]["model_name"],
|
||||||
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
|
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
|
||||||
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
|
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
|
||||||
|
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
|
||||||
# Processing settings
|
# Processing settings
|
||||||
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
|
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
|
||||||
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
|
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
|
||||||
|
27
core/embedding/ollama_embedding_model.py
Normal file
27
core/embedding/ollama_embedding_model.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
from ollama import AsyncClient
|
||||||
|
from core.embedding.base_embedding_model import BaseEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbeddingModel(BaseEmbeddingModel):
|
||||||
|
def __init__(self, model_name, base_url: str = "http://localhost:11434"):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.client = AsyncClient(host=base_url)
|
||||||
|
|
||||||
|
async def embed_for_ingestion(
|
||||||
|
self, text: Union[str, List[str]]
|
||||||
|
) -> List[List[float]]:
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
embeddings: List[List[float]] = []
|
||||||
|
for t in text:
|
||||||
|
response = await self.client.embeddings(model=self.model_name, prompt=t)
|
||||||
|
embedding = list(response["embedding"])
|
||||||
|
embeddings.append(embedding)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
async def embed_for_query(self, text: str) -> List[float]:
|
||||||
|
response = await self.client.embeddings(model=self.model_name, prompt=text)
|
||||||
|
return response["embedding"]
|
@ -43,13 +43,14 @@ with open(config_path, "rb") as f:
|
|||||||
LOGGER.info("Loaded configuration from config.toml")
|
LOGGER.info("Loaded configuration from config.toml")
|
||||||
|
|
||||||
# Extract configuration values
|
# Extract configuration values
|
||||||
DEFAULT_REGION = CONFIG["aws"]["default_region"]
|
DEFAULT_REGION = CONFIG["storage"]["aws"]["region"]
|
||||||
DEFAULT_BUCKET_NAME = CONFIG["aws"]["default_bucket_name"]
|
DEFAULT_BUCKET_NAME = CONFIG["storage"]["aws"]["bucket_name"]
|
||||||
DATABASE_NAME = CONFIG["mongodb"]["database_name"]
|
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
|
||||||
DOCUMENTS_COLLECTION = CONFIG["mongodb"]["documents_collection"]
|
DOCUMENTS_COLLECTION = CONFIG["database"]["mongodb"]["documents_collection"]
|
||||||
CHUNKS_COLLECTION = CONFIG["mongodb"]["chunks_collection"]
|
CHUNKS_COLLECTION = CONFIG["database"]["mongodb"]["chunks_collection"]
|
||||||
VECTOR_DIMENSIONS = CONFIG["mongodb"]["vector"]["dimensions"]
|
VECTOR_DIMENSIONS = CONFIG["vector_store"]["mongodb"]["dimensions"]
|
||||||
VECTOR_INDEX_NAME = CONFIG["mongodb"]["vector"]["index_name"]
|
VECTOR_INDEX_NAME = CONFIG["vector_store"]["mongodb"]["index_name"]
|
||||||
|
SIMILARITY_METRIC = CONFIG["vector_store"]["mongodb"]["similarity_metric"]
|
||||||
|
|
||||||
|
|
||||||
def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
|
def create_s3_bucket(bucket_name, region=DEFAULT_REGION):
|
||||||
@ -142,7 +143,7 @@ def setup_mongodb():
|
|||||||
{
|
{
|
||||||
"numDimensions": VECTOR_DIMENSIONS,
|
"numDimensions": VECTOR_DIMENSIONS,
|
||||||
"path": "embedding",
|
"path": "embedding",
|
||||||
"similarity": "dotProduct",
|
"similarity": SIMILARITY_METRIC,
|
||||||
"type": "vector",
|
"type": "vector",
|
||||||
},
|
},
|
||||||
{"path": "document_id", "type": "filter"},
|
{"path": "document_id", "type": "filter"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user