morphik-core/core/embedding/litellm_embedding.py
2025-04-08 20:06:18 -07:00

125 lines
4.1 KiB
Python

import logging
from typing import List, Union
import litellm
from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.models.chunk import Chunk
from core.config import get_settings
logger = logging.getLogger(__name__)
PGVECTOR_MAX_DIMENSIONS = 2000 # Maximum dimensions for pgvector
class LiteLLMEmbeddingModel(BaseEmbeddingModel):
"""
LiteLLM embedding model implementation that provides unified access to various embedding providers.
Uses registered models from the config file.
"""
def __init__(self, model_key: str):
"""
Initialize LiteLLM embedding model with a model key from registered_models.
Args:
model_key: The key of the model in the registered_models config
"""
settings = get_settings()
self.model_key = model_key
# Get the model configuration from registered_models
if (
not hasattr(settings, "REGISTERED_MODELS")
or model_key not in settings.REGISTERED_MODELS
):
raise ValueError(f"Model '{model_key}' not found in registered_models configuration")
self.model_config = settings.REGISTERED_MODELS[model_key]
self.dimensions = min(settings.VECTOR_DIMENSIONS, 2000)
logger.info(
f"Initialized LiteLLM embedding model with model_key={model_key}, config={self.model_config}"
)
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of documents using LiteLLM.
Args:
texts: List of text documents to embed
Returns:
List of embedding vectors (one per document)
"""
if not texts:
return []
try:
model_params = {"model": self.model_config["model_name"]}
if self.model_config["model_name"] in ["text-embedding-3-large", "azure/text-embedding-3-large"]:
model_params["dimensions"] = PGVECTOR_MAX_DIMENSIONS
# Add all model-specific parameters from the config
for key, value in self.model_config.items():
if key != "model_name": # Skip as we've already handled it
model_params[key] = value
# Call LiteLLM
response = await litellm.aembedding(input=texts, **model_params)
embeddings = [data["embedding"] for data in response.data]
# Validate dimensions
if embeddings and len(embeddings[0]) != self.dimensions:
logger.warning(
f"Embedding dimension mismatch: got {len(embeddings[0])}, expected {self.dimensions}. "
f"Please update your VECTOR_DIMENSIONS setting to match the actual dimension."
)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings with LiteLLM: {e}")
raise
async def embed_query(self, text: str) -> List[float]:
"""
Generate an embedding for a single query using LiteLLM.
Args:
text: Query text to embed
Returns:
Embedding vector
"""
result = await self.embed_documents([text])
if not result:
# In case of error, return zero vector
return [0.0] * self.dimensions
return result[0]
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
"""
Generate embeddings for chunks to be ingested into the vector store.
Args:
chunks: Single chunk or list of chunks to embed
Returns:
List of embedding vectors (one per chunk)
"""
if isinstance(chunks, Chunk):
chunks = [chunks]
texts = [chunk.content for chunk in chunks]
return await self.embed_documents(texts)
async def embed_for_query(self, text: str) -> List[float]:
"""
Generate embedding for a query.
Args:
text: Query text to embed
Returns:
Embedding vector
"""
return await self.embed_query(text)