mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
125 lines
4.1 KiB
Python
125 lines
4.1 KiB
Python
import logging
|
|
from typing import List, Union
|
|
import litellm
|
|
|
|
from core.embedding.base_embedding_model import BaseEmbeddingModel
|
|
from core.models.chunk import Chunk
|
|
from core.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
PGVECTOR_MAX_DIMENSIONS = 2000 # Maximum dimensions for pgvector
|
|
|
|
|
|
class LiteLLMEmbeddingModel(BaseEmbeddingModel):
|
|
"""
|
|
LiteLLM embedding model implementation that provides unified access to various embedding providers.
|
|
Uses registered models from the config file.
|
|
"""
|
|
|
|
def __init__(self, model_key: str):
|
|
"""
|
|
Initialize LiteLLM embedding model with a model key from registered_models.
|
|
|
|
Args:
|
|
model_key: The key of the model in the registered_models config
|
|
"""
|
|
settings = get_settings()
|
|
self.model_key = model_key
|
|
|
|
# Get the model configuration from registered_models
|
|
if (
|
|
not hasattr(settings, "REGISTERED_MODELS")
|
|
or model_key not in settings.REGISTERED_MODELS
|
|
):
|
|
raise ValueError(f"Model '{model_key}' not found in registered_models configuration")
|
|
|
|
self.model_config = settings.REGISTERED_MODELS[model_key]
|
|
self.dimensions = min(settings.VECTOR_DIMENSIONS, 2000)
|
|
logger.info(
|
|
f"Initialized LiteLLM embedding model with model_key={model_key}, config={self.model_config}"
|
|
)
|
|
|
|
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for a list of documents using LiteLLM.
|
|
|
|
Args:
|
|
texts: List of text documents to embed
|
|
|
|
Returns:
|
|
List of embedding vectors (one per document)
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
try:
|
|
model_params = {"model": self.model_config["model_name"]}
|
|
if self.model_config["model_name"] in ["text-embedding-3-large", "azure/text-embedding-3-large"]:
|
|
model_params["dimensions"] = PGVECTOR_MAX_DIMENSIONS
|
|
|
|
# Add all model-specific parameters from the config
|
|
for key, value in self.model_config.items():
|
|
if key != "model_name": # Skip as we've already handled it
|
|
model_params[key] = value
|
|
|
|
# Call LiteLLM
|
|
response = await litellm.aembedding(input=texts, **model_params)
|
|
|
|
embeddings = [data["embedding"] for data in response.data]
|
|
|
|
# Validate dimensions
|
|
if embeddings and len(embeddings[0]) != self.dimensions:
|
|
logger.warning(
|
|
f"Embedding dimension mismatch: got {len(embeddings[0])}, expected {self.dimensions}. "
|
|
f"Please update your VECTOR_DIMENSIONS setting to match the actual dimension."
|
|
)
|
|
|
|
return embeddings
|
|
except Exception as e:
|
|
logger.error(f"Error generating embeddings with LiteLLM: {e}")
|
|
raise
|
|
|
|
async def embed_query(self, text: str) -> List[float]:
|
|
"""
|
|
Generate an embedding for a single query using LiteLLM.
|
|
|
|
Args:
|
|
text: Query text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
result = await self.embed_documents([text])
|
|
if not result:
|
|
# In case of error, return zero vector
|
|
return [0.0] * self.dimensions
|
|
return result[0]
|
|
|
|
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for chunks to be ingested into the vector store.
|
|
|
|
Args:
|
|
chunks: Single chunk or list of chunks to embed
|
|
|
|
Returns:
|
|
List of embedding vectors (one per chunk)
|
|
"""
|
|
if isinstance(chunks, Chunk):
|
|
chunks = [chunks]
|
|
|
|
texts = [chunk.content for chunk in chunks]
|
|
return await self.embed_documents(texts)
|
|
|
|
async def embed_for_query(self, text: str) -> List[float]:
|
|
"""
|
|
Generate embedding for a query.
|
|
|
|
Args:
|
|
text: Query text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
return await self.embed_query(text)
|