2025-03-29 19:25:01 -07:00
|
|
|
import logging
|
2025-04-20 16:34:29 -07:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
|
|
|
|
from core.config import get_settings
|
|
|
|
from core.models.graph import Entity
|
2025-04-20 16:34:29 -07:00
|
|
|
from core.models.prompts import EntityResolutionPromptOverride
|
2025-03-29 19:25:01 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# Define Pydantic models for structured output
|
|
|
|
class EntityGroup(BaseModel):
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
canonical: str
|
|
|
|
variants: List[str]
|
|
|
|
|
|
|
|
|
|
|
|
class EntityResolutionResult(BaseModel):
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
entity_groups: List[EntityGroup]
|
|
|
|
|
|
|
|
|
|
|
|
class EntityResolver:
|
|
|
|
"""
|
|
|
|
Resolves and normalizes entities by identifying different variants of the same entity.
|
|
|
|
Handles cases like "Trump" vs "Donald J Trump" or "JFK" vs "Kennedy".
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
"""Initialize the entity resolver"""
|
|
|
|
self.settings = get_settings()
|
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
async def resolve_entities(
|
2025-04-08 00:19:47 -07:00
|
|
|
self,
|
|
|
|
entities: List[Entity],
|
|
|
|
prompt_overrides: Optional[EntityResolutionPromptOverride] = None,
|
2025-03-31 21:30:48 -07:00
|
|
|
) -> Tuple[List[Entity], Dict[str, str]]:
|
2025-03-29 19:25:01 -07:00
|
|
|
"""
|
|
|
|
Resolves entities by identifying and grouping entities that refer to the same real-world entity.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
entities: List of extracted entities
|
2025-03-31 21:30:48 -07:00
|
|
|
prompt_overrides: Optional EntityResolutionPromptOverride with customizations for entity resolution
|
2025-03-29 19:25:01 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple containing:
|
|
|
|
- List of normalized entities
|
|
|
|
- Dictionary mapping original entity text to canonical entity text
|
|
|
|
"""
|
|
|
|
if not entities:
|
|
|
|
return [], {}
|
|
|
|
|
|
|
|
# Extract entity labels to deduplicate
|
|
|
|
entity_labels = [e.label for e in entities]
|
|
|
|
|
|
|
|
# If there's only one entity, no need to resolve
|
|
|
|
if len(entity_labels) <= 1:
|
|
|
|
return entities, {entity_labels[0]: entity_labels[0]} if entity_labels else {}
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
# Extract relevant overrides for entity resolution if they exist
|
|
|
|
er_overrides = {}
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
# Convert prompt_overrides to dict for LLM request
|
|
|
|
if prompt_overrides:
|
|
|
|
# Convert EntityResolutionPromptOverride to dict
|
|
|
|
er_overrides = prompt_overrides.model_dump(exclude_none=True)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Use LLM to identify and group similar entities
|
2025-03-31 21:30:48 -07:00
|
|
|
resolved_entities = await self._resolve_with_llm(entity_labels, **er_overrides)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Create mapping from original to canonical forms
|
|
|
|
entity_mapping = {}
|
|
|
|
for group in resolved_entities:
|
|
|
|
canonical = group["canonical"]
|
|
|
|
for variant in group["variants"]:
|
|
|
|
entity_mapping[variant] = canonical
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Deduplicate entities based on mapping
|
|
|
|
unique_entities = []
|
|
|
|
label_to_entity_map = {}
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# First, create a map from labels to entities
|
|
|
|
for entity in entities:
|
|
|
|
canonical_label = entity_mapping.get(entity.label, entity.label)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# If we haven't seen this canonical label yet, add the entity to our map
|
|
|
|
if canonical_label not in label_to_entity_map:
|
|
|
|
label_to_entity_map[canonical_label] = entity
|
|
|
|
else:
|
|
|
|
# If we have seen it, merge chunk sources
|
|
|
|
existing_entity = label_to_entity_map[canonical_label]
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Merge document IDs
|
|
|
|
for doc_id in entity.document_ids:
|
|
|
|
if doc_id not in existing_entity.document_ids:
|
|
|
|
existing_entity.document_ids.append(doc_id)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Merge chunk sources
|
|
|
|
for doc_id, chunk_numbers in entity.chunk_sources.items():
|
|
|
|
if doc_id not in existing_entity.chunk_sources:
|
|
|
|
existing_entity.chunk_sources[doc_id] = list(chunk_numbers)
|
|
|
|
else:
|
|
|
|
for chunk_num in chunk_numbers:
|
|
|
|
if chunk_num not in existing_entity.chunk_sources[doc_id]:
|
|
|
|
existing_entity.chunk_sources[doc_id].append(chunk_num)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Merge properties (optional)
|
|
|
|
for key, value in entity.properties.items():
|
|
|
|
if key not in existing_entity.properties:
|
|
|
|
existing_entity.properties[key] = value
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Now, update the labels and create unique entities list
|
|
|
|
for canonical_label, entity in label_to_entity_map.items():
|
|
|
|
# Update the entity label to the canonical form
|
|
|
|
entity.label = canonical_label
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
# Add aliases property if there are variants
|
2025-04-08 00:19:47 -07:00
|
|
|
variants = [
|
|
|
|
variant
|
|
|
|
for variant, canon in entity_mapping.items()
|
|
|
|
if canon == canonical_label and variant != canonical_label
|
|
|
|
]
|
2025-03-29 19:25:01 -07:00
|
|
|
if variants:
|
2025-04-08 00:19:47 -07:00
|
|
|
if "aliases" not in entity.properties:
|
|
|
|
entity.properties["aliases"] = []
|
|
|
|
entity.properties["aliases"].extend(variants)
|
2025-03-29 19:25:01 -07:00
|
|
|
# Deduplicate aliases
|
2025-04-08 00:19:47 -07:00
|
|
|
entity.properties["aliases"] = list(set(entity.properties["aliases"]))
|
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
unique_entities.append(entity)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
return unique_entities, entity_mapping
|
2025-04-08 00:19:47 -07:00
|
|
|
|
|
|
|
async def _resolve_with_llm(
|
|
|
|
self, entity_labels: List[str], prompt_template=None, examples=None, **options
|
|
|
|
) -> List[Dict[str, Any]]:
|
2025-03-29 19:25:01 -07:00
|
|
|
"""
|
2025-04-08 00:19:47 -07:00
|
|
|
Uses LLM to identify and group similar entities using litellm.
|
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
Args:
|
|
|
|
entity_labels: List of entity labels to resolve
|
2025-03-31 21:30:48 -07:00
|
|
|
prompt_template: Optional custom prompt template
|
|
|
|
examples: Optional custom examples for entity resolution
|
|
|
|
**options: Additional options for entity resolution
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
Returns:
|
|
|
|
List of entity groups, where each group is a dict with:
|
|
|
|
- "canonical": The canonical form of the entity
|
|
|
|
- "variants": List of variant forms of the entity
|
|
|
|
"""
|
2025-04-08 00:19:47 -07:00
|
|
|
# Import these here to avoid circular imports
|
|
|
|
import instructor
|
|
|
|
import litellm
|
|
|
|
|
|
|
|
# Create the prompt for entity resolution
|
2025-03-31 21:30:48 -07:00
|
|
|
prompt = self._create_entity_resolution_prompt(
|
2025-04-08 00:19:47 -07:00
|
|
|
entity_labels, prompt_template=prompt_template, examples=examples
|
2025-03-31 21:30:48 -07:00
|
|
|
)
|
2025-03-29 19:25:01 -07:00
|
|
|
|
2025-04-08 00:19:47 -07:00
|
|
|
# Get the model configuration from registered_models
|
|
|
|
model_config = self.settings.REGISTERED_MODELS.get(self.settings.GRAPH_MODEL, {})
|
|
|
|
if not model_config:
|
2025-04-20 16:34:29 -07:00
|
|
|
logger.error(f"Model '{self.settings.GRAPH_MODEL}' not found in registered_models configuration")
|
2025-04-08 00:19:47 -07:00
|
|
|
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
2025-03-29 19:25:01 -07:00
|
|
|
|
|
|
|
system_message = {
|
|
|
|
"role": "system",
|
2025-04-08 00:19:47 -07:00
|
|
|
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity.",
|
2025-03-29 19:25:01 -07:00
|
|
|
}
|
2025-04-08 00:19:47 -07:00
|
|
|
|
|
|
|
user_message = {"role": "user", "content": prompt}
|
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
try:
|
2025-04-08 00:19:47 -07:00
|
|
|
# Use instructor with litellm for structured output
|
|
|
|
client = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
|
|
|
|
|
|
|
# Extract model name and prepare parameters
|
|
|
|
model = model_config.get("model_name")
|
|
|
|
model_kwargs = {k: v for k, v in model_config.items() if k != "model_name"}
|
|
|
|
|
|
|
|
# Get structured response using instructor
|
|
|
|
response = await client.chat.completions.create(
|
|
|
|
model=model,
|
|
|
|
messages=[system_message, user_message],
|
|
|
|
response_model=EntityResolutionResult,
|
|
|
|
**model_kwargs,
|
2025-03-29 19:25:01 -07:00
|
|
|
)
|
2025-04-08 00:19:47 -07:00
|
|
|
|
|
|
|
# Extract entity groups from the Pydantic model
|
|
|
|
return [group.model_dump() for group in response.entity_groups]
|
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
except Exception as e:
|
2025-04-08 00:19:47 -07:00
|
|
|
logger.error(f"Error during entity resolution with litellm: {str(e)}")
|
|
|
|
# Fallback: treat each entity as unique
|
|
|
|
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
|
|
|
|
2025-04-20 16:34:29 -07:00
|
|
|
def _create_entity_resolution_prompt(self, entity_labels: List[str], prompt_template=None, examples=None) -> str:
|
2025-03-29 19:25:01 -07:00
|
|
|
"""
|
|
|
|
Creates a prompt for the LLM to resolve entities.
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
Args:
|
|
|
|
entity_labels: List of entity labels to resolve
|
2025-03-31 21:30:48 -07:00
|
|
|
prompt_template: Optional custom prompt template
|
|
|
|
examples: Optional custom examples for entity resolution
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-29 19:25:01 -07:00
|
|
|
Returns:
|
|
|
|
Prompt string for the LLM
|
|
|
|
"""
|
|
|
|
entities_str = "\n".join([f"- {label}" for label in entity_labels])
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
# Use custom examples if provided, otherwise use defaults
|
|
|
|
if examples is not None:
|
|
|
|
# Ensure proper serialization for both dict and Pydantic model examples
|
2025-04-08 00:19:47 -07:00
|
|
|
if isinstance(examples, list) and examples and hasattr(examples[0], "model_dump"):
|
2025-03-31 21:30:48 -07:00
|
|
|
# List of Pydantic model objects
|
|
|
|
serialized_examples = [example.model_dump() for example in examples]
|
|
|
|
else:
|
|
|
|
# List of dictionaries
|
|
|
|
serialized_examples = examples
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
entities_example_dict = {"entity_groups": serialized_examples}
|
|
|
|
else:
|
|
|
|
entities_example_dict = {
|
|
|
|
"entity_groups": [
|
|
|
|
{
|
|
|
|
"canonical": "John F. Kennedy",
|
2025-04-08 00:19:47 -07:00
|
|
|
"variants": ["John F. Kennedy", "JFK", "Kennedy"],
|
2025-03-31 21:30:48 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"canonical": "United States of America",
|
2025-04-08 00:19:47 -07:00
|
|
|
"variants": ["United States of America", "USA", "United States"],
|
|
|
|
},
|
2025-03-31 21:30:48 -07:00
|
|
|
]
|
|
|
|
}
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
# If a custom template is provided, use it
|
|
|
|
if prompt_template:
|
|
|
|
# Format the custom template with our variables
|
2025-04-20 16:34:29 -07:00
|
|
|
return prompt_template.format(entities_str=entities_str, examples_json=str(entities_example_dict))
|
2025-04-08 00:19:47 -07:00
|
|
|
|
2025-03-31 21:30:48 -07:00
|
|
|
# Otherwise use the default prompt
|
2025-03-29 19:25:01 -07:00
|
|
|
prompt = f"""
|
|
|
|
Below is a list of entities extracted from a document:
|
|
|
|
|
|
|
|
{entities_str}
|
|
|
|
|
2025-04-19 15:24:53 -07:00
|
|
|
Your task is to identify entities in this list that refer to the EXACT SAME real-world concept or object, differing ONLY by name, abbreviation, or minor spelling variations. Group these synonymous entities together.
|
|
|
|
|
|
|
|
**CRITICAL RULES:**
|
|
|
|
1. **Synonymy ONLY:** Only group entities if they are truly synonymous (e.g., "JFK", "John F. Kennedy", "Kennedy").
|
|
|
|
2. **DO NOT Group Related Concepts:** DO NOT group distinct entities even if they are related. For example:
|
|
|
|
* A company and its products (e.g., "Apple" and "iPhone" must remain separate).
|
|
|
|
* An organization and its specific projects or vehicles (e.g., "SpaceX", "Falcon 9", and "Starship" must remain separate).
|
|
|
|
* A person and their title (e.g. "Elon Musk" and "CEO" must remain separate unless the list only contained variations like "CEO Musk").
|
|
|
|
3. **Canonical Form:** For each group of synonyms, choose the most complete and formal name as the "canonical" form.
|
|
|
|
4. **Omit Unique Entities:** If an entity has no synonyms in the provided list, DO NOT include it in the output JSON. The output should only contain groups of two or more synonymous entities.
|
|
|
|
|
|
|
|
**Output Format:**
|
|
|
|
Format your response as a JSON object containing a single key "entity_groups", which is an array. Each element in the array represents a group of synonyms and must have:
|
|
|
|
- "canonical": The chosen standard form.
|
|
|
|
- "variants": An array of all synonymous variants found in the input list (including the canonical form).
|
2025-03-29 19:25:01 -07:00
|
|
|
|
|
|
|
```json
|
|
|
|
{str(entities_example_dict)}
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
return prompt
|