import logging import re # Import re for parsing model name from typing import Any, Dict, List, Optional, Tuple, Union import litellm try: import ollama except ImportError: ollama = None # Make ollama import optional from pydantic import BaseModel from core.config import get_settings from core.models.completion import CompletionRequest, CompletionResponse from .base_completion import BaseCompletionModel logger = logging.getLogger(__name__) def get_system_message() -> Dict[str, str]: """Return the standard system message for Morphik's query agent.""" return { "role": "system", "content": """You are Morphik's powerful query agent. Your role is to: 1. Analyze the provided context chunks from documents carefully 2. Use the context to answer questions accurately and comprehensively 3. Be clear and concise in your answers 4. When relevant, cite specific parts of the context to support your answers 5. For image-based queries, analyze the visual content in conjunction with any text context provided Remember: Your primary goal is to provide accurate, context-aware responses that help users understand and utilize the information in their documents effectively.""", } def process_context_chunks(context_chunks: List[str], is_ollama: bool) -> Tuple[List[str], List[str], List[str]]: """ Process context chunks and separate text from images. Args: context_chunks: List of context chunks which may include images is_ollama: Whether we're using Ollama (affects image processing) Returns: Tuple of (context_text, image_urls, ollama_image_data) """ context_text = [] image_urls = [] # For non-Ollama models (full data URI) ollama_image_data = [] # For Ollama models (raw base64) for chunk in context_chunks: if chunk.startswith("data:image/"): if is_ollama: # For Ollama, strip the data URI prefix and just keep the base64 data try: base64_data = chunk.split(",", 1)[1] ollama_image_data.append(base64_data) except IndexError: logger.warning(f"Could not parse base64 data from image chunk: {chunk[:50]}...") else: image_urls.append(chunk) else: context_text.append(chunk) return context_text, image_urls, ollama_image_data def format_user_content(context_text: List[str], query: str, prompt_template: Optional[str] = None) -> str: """ Format the user content based on context and query. Args: context_text: List of context text chunks query: The user query prompt_template: Optional template to format the content Returns: Formatted user content string """ context = "\n" + "\n\n".join(context_text) + "\n\n" if context_text else "" if prompt_template: return prompt_template.format( context=context, question=query, query=query, ) elif context_text: return f"Context: {context} Question: {query}" else: return query def create_dynamic_model_from_schema(schema: Union[type, Dict]) -> Optional[type]: """ Create a dynamic Pydantic model from a schema definition. Args: schema: Either a Pydantic BaseModel class or a JSON schema dict Returns: A Pydantic model class or None if schema format is not recognized """ from pydantic import create_model if isinstance(schema, type) and issubclass(schema, BaseModel): return schema elif isinstance(schema, dict) and "properties" in schema: # Create a dynamic model from JSON schema field_definitions = {} schema_dict = schema for field_name, field_info in schema_dict.get("properties", {}).items(): if isinstance(field_info, dict) and "type" in field_info: field_type = field_info.get("type") # Convert schema types to Python types if field_type == "string": field_definitions[field_name] = (str, None) elif field_type == "number": field_definitions[field_name] = (float, None) elif field_type == "integer": field_definitions[field_name] = (int, None) elif field_type == "boolean": field_definitions[field_name] = (bool, None) elif field_type == "array": field_definitions[field_name] = (list, None) elif field_type == "object": field_definitions[field_name] = (dict, None) else: # Default to Any for unknown types field_definitions[field_name] = (Any, None) # Create the dynamic model return create_model("DynamicQueryModel", **field_definitions) else: logger.warning(f"Unrecognized schema format: {schema}") return None class LiteLLMCompletionModel(BaseCompletionModel): """ LiteLLM completion model implementation that provides unified access to various LLM providers. Uses registered models from the config file. Can optionally use direct Ollama client. """ def __init__(self, model_key: str): """ Initialize LiteLLM completion model with a model key from registered_models. Args: model_key: The key of the model in the registered_models config """ settings = get_settings() self.model_key = model_key # Get the model configuration from registered_models if not hasattr(settings, "REGISTERED_MODELS") or model_key not in settings.REGISTERED_MODELS: raise ValueError(f"Model '{model_key}' not found in registered_models configuration") self.model_config = settings.REGISTERED_MODELS[model_key] # Check if it's an Ollama model for potential direct usage self.is_ollama = "ollama" in self.model_config.get("model_name", "").lower() self.ollama_api_base = None self.ollama_base_model_name = None if self.is_ollama: if ollama is None: logger.warning("Ollama model selected, but 'ollama' library not installed. Falling back to LiteLLM.") self.is_ollama = False # Fallback to LiteLLM if library missing else: self.ollama_api_base = self.model_config.get("api_base") if not self.ollama_api_base: logger.warning( f"Ollama model {self.model_key} selected for direct use, " "but 'api_base' is missing in config. Falling back to LiteLLM." ) self.is_ollama = False # Fallback if api_base is missing else: # Extract base model name (e.g., 'llama3.2' from 'ollama_chat/llama3.2') match = re.search(r"[^/]+$", self.model_config["model_name"]) if match: self.ollama_base_model_name = match.group(0) else: logger.warning( f"Could not parse base model name from Ollama model " f"{self.model_config['model_name']}. Falling back to LiteLLM." ) self.is_ollama = False # Fallback if name parsing fails logger.info( f"Initialized LiteLLM completion model with model_key={model_key}, " f"config={self.model_config}, is_ollama_direct={self.is_ollama}" ) async def _handle_structured_ollama( self, dynamic_model: type, system_message: Dict[str, str], user_content: str, ollama_image_data: List[str], request: CompletionRequest, ) -> CompletionResponse: """Handle structured output generation with Ollama.""" try: client = ollama.AsyncClient(host=self.ollama_api_base) # Add images directly to content if available content_data = user_content if ollama_image_data and len(ollama_image_data) > 0: # Ollama image handling is limited; we can use only the first image content_data = {"content": user_content, "images": [ollama_image_data[0]]} # Create messages for Ollama messages = [system_message, {"role": "user", "content": content_data}] # Get the JSON schema from the dynamic model format_schema = dynamic_model.model_json_schema() # Call Ollama directly with format parameter response = await client.chat( model=self.ollama_base_model_name, messages=messages, format=format_schema, options={ "temperature": request.temperature or 0.1, # Lower temperature for structured output "num_predict": request.max_tokens, }, ) # Parse the response into the dynamic model parsed_response = dynamic_model.model_validate_json(response["message"]["content"]) # Extract token usage information usage = { "prompt_tokens": response.get("prompt_eval_count", 0), "completion_tokens": response.get("eval_count", 0), "total_tokens": response.get("prompt_eval_count", 0) + response.get("eval_count", 0), } return CompletionResponse( completion=parsed_response, usage=usage, finish_reason=response.get("done_reason", "stop"), ) except Exception as e: logger.error(f"Error using Ollama for structured output: {e}") # Fall back to standard completion if structured output fails logger.warning("Falling back to standard Ollama completion without structured output") return None async def _handle_structured_litellm( self, dynamic_model: type, system_message: Dict[str, str], user_content: str, image_urls: List[str], request: CompletionRequest, ) -> CompletionResponse: """Handle structured output generation with LiteLLM.""" import instructor from instructor import Mode try: # Use instructor with litellm client = instructor.from_litellm(litellm.acompletion, mode=Mode.JSON) # Create content list with text and images content_list = [{"type": "text", "text": user_content}] # Add images if available if image_urls: NUM_IMAGES = min(3, len(image_urls)) for img_url in image_urls[:NUM_IMAGES]: content_list.append({"type": "image_url", "image_url": {"url": img_url}}) # Create messages for instructor messages = [system_message, {"role": "user", "content": content_list}] # Extract model configuration model = self.model_config.get("model_name") model_kwargs = {k: v for k, v in self.model_config.items() if k != "model_name"} # Override with completion request parameters if request.temperature is not None: model_kwargs["temperature"] = request.temperature if request.max_tokens is not None: model_kwargs["max_tokens"] = request.max_tokens # Add format forcing for structured output model_kwargs["response_format"] = {"type": "json_object"} # Call instructor with litellm response = await client.chat.completions.create( model=model, messages=messages, response_model=dynamic_model, **model_kwargs, ) # Get token usage from response completion_tokens = model_kwargs.get("response_tokens", 0) prompt_tokens = model_kwargs.get("prompt_tokens", 0) return CompletionResponse( completion=response, usage={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, finish_reason="stop", ) except Exception as e: logger.error(f"Error using instructor with LiteLLM: {e}") # Fall back to standard completion if instructor fails logger.warning("Falling back to standard LiteLLM completion without structured output") return None async def _handle_standard_ollama( self, user_content: str, ollama_image_data: List[str], request: CompletionRequest ) -> CompletionResponse: """Handle standard (non-structured) output generation with Ollama.""" logger.debug(f"Using direct Ollama client for model: {self.ollama_base_model_name}") client = ollama.AsyncClient(host=self.ollama_api_base) # Construct Ollama messages system_message = {"role": "system", "content": get_system_message()["content"]} user_message_data = {"role": "user", "content": user_content} # Add images directly to the user message if available if ollama_image_data: if len(ollama_image_data) > 1: logger.warning( f"Ollama model {self.model_config['model_name']} only supports one image per message. " "Using the first image and ignoring others." ) # Add 'images' key inside the user message dictionary user_message_data["images"] = [ollama_image_data[0]] ollama_messages = [system_message, user_message_data] # Construct Ollama options options = { "temperature": request.temperature, "num_predict": ( request.max_tokens if request.max_tokens is not None else -1 ), # Default to model's default if None } try: response = await client.chat(model=self.ollama_base_model_name, messages=ollama_messages, options=options) # Map Ollama response to CompletionResponse prompt_tokens = response.get("prompt_eval_count", 0) completion_tokens = response.get("eval_count", 0) return CompletionResponse( completion=response["message"]["content"], usage={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, finish_reason=response.get("done_reason", "unknown"), # Map done_reason if available ) except Exception as e: logger.error(f"Error during direct Ollama call: {e}") raise async def _handle_standard_litellm( self, user_content: str, image_urls: List[str], request: CompletionRequest ) -> CompletionResponse: """Handle standard (non-structured) output generation with LiteLLM.""" logger.debug(f"Using LiteLLM for model: {self.model_config['model_name']}") # Build messages for LiteLLM content_list = [{"type": "text", "text": user_content}] include_images = image_urls # Use the collected full data URIs if include_images: NUM_IMAGES = min(3, len(image_urls)) for img_url in image_urls[:NUM_IMAGES]: content_list.append({"type": "image_url", "image_url": {"url": img_url}}) # LiteLLM uses list content format user_message = {"role": "user", "content": content_list} # Use the system prompt defined earlier litellm_messages = [get_system_message(), user_message] # Prepare LiteLLM parameters model_params = { "model": self.model_config["model_name"], "messages": litellm_messages, "max_tokens": request.max_tokens, "temperature": request.temperature, "num_retries": 3, } for key, value in self.model_config.items(): if key != "model_name": model_params[key] = value logger.debug(f"Calling LiteLLM with params: {model_params}") response = await litellm.acompletion(**model_params) return CompletionResponse( completion=response.choices[0].message.content, usage={ "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens, }, finish_reason=response.choices[0].finish_reason, ) async def complete(self, request: CompletionRequest) -> CompletionResponse: """ Generate completion using LiteLLM or direct Ollama client if configured. Args: request: CompletionRequest object containing query, context, and parameters Returns: CompletionResponse object with the generated text and usage statistics """ # Process context chunks and handle images context_text, image_urls, ollama_image_data = process_context_chunks(request.context_chunks, self.is_ollama) # Format user content user_content = format_user_content(context_text, request.query, request.prompt_template) # Check if structured output is requested structured_output = request.schema is not None # If structured output is requested, use instructor to handle it if structured_output: # Get dynamic model from schema dynamic_model = create_dynamic_model_from_schema(request.schema) # If schema format is not recognized, log warning and fall back to text completion if not dynamic_model: logger.warning(f"Unrecognized schema format: {request.schema}. Falling back to text completion.") structured_output = False else: logger.info(f"Using structured output with model: {dynamic_model.__name__}") # Create system and user messages with enhanced instructions for structured output system_message = { "role": "system", "content": get_system_message()["content"] + "\n\nYou MUST format your response according to the required schema.", } # Create enhanced user message that includes schema information enhanced_user_content = ( user_content + "\n\nPlease format your response according to the required schema." ) # Try structured output based on model type if self.is_ollama: response = await self._handle_structured_ollama( dynamic_model, system_message, enhanced_user_content, ollama_image_data, request ) if response: return response structured_output = False # Fall back if structured output failed else: response = await self._handle_structured_litellm( dynamic_model, system_message, enhanced_user_content, image_urls, request ) if response: return response structured_output = False # Fall back if structured output failed # If we're here, either structured output wasn't requested or instructor failed # Proceed with standard completion based on model type if self.is_ollama: return await self._handle_standard_ollama(user_content, ollama_image_data, request) else: return await self._handle_standard_litellm(user_content, image_urls, request)