morphik-core/core/services/entity_resolution.py

270 lines
11 KiB
Python
Raw Normal View History

2025-03-29 19:25:01 -07:00
import logging
from typing import Any, Dict, List, Optional, Tuple
2025-03-29 19:25:01 -07:00
from pydantic import BaseModel, ConfigDict
from core.config import get_settings
from core.models.graph import Entity
from core.models.prompts import EntityResolutionPromptOverride
2025-03-29 19:25:01 -07:00
logger = logging.getLogger(__name__)
# Define Pydantic models for structured output
class EntityGroup(BaseModel):
model_config = ConfigDict(extra="forbid")
canonical: str
variants: List[str]
class EntityResolutionResult(BaseModel):
model_config = ConfigDict(extra="forbid")
entity_groups: List[EntityGroup]
class EntityResolver:
"""
Resolves and normalizes entities by identifying different variants of the same entity.
Handles cases like "Trump" vs "Donald J Trump" or "JFK" vs "Kennedy".
"""
def __init__(self):
"""Initialize the entity resolver"""
self.settings = get_settings()
async def resolve_entities(
self,
entities: List[Entity],
prompt_overrides: Optional[EntityResolutionPromptOverride] = None,
) -> Tuple[List[Entity], Dict[str, str]]:
2025-03-29 19:25:01 -07:00
"""
Resolves entities by identifying and grouping entities that refer to the same real-world entity.
Args:
entities: List of extracted entities
prompt_overrides: Optional EntityResolutionPromptOverride with customizations for entity resolution
2025-03-29 19:25:01 -07:00
Returns:
Tuple containing:
- List of normalized entities
- Dictionary mapping original entity text to canonical entity text
"""
if not entities:
return [], {}
# Extract entity labels to deduplicate
entity_labels = [e.label for e in entities]
# If there's only one entity, no need to resolve
if len(entity_labels) <= 1:
return entities, {entity_labels[0]: entity_labels[0]} if entity_labels else {}
# Extract relevant overrides for entity resolution if they exist
er_overrides = {}
# Convert prompt_overrides to dict for LLM request
if prompt_overrides:
# Convert EntityResolutionPromptOverride to dict
er_overrides = prompt_overrides.model_dump(exclude_none=True)
2025-03-29 19:25:01 -07:00
# Use LLM to identify and group similar entities
resolved_entities = await self._resolve_with_llm(entity_labels, **er_overrides)
2025-03-29 19:25:01 -07:00
# Create mapping from original to canonical forms
entity_mapping = {}
for group in resolved_entities:
canonical = group["canonical"]
for variant in group["variants"]:
entity_mapping[variant] = canonical
2025-03-29 19:25:01 -07:00
# Deduplicate entities based on mapping
unique_entities = []
label_to_entity_map = {}
2025-03-29 19:25:01 -07:00
# First, create a map from labels to entities
for entity in entities:
canonical_label = entity_mapping.get(entity.label, entity.label)
2025-03-29 19:25:01 -07:00
# If we haven't seen this canonical label yet, add the entity to our map
if canonical_label not in label_to_entity_map:
label_to_entity_map[canonical_label] = entity
else:
# If we have seen it, merge chunk sources
existing_entity = label_to_entity_map[canonical_label]
2025-03-29 19:25:01 -07:00
# Merge document IDs
for doc_id in entity.document_ids:
if doc_id not in existing_entity.document_ids:
existing_entity.document_ids.append(doc_id)
2025-03-29 19:25:01 -07:00
# Merge chunk sources
for doc_id, chunk_numbers in entity.chunk_sources.items():
if doc_id not in existing_entity.chunk_sources:
existing_entity.chunk_sources[doc_id] = list(chunk_numbers)
else:
for chunk_num in chunk_numbers:
if chunk_num not in existing_entity.chunk_sources[doc_id]:
existing_entity.chunk_sources[doc_id].append(chunk_num)
2025-03-29 19:25:01 -07:00
# Merge properties (optional)
for key, value in entity.properties.items():
if key not in existing_entity.properties:
existing_entity.properties[key] = value
2025-03-29 19:25:01 -07:00
# Now, update the labels and create unique entities list
for canonical_label, entity in label_to_entity_map.items():
# Update the entity label to the canonical form
entity.label = canonical_label
2025-03-29 19:25:01 -07:00
# Add aliases property if there are variants
variants = [
variant
for variant, canon in entity_mapping.items()
if canon == canonical_label and variant != canonical_label
]
2025-03-29 19:25:01 -07:00
if variants:
if "aliases" not in entity.properties:
entity.properties["aliases"] = []
entity.properties["aliases"].extend(variants)
2025-03-29 19:25:01 -07:00
# Deduplicate aliases
entity.properties["aliases"] = list(set(entity.properties["aliases"]))
2025-03-29 19:25:01 -07:00
unique_entities.append(entity)
2025-03-29 19:25:01 -07:00
return unique_entities, entity_mapping
async def _resolve_with_llm(
self, entity_labels: List[str], prompt_template=None, examples=None, **options
) -> List[Dict[str, Any]]:
2025-03-29 19:25:01 -07:00
"""
Uses LLM to identify and group similar entities using litellm.
2025-03-29 19:25:01 -07:00
Args:
entity_labels: List of entity labels to resolve
prompt_template: Optional custom prompt template
examples: Optional custom examples for entity resolution
**options: Additional options for entity resolution
2025-03-29 19:25:01 -07:00
Returns:
List of entity groups, where each group is a dict with:
- "canonical": The canonical form of the entity
- "variants": List of variant forms of the entity
"""
# Import these here to avoid circular imports
import instructor
import litellm
# Create the prompt for entity resolution
prompt = self._create_entity_resolution_prompt(
entity_labels, prompt_template=prompt_template, examples=examples
)
2025-03-29 19:25:01 -07:00
# Get the model configuration from registered_models
model_config = self.settings.REGISTERED_MODELS.get(self.settings.GRAPH_MODEL, {})
if not model_config:
logger.error(f"Model '{self.settings.GRAPH_MODEL}' not found in registered_models configuration")
return [{"canonical": label, "variants": [label]} for label in entity_labels]
2025-03-29 19:25:01 -07:00
system_message = {
"role": "system",
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity.",
2025-03-29 19:25:01 -07:00
}
user_message = {"role": "user", "content": prompt}
2025-03-29 19:25:01 -07:00
try:
# Use instructor with litellm for structured output
client = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
# Extract model name and prepare parameters
model = model_config.get("model_name")
model_kwargs = {k: v for k, v in model_config.items() if k != "model_name"}
# Get structured response using instructor
response = await client.chat.completions.create(
model=model,
messages=[system_message, user_message],
response_model=EntityResolutionResult,
**model_kwargs,
2025-03-29 19:25:01 -07:00
)
# Extract entity groups from the Pydantic model
return [group.model_dump() for group in response.entity_groups]
2025-03-29 19:25:01 -07:00
except Exception as e:
logger.error(f"Error during entity resolution with litellm: {str(e)}")
# Fallback: treat each entity as unique
return [{"canonical": label, "variants": [label]} for label in entity_labels]
def _create_entity_resolution_prompt(self, entity_labels: List[str], prompt_template=None, examples=None) -> str:
2025-03-29 19:25:01 -07:00
"""
Creates a prompt for the LLM to resolve entities.
2025-03-29 19:25:01 -07:00
Args:
entity_labels: List of entity labels to resolve
prompt_template: Optional custom prompt template
examples: Optional custom examples for entity resolution
2025-03-29 19:25:01 -07:00
Returns:
Prompt string for the LLM
"""
entities_str = "\n".join([f"- {label}" for label in entity_labels])
# Use custom examples if provided, otherwise use defaults
if examples is not None:
# Ensure proper serialization for both dict and Pydantic model examples
if isinstance(examples, list) and examples and hasattr(examples[0], "model_dump"):
# List of Pydantic model objects
serialized_examples = [example.model_dump() for example in examples]
else:
# List of dictionaries
serialized_examples = examples
entities_example_dict = {"entity_groups": serialized_examples}
else:
entities_example_dict = {
"entity_groups": [
{
"canonical": "John F. Kennedy",
"variants": ["John F. Kennedy", "JFK", "Kennedy"],
},
{
"canonical": "United States of America",
"variants": ["United States of America", "USA", "United States"],
},
]
}
# If a custom template is provided, use it
if prompt_template:
# Format the custom template with our variables
return prompt_template.format(entities_str=entities_str, examples_json=str(entities_example_dict))
# Otherwise use the default prompt
2025-03-29 19:25:01 -07:00
prompt = f"""
Below is a list of entities extracted from a document:
{entities_str}
Your task is to identify entities in this list that refer to the EXACT SAME real-world concept or object, differing ONLY by name, abbreviation, or minor spelling variations. Group these synonymous entities together.
**CRITICAL RULES:**
1. **Synonymy ONLY:** Only group entities if they are truly synonymous (e.g., "JFK", "John F. Kennedy", "Kennedy").
2. **DO NOT Group Related Concepts:** DO NOT group distinct entities even if they are related. For example:
* A company and its products (e.g., "Apple" and "iPhone" must remain separate).
* An organization and its specific projects or vehicles (e.g., "SpaceX", "Falcon 9", and "Starship" must remain separate).
* A person and their title (e.g. "Elon Musk" and "CEO" must remain separate unless the list only contained variations like "CEO Musk").
3. **Canonical Form:** For each group of synonyms, choose the most complete and formal name as the "canonical" form.
4. **Omit Unique Entities:** If an entity has no synonyms in the provided list, DO NOT include it in the output JSON. The output should only contain groups of two or more synonymous entities.
**Output Format:**
Format your response as a JSON object containing a single key "entity_groups", which is an array. Each element in the array represents a group of synonyms and must have:
- "canonical": The chosen standard form.
- "variants": An array of all synonymous variants found in the input list (including the canonical form).
2025-03-29 19:25:01 -07:00
```json
{str(entities_example_dict)}
```
"""
return prompt