262 lines
10 KiB
Python
Raw Normal View History

import logging
from abc import ABC, abstractmethod
2025-04-23 23:15:03 -07:00
from typing import Any, Dict, Literal, Optional
import litellm
from pydantic import BaseModel
from core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class BaseRule(BaseModel, ABC):
"""Base model for all rules"""
type: str
2025-04-23 23:15:03 -07:00
stage: Literal["post_parsing", "post_chunking"]
@abstractmethod
2025-04-23 23:15:03 -07:00
async def apply(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> tuple[Dict[str, Any], str]:
"""
Apply the rule to the content.
Args:
content: The content to apply the rule to
2025-04-23 23:15:03 -07:00
metadata: Optional existing metadata that may be used or modified by the rule
Returns:
tuple[Dict[str, Any], str]: (metadata, modified_content)
"""
pass
class MetadataOutput(BaseModel):
"""Model for metadata extraction results"""
# This model will be dynamically extended based on the schema
class MetadataExtractionRule(BaseRule):
"""Rule for extracting metadata using a schema"""
type: Literal["metadata_extraction"]
schema: Dict[str, Any]
2025-04-25 20:43:04 -07:00
use_images: bool = False
2025-04-23 23:15:03 -07:00
async def apply(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> tuple[Dict[str, Any], str]:
"""Extract metadata according to schema"""
import instructor
from pydantic import create_model
# Create a dynamic Pydantic model based on the schema
# This allows instructor to validate the output against our schema
field_definitions = {}
for field_name, field_info in self.schema.items():
if isinstance(field_info, dict) and "type" in field_info:
field_type = field_info.get("type")
# Convert schema types to Python types
if field_type == "string":
field_definitions[field_name] = (str, None)
elif field_type == "number":
field_definitions[field_name] = (float, None)
elif field_type == "integer":
field_definitions[field_name] = (int, None)
elif field_type == "boolean":
field_definitions[field_name] = (bool, None)
elif field_type == "array":
field_definitions[field_name] = (list, None)
elif field_type == "object":
field_definitions[field_name] = (dict, None)
else:
# Default to Any for unknown types
field_definitions[field_name] = (Any, None)
else:
# Default to Any if no type specified
field_definitions[field_name] = (Any, None)
# Create the dynamic model
DynamicMetadataModel = create_model("DynamicMetadataModel", **field_definitions)
2025-04-17 20:52:18 -07:00
# Create a more explicit instruction that clearly shows expected output format
schema_descriptions = []
for field_name, field_config in self.schema.items():
field_type = field_config.get("type", "string") if isinstance(field_config, dict) else "string"
description = (
field_config.get("description", "No description") if isinstance(field_config, dict) else field_config
)
2025-04-17 20:52:18 -07:00
schema_descriptions.append(f"- {field_name}: {description} (type: {field_type})")
2025-04-17 20:52:18 -07:00
schema_text = "\n".join(schema_descriptions)
2025-04-25 20:43:04 -07:00
# Adjust prompt based on whether it's a chunk or full document and whether it's an image
if self.use_images:
prompt_context = "image" if self.stage == "post_chunking" else "document with images"
else:
prompt_context = "chunk of text" if self.stage == "post_chunking" else "text"
2025-04-23 23:15:03 -07:00
prompt = f"""
2025-04-23 23:15:03 -07:00
Extract metadata from the following {prompt_context} according to this schema:
2025-04-17 20:52:18 -07:00
{schema_text}
2025-04-25 20:43:04 -07:00
{"Image to analyze:" if self.use_images else "Text to extract from:"}
{content}
2025-04-17 20:52:18 -07:00
Follow these guidelines:
2025-04-23 23:15:03 -07:00
1. Extract all requested information as simple strings, numbers, or booleans
(not as objects or nested structures)
2025-04-17 20:52:18 -07:00
2. If information is not present, indicate this with null instead of making something up
3. Answer directly with the requested information - don't include explanations or reasoning
4. Be concise but accurate in your extractions
"""
# Get the model configuration from registered_models
model_config = settings.REGISTERED_MODELS.get(settings.RULES_MODEL, {})
if not model_config:
raise ValueError(f"Model '{settings.RULES_MODEL}' not found in registered_models configuration")
2025-04-25 20:43:04 -07:00
# Prepare base64 data for vision model if this is an image rule
vision_messages = []
if self.use_images:
try:
# For image content, check if it's a base64 string
# Handle data URI format "data:image/png;base64,..."
if content.startswith("data:"):
content_type, content = content.split(";base64,", 1)
# User message with image content
vision_messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{content}"}},
],
}
]
except Exception as e:
logger.error(f"Error preparing image content for vision model: {str(e)}")
# Fall back to text-only if image processing fails
vision_messages = []
system_message = {
"role": "system",
2025-04-23 23:15:03 -07:00
"content": (
2025-04-25 20:43:04 -07:00
"You are a metadata extraction assistant. Extract structured metadata "
f"from {'images' if self.use_images else 'text'} "
2025-04-23 23:15:03 -07:00
"precisely following the provided schema. Always return the metadata as direct values "
"(strings, numbers, booleans), not as objects with additional properties."
),
}
2025-04-25 20:43:04 -07:00
# If we have vision messages, use those, otherwise use standard text message
messages = []
if vision_messages and self.use_images:
messages = [system_message] + vision_messages
else:
user_message = {"role": "user", "content": prompt}
messages = [system_message, user_message]
# Use instructor with litellm to get structured responses
client = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
try:
# Extract the model name for instructor call
model = model_config.get("model_name")
# Prepare additional parameters from model config
model_kwargs = {k: v for k, v in model_config.items() if k != "model_name"}
# Use instructor's client to create a structured response
response = await client.chat.completions.create(
model=model,
2025-04-25 20:43:04 -07:00
messages=messages,
response_model=DynamicMetadataModel,
**model_kwargs,
)
# Convert pydantic model to dict
2025-04-23 23:15:03 -07:00
extracted_metadata = response.model_dump()
except Exception as e:
logger.error(f"Error in instructor metadata extraction: {str(e)}")
2025-04-23 23:15:03 -07:00
extracted_metadata = {}
2025-04-23 23:15:03 -07:00
# Metadata extraction doesn't modify content
return extracted_metadata, content
class TransformationOutput(BaseModel):
"""Model for text transformation results"""
transformed_text: str
class NaturalLanguageRule(BaseRule):
"""Rule for transforming content using natural language"""
type: Literal["natural_language"]
prompt: str
2025-04-23 23:15:03 -07:00
async def apply(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> tuple[Dict[str, Any], str]:
"""Transform content according to prompt"""
import instructor
2025-04-23 23:15:03 -07:00
# Adjust prompt based on whether it's a chunk or full document
prompt_context = "chunk of text" if self.stage == "post_chunking" else "text"
prompt = f"""
2025-04-23 23:15:03 -07:00
Your task is to transform the following {prompt_context} according to this instruction:
{self.prompt}
Text to transform:
{content}
Perform the transformation and return only the transformed text.
"""
# Get the model configuration from registered_models
model_config = settings.REGISTERED_MODELS.get(settings.RULES_MODEL, {})
if not model_config:
raise ValueError(f"Model '{settings.RULES_MODEL}' not found in registered_models configuration")
system_message = {
"role": "system",
2025-04-23 23:15:03 -07:00
"content": (
"You are a text transformation assistant. Transform text precisely following "
"the provided instructions."
),
}
user_message = {"role": "user", "content": prompt}
# Use instructor with litellm to get structured responses
client = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
try:
# Extract the model name for instructor call
model = model_config.get("model_name")
# Prepare additional parameters from model config
model_kwargs = {k: v for k, v in model_config.items() if k != "model_name"}
# Use instructor's client to create a structured response
response = await client.chat.completions.create(
model=model,
messages=[system_message, user_message],
response_model=TransformationOutput,
**model_kwargs,
)
# Extract the transformed text from the response model
transformed_text = response.transformed_text
except Exception as e:
logger.error(f"Error in instructor text transformation: {str(e)}")
transformed_text = content # Return original content on error
2025-04-23 23:15:03 -07:00
# Natural language rules modify content, don't add metadata directly
return {}, transformed_text