mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add entity resolution (#58)
This commit is contained in:
parent
9ce0507616
commit
2436a11f29
@ -67,6 +67,7 @@ class Settings(BaseSettings):
|
|||||||
# Graph configuration
|
# Graph configuration
|
||||||
GRAPH_PROVIDER: Literal["ollama", "openai"]
|
GRAPH_PROVIDER: Literal["ollama", "openai"]
|
||||||
GRAPH_MODEL: str
|
GRAPH_MODEL: str
|
||||||
|
ENABLE_ENTITY_RESOLUTION: bool = True
|
||||||
|
|
||||||
# Reranker configuration
|
# Reranker configuration
|
||||||
USE_RERANKING: bool
|
USE_RERANKING: bool
|
||||||
@ -301,6 +302,7 @@ def get_settings() -> Settings:
|
|||||||
graph_config = {
|
graph_config = {
|
||||||
"GRAPH_PROVIDER": config["graph"]["provider"],
|
"GRAPH_PROVIDER": config["graph"]["provider"],
|
||||||
"GRAPH_MODEL": config["graph"]["model_name"],
|
"GRAPH_MODEL": config["graph"]["model_name"],
|
||||||
|
"ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True),
|
||||||
}
|
}
|
||||||
|
|
||||||
# load telemetry config
|
# load telemetry config
|
||||||
|
344
core/services/entity_resolution.py
Normal file
344
core/services/entity_resolution.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
import httpx
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from core.config import get_settings
|
||||||
|
from core.models.graph import Entity
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Define Pydantic models for structured output
|
||||||
|
class EntityGroup(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
canonical: str
|
||||||
|
variants: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class EntityResolutionResult(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
entity_groups: List[EntityGroup]
|
||||||
|
|
||||||
|
|
||||||
|
class EntityResolver:
|
||||||
|
"""
|
||||||
|
Resolves and normalizes entities by identifying different variants of the same entity.
|
||||||
|
Handles cases like "Trump" vs "Donald J Trump" or "JFK" vs "Kennedy".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the entity resolver"""
|
||||||
|
self.settings = get_settings()
|
||||||
|
|
||||||
|
async def resolve_entities(self, entities: List[Entity]) -> Tuple[List[Entity], Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Resolves entities by identifying and grouping entities that refer to the same real-world entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entities: List of extracted entities
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- List of normalized entities
|
||||||
|
- Dictionary mapping original entity text to canonical entity text
|
||||||
|
"""
|
||||||
|
if not entities:
|
||||||
|
return [], {}
|
||||||
|
|
||||||
|
# Extract entity labels to deduplicate
|
||||||
|
entity_labels = [e.label for e in entities]
|
||||||
|
|
||||||
|
# If there's only one entity, no need to resolve
|
||||||
|
if len(entity_labels) <= 1:
|
||||||
|
return entities, {entity_labels[0]: entity_labels[0]} if entity_labels else {}
|
||||||
|
|
||||||
|
# Use LLM to identify and group similar entities
|
||||||
|
resolved_entities = await self._resolve_with_llm(entity_labels)
|
||||||
|
|
||||||
|
# Create mapping from original to canonical forms
|
||||||
|
entity_mapping = {}
|
||||||
|
for group in resolved_entities:
|
||||||
|
canonical = group["canonical"]
|
||||||
|
for variant in group["variants"]:
|
||||||
|
entity_mapping[variant] = canonical
|
||||||
|
|
||||||
|
# Deduplicate entities based on mapping
|
||||||
|
unique_entities = []
|
||||||
|
label_to_entity_map = {}
|
||||||
|
|
||||||
|
# First, create a map from labels to entities
|
||||||
|
for entity in entities:
|
||||||
|
canonical_label = entity_mapping.get(entity.label, entity.label)
|
||||||
|
|
||||||
|
# If we haven't seen this canonical label yet, add the entity to our map
|
||||||
|
if canonical_label not in label_to_entity_map:
|
||||||
|
label_to_entity_map[canonical_label] = entity
|
||||||
|
else:
|
||||||
|
# If we have seen it, merge chunk sources
|
||||||
|
existing_entity = label_to_entity_map[canonical_label]
|
||||||
|
|
||||||
|
# Merge document IDs
|
||||||
|
for doc_id in entity.document_ids:
|
||||||
|
if doc_id not in existing_entity.document_ids:
|
||||||
|
existing_entity.document_ids.append(doc_id)
|
||||||
|
|
||||||
|
# Merge chunk sources
|
||||||
|
for doc_id, chunk_numbers in entity.chunk_sources.items():
|
||||||
|
if doc_id not in existing_entity.chunk_sources:
|
||||||
|
existing_entity.chunk_sources[doc_id] = list(chunk_numbers)
|
||||||
|
else:
|
||||||
|
for chunk_num in chunk_numbers:
|
||||||
|
if chunk_num not in existing_entity.chunk_sources[doc_id]:
|
||||||
|
existing_entity.chunk_sources[doc_id].append(chunk_num)
|
||||||
|
|
||||||
|
# Merge properties (optional)
|
||||||
|
for key, value in entity.properties.items():
|
||||||
|
if key not in existing_entity.properties:
|
||||||
|
existing_entity.properties[key] = value
|
||||||
|
|
||||||
|
# Now, update the labels and create unique entities list
|
||||||
|
for canonical_label, entity in label_to_entity_map.items():
|
||||||
|
# Update the entity label to the canonical form
|
||||||
|
entity.label = canonical_label
|
||||||
|
|
||||||
|
# Add aliases property if there are variants
|
||||||
|
variants = [variant for variant, canon in entity_mapping.items()
|
||||||
|
if canon == canonical_label and variant != canonical_label]
|
||||||
|
if variants:
|
||||||
|
if 'aliases' not in entity.properties:
|
||||||
|
entity.properties['aliases'] = []
|
||||||
|
entity.properties['aliases'].extend(variants)
|
||||||
|
# Deduplicate aliases
|
||||||
|
entity.properties['aliases'] = list(set(entity.properties['aliases']))
|
||||||
|
|
||||||
|
unique_entities.append(entity)
|
||||||
|
|
||||||
|
return unique_entities, entity_mapping
|
||||||
|
|
||||||
|
async def _resolve_with_llm(self, entity_labels: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Uses LLM to identify and group similar entities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_labels: List of entity labels to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of entity groups, where each group is a dict with:
|
||||||
|
- "canonical": The canonical form of the entity
|
||||||
|
- "variants": List of variant forms of the entity
|
||||||
|
"""
|
||||||
|
prompt = self._create_entity_resolution_prompt(entity_labels)
|
||||||
|
|
||||||
|
if self.settings.GRAPH_PROVIDER == "openai":
|
||||||
|
return await self._resolve_with_openai(prompt, entity_labels)
|
||||||
|
elif self.settings.GRAPH_PROVIDER == "ollama":
|
||||||
|
return await self._resolve_with_ollama(prompt, entity_labels)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unsupported graph provider: {self.settings.GRAPH_PROVIDER}")
|
||||||
|
# Fallback: treat each entity as unique
|
||||||
|
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
||||||
|
|
||||||
|
async def _resolve_with_openai(self, prompt: str, entity_labels: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Resolves entities using OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt for entity resolution
|
||||||
|
entity_labels: Original entity labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of entity groups
|
||||||
|
"""
|
||||||
|
# Define the schema directly instead of converting from Pydantic model
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"entity_groups": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"canonical": {"type": "string"},
|
||||||
|
"variants": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["canonical", "variants"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["entity_groups"],
|
||||||
|
"additionalProperties": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# System and user messages
|
||||||
|
system_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity."
|
||||||
|
}
|
||||||
|
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = AsyncOpenAI(api_key=self.settings.OPENAI_API_KEY)
|
||||||
|
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=self.settings.GRAPH_MODEL,
|
||||||
|
input=[system_message, user_message],
|
||||||
|
text={
|
||||||
|
"format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"name": "entity_resolution",
|
||||||
|
"schema": json_schema,
|
||||||
|
"strict": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process response based on current API structure
|
||||||
|
if hasattr(response, 'output_text') and response.output_text:
|
||||||
|
content = response.output_text
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse with Pydantic model for validation
|
||||||
|
parsed_data = json.loads(content)
|
||||||
|
|
||||||
|
# Validate with Pydantic model
|
||||||
|
resolution_result = EntityResolutionResult.model_validate(parsed_data)
|
||||||
|
|
||||||
|
# Extract entity groups
|
||||||
|
return [group.model_dump() for group in resolution_result.entity_groups]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error parsing entity resolution response: %r", e)
|
||||||
|
|
||||||
|
# Fallback to direct JSON parsing if Pydantic validation fails
|
||||||
|
try:
|
||||||
|
parsed_data = json.loads(content)
|
||||||
|
return parsed_data.get("entity_groups", [])
|
||||||
|
except Exception:
|
||||||
|
logger.error("JSON parsing failed, falling back to default resolution")
|
||||||
|
|
||||||
|
elif hasattr(response, 'refusal') and response.refusal:
|
||||||
|
logger.warning("OpenAI refused to resolve entities")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during entity resolution with OpenAI: %r", e)
|
||||||
|
|
||||||
|
# Fallback: treat each entity as unique
|
||||||
|
logger.info("Falling back to treating each entity as unique")
|
||||||
|
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
||||||
|
|
||||||
|
async def _resolve_with_ollama(self, prompt: str, entity_labels: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Resolves entities using Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt for entity resolution
|
||||||
|
entity_labels: Original entity labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of entity groups
|
||||||
|
"""
|
||||||
|
system_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity."
|
||||||
|
}
|
||||||
|
|
||||||
|
user_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
# Create the schema for structured output
|
||||||
|
format_schema = EntityResolutionResult.model_json_schema()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.settings.EMBEDDING_OLLAMA_BASE_URL}/api/chat",
|
||||||
|
json={
|
||||||
|
"model": self.settings.GRAPH_MODEL,
|
||||||
|
"messages": [system_message, user_message],
|
||||||
|
"stream": False,
|
||||||
|
"format": format_schema
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Parse with Pydantic model for validation
|
||||||
|
resolution_result = EntityResolutionResult.model_validate_json(result["message"]["content"])
|
||||||
|
return [group.model_dump() for group in resolution_result.entity_groups]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during entity resolution with Ollama: %r", e)
|
||||||
|
|
||||||
|
# Fallback: treat each entity as unique
|
||||||
|
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
||||||
|
|
||||||
|
def _create_entity_resolution_prompt(self, entity_labels: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Creates a prompt for the LLM to resolve entities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_labels: List of entity labels to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt string for the LLM
|
||||||
|
"""
|
||||||
|
entities_str = "\n".join([f"- {label}" for label in entity_labels])
|
||||||
|
entities_example_dict = {
|
||||||
|
"entity_groups": [
|
||||||
|
{
|
||||||
|
"canonical": "John F. Kennedy",
|
||||||
|
"variants": ["John F. Kennedy", "JFK", "Kennedy"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"canonical": "United States of America",
|
||||||
|
"variants": ["United States of America", "USA", "United States"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Below is a list of entities extracted from a document:
|
||||||
|
|
||||||
|
{entities_str}
|
||||||
|
|
||||||
|
Some of these entities may refer to the same real-world entity but with different names or spellings.
|
||||||
|
For example, "JFK" and "John F. Kennedy" refer to the same person.
|
||||||
|
|
||||||
|
Please analyze this list and group entities that refer to the same real-world entity.
|
||||||
|
For each group, provide:
|
||||||
|
1. A canonical (standard) form of the entity
|
||||||
|
2. All variant forms found in the list
|
||||||
|
|
||||||
|
Format your response as a JSON object with an "entity_groups" array, where each item in the array is an object with:
|
||||||
|
- "canonical": The canonical form (choose the most complete and formal name)
|
||||||
|
- "variants": Array of all variants (including the canonical form)
|
||||||
|
|
||||||
|
The exact format of the JSON structure should be:
|
||||||
|
```json
|
||||||
|
{str(entities_example_dict)}
|
||||||
|
```
|
||||||
|
|
||||||
|
Only include entities in your response that have multiple variants or are grouped with other entities.
|
||||||
|
If an entity has no variants and doesn't belong to any group, don't include it in your response.
|
||||||
|
|
||||||
|
Focus on identifying:
|
||||||
|
- Different names for the same person (e.g., full names vs. nicknames)
|
||||||
|
- Different forms of the same organization
|
||||||
|
- The same concept expressed differently
|
||||||
|
- Abbreviations and their full forms
|
||||||
|
- Spelling variations and typos
|
||||||
|
"""
|
||||||
|
return prompt
|
@ -12,6 +12,7 @@ from core.completion.base_completion import BaseCompletionModel
|
|||||||
from core.database.base_database import BaseDatabase
|
from core.database.base_database import BaseDatabase
|
||||||
from core.models.documents import Document, ChunkResult
|
from core.models.documents import Document, ChunkResult
|
||||||
from core.config import get_settings
|
from core.config import get_settings
|
||||||
|
from core.services.entity_resolution import EntityResolver
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -49,6 +50,7 @@ class GraphService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.completion_model = completion_model
|
self.completion_model = completion_model
|
||||||
|
self.entity_resolver = EntityResolver()
|
||||||
|
|
||||||
async def create_graph(
|
async def create_graph(
|
||||||
self,
|
self,
|
||||||
@ -140,6 +142,8 @@ class GraphService:
|
|||||||
entities = {}
|
entities = {}
|
||||||
# List to collect all relationships
|
# List to collect all relationships
|
||||||
relationships = []
|
relationships = []
|
||||||
|
# List to collect all extracted entities for resolution
|
||||||
|
all_entities = []
|
||||||
|
|
||||||
# Collect all chunk sources from documents.
|
# Collect all chunk sources from documents.
|
||||||
chunk_sources = [
|
chunk_sources = [
|
||||||
@ -162,11 +166,12 @@ class GraphService:
|
|||||||
chunk.chunk_number
|
chunk.chunk_number
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add entities to the collection, avoiding duplicates
|
# Add entities to the collection, avoiding duplicates based on exact label match
|
||||||
for entity in chunk_entities:
|
for entity in chunk_entities:
|
||||||
if entity.label not in entities:
|
if entity.label not in entities:
|
||||||
# For new entities, initialize chunk_sources with the current chunk
|
# For new entities, initialize chunk_sources with the current chunk
|
||||||
entities[entity.label] = entity
|
entities[entity.label] = entity
|
||||||
|
all_entities.append(entity)
|
||||||
else:
|
else:
|
||||||
# If entity already exists, add this chunk source if not already present
|
# If entity already exists, add this chunk source if not already present
|
||||||
existing_entity = entities[entity.label]
|
existing_entity = entities[entity.label]
|
||||||
@ -197,6 +202,58 @@ class GraphService:
|
|||||||
logger.error(f"Fatal error processing chunk {chunk.chunk_number} in document {chunk.document_id}: {e}")
|
logger.error(f"Fatal error processing chunk {chunk.chunk_number} in document {chunk.document_id}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Check if entity resolution is enabled in settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Resolve entities to handle variations like "Trump" vs "Donald J Trump"
|
||||||
|
if settings.ENABLE_ENTITY_RESOLUTION:
|
||||||
|
logger.info("Resolving %d entities using LLM...", len(all_entities))
|
||||||
|
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(all_entities)
|
||||||
|
logger.info("Entity resolution completed successfully")
|
||||||
|
else:
|
||||||
|
logger.info("Entity resolution is disabled in settings.")
|
||||||
|
# Return identity mapping (each entity maps to itself)
|
||||||
|
entity_mapping = {entity.label: entity.label for entity in all_entities}
|
||||||
|
resolved_entities = all_entities
|
||||||
|
|
||||||
|
if entity_mapping:
|
||||||
|
logger.info("Entity resolution complete. Found %d mappings.", len(entity_mapping))
|
||||||
|
# Create a new entities dictionary with resolved entities
|
||||||
|
resolved_entities_dict = {}
|
||||||
|
# Build new entities dictionary with canonical labels
|
||||||
|
for entity in resolved_entities:
|
||||||
|
resolved_entities_dict[entity.label] = entity
|
||||||
|
# Update relationships to use canonical entity labels
|
||||||
|
updated_relationships = []
|
||||||
|
|
||||||
|
# Create an entity index by ID for efficient lookups
|
||||||
|
entity_by_id = {entity.id: entity for entity in all_entities}
|
||||||
|
|
||||||
|
for relationship in relationships:
|
||||||
|
# Lookup entities by ID directly from the index
|
||||||
|
source_entity = entity_by_id.get(relationship.source_id)
|
||||||
|
target_entity = entity_by_id.get(relationship.target_id)
|
||||||
|
|
||||||
|
if source_entity and target_entity:
|
||||||
|
# Get canonical labels
|
||||||
|
source_canonical = entity_mapping.get(source_entity.label, source_entity.label)
|
||||||
|
target_canonical = entity_mapping.get(target_entity.label, target_entity.label)
|
||||||
|
# Get canonical entities
|
||||||
|
canonical_source = resolved_entities_dict.get(source_canonical)
|
||||||
|
canonical_target = resolved_entities_dict.get(target_canonical)
|
||||||
|
if canonical_source and canonical_target:
|
||||||
|
# Update relationship to point to canonical entities
|
||||||
|
relationship.source_id = canonical_source.id
|
||||||
|
relationship.target_id = canonical_target.id
|
||||||
|
updated_relationships.append(relationship)
|
||||||
|
else:
|
||||||
|
# Skip relationships that can't be properly mapped
|
||||||
|
logger.warning("Skipping relationship between '%s' and '%s' - canonical entities not found", source_entity.label, target_entity.label)
|
||||||
|
else:
|
||||||
|
# Keep relationship as is if we can't find the entities
|
||||||
|
updated_relationships.append(relationship)
|
||||||
|
return resolved_entities_dict, updated_relationships
|
||||||
|
# If no entity resolution occurred, return original entities and relationships
|
||||||
return entities, relationships
|
return entities, relationships
|
||||||
|
|
||||||
async def extract_entities_from_text(
|
async def extract_entities_from_text(
|
||||||
@ -328,7 +385,7 @@ class GraphService:
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
# Log the raw response for debugging
|
# Log the raw response for debugging
|
||||||
logger.info(f"Raw Ollama response for entity extraction: {result['message']['content']}")
|
logger.debug(f"Raw Ollama response for entity extraction: {result['message']['content']}")
|
||||||
|
|
||||||
# Parse the JSON response - Pydantic will handle validation
|
# Parse the JSON response - Pydantic will handle validation
|
||||||
extraction_result = ExtractionResult.model_validate_json(result["message"]["content"])
|
extraction_result = ExtractionResult.model_validate_json(result["message"]["content"])
|
||||||
@ -459,14 +516,36 @@ class GraphService:
|
|||||||
# Find similar entities using embedding similarity
|
# Find similar entities using embedding similarity
|
||||||
top_entities = await self._find_similar_entities(query, graph.entities, k)
|
top_entities = await self._find_similar_entities(query, graph.entities, k)
|
||||||
else:
|
else:
|
||||||
# Use extracted entities directly
|
# Use entity resolution to handle variants of the same entity
|
||||||
entity_map = {entity.label.lower(): entity for entity in graph.entities}
|
settings = get_settings()
|
||||||
|
|
||||||
|
# First, create combined list of query entities and graph entities for resolution
|
||||||
|
combined_entities = query_entities + graph.entities
|
||||||
|
|
||||||
|
# Resolve entities to identify variants if enabled
|
||||||
|
if settings.ENABLE_ENTITY_RESOLUTION:
|
||||||
|
logger.info(f"Resolving {len(combined_entities)} entities from query and graph...")
|
||||||
|
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(combined_entities)
|
||||||
|
else:
|
||||||
|
logger.info("Entity resolution is disabled in settings.")
|
||||||
|
# Return identity mapping (each entity maps to itself)
|
||||||
|
entity_mapping = {entity.label: entity.label for entity in combined_entities}
|
||||||
|
resolved_entities = combined_entities
|
||||||
|
|
||||||
|
# Create a mapping of resolved entity labels to graph entities
|
||||||
|
entity_map = {}
|
||||||
|
for entity in graph.entities:
|
||||||
|
# Get canonical form for this entity
|
||||||
|
canonical_label = entity_mapping.get(entity.label, entity.label)
|
||||||
|
entity_map[canonical_label.lower()] = entity
|
||||||
|
|
||||||
matched_entities = []
|
matched_entities = []
|
||||||
|
# Match extracted entities with graph entities using canonical labels
|
||||||
# Match extracted entities with graph entities
|
|
||||||
for query_entity in query_entities:
|
for query_entity in query_entities:
|
||||||
if query_entity.label.lower() in entity_map:
|
# Get canonical form for this query entity
|
||||||
matched_entities.append(entity_map[query_entity.label.lower()])
|
canonical_query = entity_mapping.get(query_entity.label, query_entity.label)
|
||||||
|
if canonical_query.lower() in entity_map:
|
||||||
|
matched_entities.append(entity_map[canonical_query.lower()])
|
||||||
|
|
||||||
# If no matches, fallback to embedding similarity
|
# If no matches, fallback to embedding similarity
|
||||||
if matched_entities:
|
if matched_entities:
|
||||||
|
@ -105,10 +105,12 @@ mode = "self_hosted" # "cloud" or "self_hosted"
|
|||||||
[graph]
|
[graph]
|
||||||
provider = "ollama"
|
provider = "ollama"
|
||||||
model_name = "llama3.2"
|
model_name = "llama3.2"
|
||||||
|
enable_entity_resolution = true
|
||||||
|
|
||||||
# [graph]
|
# [graph]
|
||||||
# provider = "openai"
|
# provider = "openai"
|
||||||
# model_name = "gpt-4o-mini"
|
# model_name = "gpt-4o-mini"
|
||||||
|
# enable_entity_resolution = true
|
||||||
|
|
||||||
[telemetry]
|
[telemetry]
|
||||||
enabled = true
|
enabled = true
|
||||||
|
@ -12,4 +12,4 @@ __all__ = [
|
|||||||
"Document",
|
"Document",
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "0.2.6"
|
__version__ = "0.2.7"
|
||||||
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "databridge-client"
|
name = "databridge-client"
|
||||||
version = "0.2.6"
|
version = "0.2.7"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "DataBridge", email = "databridgesuperuser@gmail.com" },
|
{ name = "DataBridge", email = "databridgesuperuser@gmail.com" },
|
||||||
]
|
]
|
||||||
|
@ -92,12 +92,10 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
|
|||||||
.linkDirectionalArrowLength(3)
|
.linkDirectionalArrowLength(3)
|
||||||
.linkDirectionalArrowRelPos(1);
|
.linkDirectionalArrowRelPos(1);
|
||||||
|
|
||||||
// Add node labels if enabled
|
// Always use nodeCanvasObject to have consistent rendering regardless of label visibility
|
||||||
if (showNodeLabels && graph.nodeCanvasObject) {
|
if (graph.nodeCanvasObject) {
|
||||||
graph.nodeCanvasObject((node: NodeObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
graph.nodeCanvasObject((node: NodeObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
||||||
// Draw the node circle
|
// Draw the node circle
|
||||||
const label = node.label;
|
|
||||||
const fontSize = 12/globalScale;
|
|
||||||
const nodeR = 5;
|
const nodeR = 5;
|
||||||
|
|
||||||
if (typeof node.x !== 'number' || typeof node.y !== 'number') return;
|
if (typeof node.x !== 'number' || typeof node.y !== 'number') return;
|
||||||
@ -110,31 +108,35 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
|
|||||||
ctx.fillStyle = node.color;
|
ctx.fillStyle = node.color;
|
||||||
ctx.fill();
|
ctx.fill();
|
||||||
|
|
||||||
// Draw the text label
|
// Only draw the text label if showNodeLabels is true
|
||||||
ctx.font = `${fontSize}px Sans-Serif`;
|
if (showNodeLabels) {
|
||||||
ctx.textAlign = 'center';
|
const label = node.label;
|
||||||
ctx.textBaseline = 'middle';
|
const fontSize = 12/globalScale;
|
||||||
ctx.fillStyle = 'black';
|
|
||||||
|
ctx.font = `${fontSize}px Sans-Serif`;
|
||||||
// Add a background for better readability
|
ctx.textAlign = 'center';
|
||||||
const textWidth = ctx.measureText(label).width;
|
ctx.textBaseline = 'middle';
|
||||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
|
||||||
|
// Add a background for better readability
|
||||||
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
const textWidth = ctx.measureText(label).width;
|
||||||
ctx.fillRect(
|
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
||||||
x - bckgDimensions[0] / 2,
|
|
||||||
y - bckgDimensions[1] / 2,
|
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
||||||
bckgDimensions[0],
|
ctx.fillRect(
|
||||||
bckgDimensions[1]
|
x - bckgDimensions[0] / 2,
|
||||||
);
|
y - bckgDimensions[1] / 2,
|
||||||
|
bckgDimensions[0],
|
||||||
ctx.fillStyle = 'black';
|
bckgDimensions[1]
|
||||||
ctx.fillText(label, x, y);
|
);
|
||||||
|
|
||||||
|
ctx.fillStyle = 'black';
|
||||||
|
ctx.fillText(label, x, y);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add link labels if enabled
|
// Always use linkCanvasObject for consistent rendering
|
||||||
if (showLinkLabels && graph.linkCanvasObject) {
|
if (graph.linkCanvasObject) {
|
||||||
graph.linkCanvasObject((link: LinkObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
graph.linkCanvasObject((link: LinkObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
||||||
// Draw the link line
|
// Draw the link line
|
||||||
const start = link.source as NodeObject;
|
const start = link.source as NodeObject;
|
||||||
@ -155,55 +157,59 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
|
|||||||
ctx.lineWidth = 1;
|
ctx.lineWidth = 1;
|
||||||
ctx.stroke();
|
ctx.stroke();
|
||||||
|
|
||||||
// Draw the label at the middle of the link
|
// Draw arrowhead regardless of label visibility
|
||||||
const label = link.type;
|
const arrowLength = 5;
|
||||||
if (label) {
|
const dx = endX - startX;
|
||||||
const fontSize = 10/globalScale;
|
const dy = endY - startY;
|
||||||
ctx.font = `${fontSize}px Sans-Serif`;
|
const angle = Math.atan2(dy, dx);
|
||||||
|
|
||||||
// Calculate middle point
|
// Calculate a position near the target for the arrow
|
||||||
const middleX = startX + (endX - startX) / 2;
|
const arrowDistance = 15; // Distance from target node
|
||||||
const middleY = startY + (endY - startY) / 2;
|
const arrowX = endX - Math.cos(angle) * arrowDistance;
|
||||||
|
const arrowY = endY - Math.sin(angle) * arrowDistance;
|
||||||
// Add a background for better readability
|
|
||||||
const textWidth = ctx.measureText(label).width;
|
ctx.beginPath();
|
||||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
ctx.moveTo(arrowX, arrowY);
|
||||||
|
ctx.lineTo(
|
||||||
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
arrowX - arrowLength * Math.cos(angle - Math.PI / 6),
|
||||||
ctx.fillRect(
|
arrowY - arrowLength * Math.sin(angle - Math.PI / 6)
|
||||||
middleX - bckgDimensions[0] / 2,
|
);
|
||||||
middleY - bckgDimensions[1] / 2,
|
ctx.lineTo(
|
||||||
bckgDimensions[0],
|
arrowX - arrowLength * Math.cos(angle + Math.PI / 6),
|
||||||
bckgDimensions[1]
|
arrowY - arrowLength * Math.sin(angle + Math.PI / 6)
|
||||||
);
|
);
|
||||||
|
ctx.closePath();
|
||||||
ctx.textAlign = 'center';
|
ctx.fillStyle = 'rgba(0, 0, 0, 0.5)';
|
||||||
ctx.textBaseline = 'middle';
|
ctx.fill();
|
||||||
ctx.fillStyle = 'black';
|
|
||||||
ctx.fillText(label, middleX, middleY);
|
// Only draw label if showLinkLabels is true
|
||||||
|
if (showLinkLabels) {
|
||||||
// Draw arrow
|
const label = link.type;
|
||||||
const arrowLength = 5;
|
if (label) {
|
||||||
const dx = endX - startX;
|
const fontSize = 10/globalScale;
|
||||||
const dy = endY - startY;
|
ctx.font = `${fontSize}px Sans-Serif`;
|
||||||
const angle = Math.atan2(dy, dx);
|
|
||||||
|
// Calculate middle point
|
||||||
const arrowX = middleX + Math.cos(angle) * arrowLength;
|
const middleX = startX + (endX - startX) / 2;
|
||||||
const arrowY = middleY + Math.sin(angle) * arrowLength;
|
const middleY = startY + (endY - startY) / 2;
|
||||||
|
|
||||||
ctx.beginPath();
|
// Add a background for better readability
|
||||||
ctx.moveTo(arrowX, arrowY);
|
const textWidth = ctx.measureText(label).width;
|
||||||
ctx.lineTo(
|
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
||||||
arrowX - arrowLength * Math.cos(angle - Math.PI / 6),
|
|
||||||
arrowY - arrowLength * Math.sin(angle - Math.PI / 6)
|
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
||||||
);
|
ctx.fillRect(
|
||||||
ctx.lineTo(
|
middleX - bckgDimensions[0] / 2,
|
||||||
arrowX - arrowLength * Math.cos(angle + Math.PI / 6),
|
middleY - bckgDimensions[1] / 2,
|
||||||
arrowY - arrowLength * Math.sin(angle + Math.PI / 6)
|
bckgDimensions[0],
|
||||||
);
|
bckgDimensions[1]
|
||||||
ctx.closePath();
|
);
|
||||||
ctx.fillStyle = 'rgba(0, 0, 0, 0.5)';
|
|
||||||
ctx.fill();
|
ctx.textAlign = 'center';
|
||||||
|
ctx.textBaseline = 'middle';
|
||||||
|
ctx.fillStyle = 'black';
|
||||||
|
ctx.fillText(label, middleX, middleY);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -98,11 +98,17 @@ const GraphSection: React.FC<GraphSectionProps> = ({ apiBaseUrl }) => {
|
|||||||
color: entityTypeColors[entity.type.toLowerCase()] || entityTypeColors.default
|
color: entityTypeColors[entity.type.toLowerCase()] || entityTypeColors.default
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const links = graph.relationships.map(rel => ({
|
// Create a Set of all entity IDs for faster lookups
|
||||||
source: rel.source_id,
|
const nodeIdSet = new Set(graph.entities.map(entity => entity.id));
|
||||||
target: rel.target_id,
|
|
||||||
type: rel.type
|
// Filter relationships to only include those where both source and target nodes exist
|
||||||
}));
|
const links = graph.relationships
|
||||||
|
.filter(rel => nodeIdSet.has(rel.source_id) && nodeIdSet.has(rel.target_id))
|
||||||
|
.map(rel => ({
|
||||||
|
source: rel.source_id,
|
||||||
|
target: rel.target_id,
|
||||||
|
type: rel.type
|
||||||
|
}));
|
||||||
|
|
||||||
return { nodes, links };
|
return { nodes, links };
|
||||||
}, []);
|
}, []);
|
||||||
@ -251,7 +257,7 @@ const GraphSection: React.FC<GraphSectionProps> = ({ apiBaseUrl }) => {
|
|||||||
initializeGraph();
|
initializeGraph();
|
||||||
}, 100);
|
}, 100);
|
||||||
}
|
}
|
||||||
}, [selectedGraph, activeTab, initializeGraph]);
|
}, [selectedGraph, activeTab, initializeGraph, showNodeLabels, showLinkLabels]);
|
||||||
|
|
||||||
// Handle tab change
|
// Handle tab change
|
||||||
const handleTabChange = (value: string) => {
|
const handleTabChange = (value: string) => {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user