mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add entity resolution (#58)
This commit is contained in:
parent
9ce0507616
commit
2436a11f29
@ -67,6 +67,7 @@ class Settings(BaseSettings):
|
||||
# Graph configuration
|
||||
GRAPH_PROVIDER: Literal["ollama", "openai"]
|
||||
GRAPH_MODEL: str
|
||||
ENABLE_ENTITY_RESOLUTION: bool = True
|
||||
|
||||
# Reranker configuration
|
||||
USE_RERANKING: bool
|
||||
@ -301,6 +302,7 @@ def get_settings() -> Settings:
|
||||
graph_config = {
|
||||
"GRAPH_PROVIDER": config["graph"]["provider"],
|
||||
"GRAPH_MODEL": config["graph"]["model_name"],
|
||||
"ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True),
|
||||
}
|
||||
|
||||
# load telemetry config
|
||||
|
344
core/services/entity_resolution.py
Normal file
344
core/services/entity_resolution.py
Normal file
@ -0,0 +1,344 @@
|
||||
import httpx
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.config import get_settings
|
||||
from core.models.graph import Entity
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Pydantic models for structured output
|
||||
class EntityGroup(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
canonical: str
|
||||
variants: List[str]
|
||||
|
||||
|
||||
class EntityResolutionResult(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
entity_groups: List[EntityGroup]
|
||||
|
||||
|
||||
class EntityResolver:
|
||||
"""
|
||||
Resolves and normalizes entities by identifying different variants of the same entity.
|
||||
Handles cases like "Trump" vs "Donald J Trump" or "JFK" vs "Kennedy".
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the entity resolver"""
|
||||
self.settings = get_settings()
|
||||
|
||||
async def resolve_entities(self, entities: List[Entity]) -> Tuple[List[Entity], Dict[str, str]]:
|
||||
"""
|
||||
Resolves entities by identifying and grouping entities that refer to the same real-world entity.
|
||||
|
||||
Args:
|
||||
entities: List of extracted entities
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of normalized entities
|
||||
- Dictionary mapping original entity text to canonical entity text
|
||||
"""
|
||||
if not entities:
|
||||
return [], {}
|
||||
|
||||
# Extract entity labels to deduplicate
|
||||
entity_labels = [e.label for e in entities]
|
||||
|
||||
# If there's only one entity, no need to resolve
|
||||
if len(entity_labels) <= 1:
|
||||
return entities, {entity_labels[0]: entity_labels[0]} if entity_labels else {}
|
||||
|
||||
# Use LLM to identify and group similar entities
|
||||
resolved_entities = await self._resolve_with_llm(entity_labels)
|
||||
|
||||
# Create mapping from original to canonical forms
|
||||
entity_mapping = {}
|
||||
for group in resolved_entities:
|
||||
canonical = group["canonical"]
|
||||
for variant in group["variants"]:
|
||||
entity_mapping[variant] = canonical
|
||||
|
||||
# Deduplicate entities based on mapping
|
||||
unique_entities = []
|
||||
label_to_entity_map = {}
|
||||
|
||||
# First, create a map from labels to entities
|
||||
for entity in entities:
|
||||
canonical_label = entity_mapping.get(entity.label, entity.label)
|
||||
|
||||
# If we haven't seen this canonical label yet, add the entity to our map
|
||||
if canonical_label not in label_to_entity_map:
|
||||
label_to_entity_map[canonical_label] = entity
|
||||
else:
|
||||
# If we have seen it, merge chunk sources
|
||||
existing_entity = label_to_entity_map[canonical_label]
|
||||
|
||||
# Merge document IDs
|
||||
for doc_id in entity.document_ids:
|
||||
if doc_id not in existing_entity.document_ids:
|
||||
existing_entity.document_ids.append(doc_id)
|
||||
|
||||
# Merge chunk sources
|
||||
for doc_id, chunk_numbers in entity.chunk_sources.items():
|
||||
if doc_id not in existing_entity.chunk_sources:
|
||||
existing_entity.chunk_sources[doc_id] = list(chunk_numbers)
|
||||
else:
|
||||
for chunk_num in chunk_numbers:
|
||||
if chunk_num not in existing_entity.chunk_sources[doc_id]:
|
||||
existing_entity.chunk_sources[doc_id].append(chunk_num)
|
||||
|
||||
# Merge properties (optional)
|
||||
for key, value in entity.properties.items():
|
||||
if key not in existing_entity.properties:
|
||||
existing_entity.properties[key] = value
|
||||
|
||||
# Now, update the labels and create unique entities list
|
||||
for canonical_label, entity in label_to_entity_map.items():
|
||||
# Update the entity label to the canonical form
|
||||
entity.label = canonical_label
|
||||
|
||||
# Add aliases property if there are variants
|
||||
variants = [variant for variant, canon in entity_mapping.items()
|
||||
if canon == canonical_label and variant != canonical_label]
|
||||
if variants:
|
||||
if 'aliases' not in entity.properties:
|
||||
entity.properties['aliases'] = []
|
||||
entity.properties['aliases'].extend(variants)
|
||||
# Deduplicate aliases
|
||||
entity.properties['aliases'] = list(set(entity.properties['aliases']))
|
||||
|
||||
unique_entities.append(entity)
|
||||
|
||||
return unique_entities, entity_mapping
|
||||
|
||||
async def _resolve_with_llm(self, entity_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Uses LLM to identify and group similar entities.
|
||||
|
||||
Args:
|
||||
entity_labels: List of entity labels to resolve
|
||||
|
||||
Returns:
|
||||
List of entity groups, where each group is a dict with:
|
||||
- "canonical": The canonical form of the entity
|
||||
- "variants": List of variant forms of the entity
|
||||
"""
|
||||
prompt = self._create_entity_resolution_prompt(entity_labels)
|
||||
|
||||
if self.settings.GRAPH_PROVIDER == "openai":
|
||||
return await self._resolve_with_openai(prompt, entity_labels)
|
||||
elif self.settings.GRAPH_PROVIDER == "ollama":
|
||||
return await self._resolve_with_ollama(prompt, entity_labels)
|
||||
else:
|
||||
logger.error(f"Unsupported graph provider: {self.settings.GRAPH_PROVIDER}")
|
||||
# Fallback: treat each entity as unique
|
||||
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
||||
|
||||
async def _resolve_with_openai(self, prompt: str, entity_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Resolves entities using OpenAI.
|
||||
|
||||
Args:
|
||||
prompt: Prompt for entity resolution
|
||||
entity_labels: Original entity labels
|
||||
|
||||
Returns:
|
||||
List of entity groups
|
||||
"""
|
||||
# Define the schema directly instead of converting from Pydantic model
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_groups": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"canonical": {"type": "string"},
|
||||
"variants": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"required": ["canonical", "variants"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entity_groups"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
|
||||
# System and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity."
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
try:
|
||||
client = AsyncOpenAI(api_key=self.settings.OPENAI_API_KEY)
|
||||
|
||||
response = await client.responses.create(
|
||||
model=self.settings.GRAPH_MODEL,
|
||||
input=[system_message, user_message],
|
||||
text={
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"name": "entity_resolution",
|
||||
"schema": json_schema,
|
||||
"strict": True
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Process response based on current API structure
|
||||
if hasattr(response, 'output_text') and response.output_text:
|
||||
content = response.output_text
|
||||
|
||||
try:
|
||||
# Parse with Pydantic model for validation
|
||||
parsed_data = json.loads(content)
|
||||
|
||||
# Validate with Pydantic model
|
||||
resolution_result = EntityResolutionResult.model_validate(parsed_data)
|
||||
|
||||
# Extract entity groups
|
||||
return [group.model_dump() for group in resolution_result.entity_groups]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error parsing entity resolution response: %r", e)
|
||||
|
||||
# Fallback to direct JSON parsing if Pydantic validation fails
|
||||
try:
|
||||
parsed_data = json.loads(content)
|
||||
return parsed_data.get("entity_groups", [])
|
||||
except Exception:
|
||||
logger.error("JSON parsing failed, falling back to default resolution")
|
||||
|
||||
elif hasattr(response, 'refusal') and response.refusal:
|
||||
logger.warning("OpenAI refused to resolve entities")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during entity resolution with OpenAI: %r", e)
|
||||
|
||||
# Fallback: treat each entity as unique
|
||||
logger.info("Falling back to treating each entity as unique")
|
||||
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
||||
|
||||
async def _resolve_with_ollama(self, prompt: str, entity_labels: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Resolves entities using Ollama.
|
||||
|
||||
Args:
|
||||
prompt: Prompt for entity resolution
|
||||
entity_labels: Original entity labels
|
||||
|
||||
Returns:
|
||||
List of entity groups
|
||||
"""
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity."
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Create the schema for structured output
|
||||
format_schema = EntityResolutionResult.model_json_schema()
|
||||
|
||||
response = await client.post(
|
||||
f"{self.settings.EMBEDDING_OLLAMA_BASE_URL}/api/chat",
|
||||
json={
|
||||
"model": self.settings.GRAPH_MODEL,
|
||||
"messages": [system_message, user_message],
|
||||
"stream": False,
|
||||
"format": format_schema
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Parse with Pydantic model for validation
|
||||
resolution_result = EntityResolutionResult.model_validate_json(result["message"]["content"])
|
||||
return [group.model_dump() for group in resolution_result.entity_groups]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during entity resolution with Ollama: %r", e)
|
||||
|
||||
# Fallback: treat each entity as unique
|
||||
return [{"canonical": label, "variants": [label]} for label in entity_labels]
|
||||
|
||||
def _create_entity_resolution_prompt(self, entity_labels: List[str]) -> str:
|
||||
"""
|
||||
Creates a prompt for the LLM to resolve entities.
|
||||
|
||||
Args:
|
||||
entity_labels: List of entity labels to resolve
|
||||
|
||||
Returns:
|
||||
Prompt string for the LLM
|
||||
"""
|
||||
entities_str = "\n".join([f"- {label}" for label in entity_labels])
|
||||
entities_example_dict = {
|
||||
"entity_groups": [
|
||||
{
|
||||
"canonical": "John F. Kennedy",
|
||||
"variants": ["John F. Kennedy", "JFK", "Kennedy"]
|
||||
},
|
||||
{
|
||||
"canonical": "United States of America",
|
||||
"variants": ["United States of America", "USA", "United States"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
Below is a list of entities extracted from a document:
|
||||
|
||||
{entities_str}
|
||||
|
||||
Some of these entities may refer to the same real-world entity but with different names or spellings.
|
||||
For example, "JFK" and "John F. Kennedy" refer to the same person.
|
||||
|
||||
Please analyze this list and group entities that refer to the same real-world entity.
|
||||
For each group, provide:
|
||||
1. A canonical (standard) form of the entity
|
||||
2. All variant forms found in the list
|
||||
|
||||
Format your response as a JSON object with an "entity_groups" array, where each item in the array is an object with:
|
||||
- "canonical": The canonical form (choose the most complete and formal name)
|
||||
- "variants": Array of all variants (including the canonical form)
|
||||
|
||||
The exact format of the JSON structure should be:
|
||||
```json
|
||||
{str(entities_example_dict)}
|
||||
```
|
||||
|
||||
Only include entities in your response that have multiple variants or are grouped with other entities.
|
||||
If an entity has no variants and doesn't belong to any group, don't include it in your response.
|
||||
|
||||
Focus on identifying:
|
||||
- Different names for the same person (e.g., full names vs. nicknames)
|
||||
- Different forms of the same organization
|
||||
- The same concept expressed differently
|
||||
- Abbreviations and their full forms
|
||||
- Spelling variations and typos
|
||||
"""
|
||||
return prompt
|
@ -12,6 +12,7 @@ from core.completion.base_completion import BaseCompletionModel
|
||||
from core.database.base_database import BaseDatabase
|
||||
from core.models.documents import Document, ChunkResult
|
||||
from core.config import get_settings
|
||||
from core.services.entity_resolution import EntityResolver
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -49,6 +50,7 @@ class GraphService:
|
||||
self.db = db
|
||||
self.embedding_model = embedding_model
|
||||
self.completion_model = completion_model
|
||||
self.entity_resolver = EntityResolver()
|
||||
|
||||
async def create_graph(
|
||||
self,
|
||||
@ -140,6 +142,8 @@ class GraphService:
|
||||
entities = {}
|
||||
# List to collect all relationships
|
||||
relationships = []
|
||||
# List to collect all extracted entities for resolution
|
||||
all_entities = []
|
||||
|
||||
# Collect all chunk sources from documents.
|
||||
chunk_sources = [
|
||||
@ -162,11 +166,12 @@ class GraphService:
|
||||
chunk.chunk_number
|
||||
)
|
||||
|
||||
# Add entities to the collection, avoiding duplicates
|
||||
# Add entities to the collection, avoiding duplicates based on exact label match
|
||||
for entity in chunk_entities:
|
||||
if entity.label not in entities:
|
||||
# For new entities, initialize chunk_sources with the current chunk
|
||||
entities[entity.label] = entity
|
||||
all_entities.append(entity)
|
||||
else:
|
||||
# If entity already exists, add this chunk source if not already present
|
||||
existing_entity = entities[entity.label]
|
||||
@ -197,6 +202,58 @@ class GraphService:
|
||||
logger.error(f"Fatal error processing chunk {chunk.chunk_number} in document {chunk.document_id}: {e}")
|
||||
raise
|
||||
|
||||
# Check if entity resolution is enabled in settings
|
||||
settings = get_settings()
|
||||
|
||||
# Resolve entities to handle variations like "Trump" vs "Donald J Trump"
|
||||
if settings.ENABLE_ENTITY_RESOLUTION:
|
||||
logger.info("Resolving %d entities using LLM...", len(all_entities))
|
||||
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(all_entities)
|
||||
logger.info("Entity resolution completed successfully")
|
||||
else:
|
||||
logger.info("Entity resolution is disabled in settings.")
|
||||
# Return identity mapping (each entity maps to itself)
|
||||
entity_mapping = {entity.label: entity.label for entity in all_entities}
|
||||
resolved_entities = all_entities
|
||||
|
||||
if entity_mapping:
|
||||
logger.info("Entity resolution complete. Found %d mappings.", len(entity_mapping))
|
||||
# Create a new entities dictionary with resolved entities
|
||||
resolved_entities_dict = {}
|
||||
# Build new entities dictionary with canonical labels
|
||||
for entity in resolved_entities:
|
||||
resolved_entities_dict[entity.label] = entity
|
||||
# Update relationships to use canonical entity labels
|
||||
updated_relationships = []
|
||||
|
||||
# Create an entity index by ID for efficient lookups
|
||||
entity_by_id = {entity.id: entity for entity in all_entities}
|
||||
|
||||
for relationship in relationships:
|
||||
# Lookup entities by ID directly from the index
|
||||
source_entity = entity_by_id.get(relationship.source_id)
|
||||
target_entity = entity_by_id.get(relationship.target_id)
|
||||
|
||||
if source_entity and target_entity:
|
||||
# Get canonical labels
|
||||
source_canonical = entity_mapping.get(source_entity.label, source_entity.label)
|
||||
target_canonical = entity_mapping.get(target_entity.label, target_entity.label)
|
||||
# Get canonical entities
|
||||
canonical_source = resolved_entities_dict.get(source_canonical)
|
||||
canonical_target = resolved_entities_dict.get(target_canonical)
|
||||
if canonical_source and canonical_target:
|
||||
# Update relationship to point to canonical entities
|
||||
relationship.source_id = canonical_source.id
|
||||
relationship.target_id = canonical_target.id
|
||||
updated_relationships.append(relationship)
|
||||
else:
|
||||
# Skip relationships that can't be properly mapped
|
||||
logger.warning("Skipping relationship between '%s' and '%s' - canonical entities not found", source_entity.label, target_entity.label)
|
||||
else:
|
||||
# Keep relationship as is if we can't find the entities
|
||||
updated_relationships.append(relationship)
|
||||
return resolved_entities_dict, updated_relationships
|
||||
# If no entity resolution occurred, return original entities and relationships
|
||||
return entities, relationships
|
||||
|
||||
async def extract_entities_from_text(
|
||||
@ -328,7 +385,7 @@ class GraphService:
|
||||
result = response.json()
|
||||
|
||||
# Log the raw response for debugging
|
||||
logger.info(f"Raw Ollama response for entity extraction: {result['message']['content']}")
|
||||
logger.debug(f"Raw Ollama response for entity extraction: {result['message']['content']}")
|
||||
|
||||
# Parse the JSON response - Pydantic will handle validation
|
||||
extraction_result = ExtractionResult.model_validate_json(result["message"]["content"])
|
||||
@ -459,14 +516,36 @@ class GraphService:
|
||||
# Find similar entities using embedding similarity
|
||||
top_entities = await self._find_similar_entities(query, graph.entities, k)
|
||||
else:
|
||||
# Use extracted entities directly
|
||||
entity_map = {entity.label.lower(): entity for entity in graph.entities}
|
||||
# Use entity resolution to handle variants of the same entity
|
||||
settings = get_settings()
|
||||
|
||||
# First, create combined list of query entities and graph entities for resolution
|
||||
combined_entities = query_entities + graph.entities
|
||||
|
||||
# Resolve entities to identify variants if enabled
|
||||
if settings.ENABLE_ENTITY_RESOLUTION:
|
||||
logger.info(f"Resolving {len(combined_entities)} entities from query and graph...")
|
||||
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(combined_entities)
|
||||
else:
|
||||
logger.info("Entity resolution is disabled in settings.")
|
||||
# Return identity mapping (each entity maps to itself)
|
||||
entity_mapping = {entity.label: entity.label for entity in combined_entities}
|
||||
resolved_entities = combined_entities
|
||||
|
||||
# Create a mapping of resolved entity labels to graph entities
|
||||
entity_map = {}
|
||||
for entity in graph.entities:
|
||||
# Get canonical form for this entity
|
||||
canonical_label = entity_mapping.get(entity.label, entity.label)
|
||||
entity_map[canonical_label.lower()] = entity
|
||||
|
||||
matched_entities = []
|
||||
|
||||
# Match extracted entities with graph entities
|
||||
# Match extracted entities with graph entities using canonical labels
|
||||
for query_entity in query_entities:
|
||||
if query_entity.label.lower() in entity_map:
|
||||
matched_entities.append(entity_map[query_entity.label.lower()])
|
||||
# Get canonical form for this query entity
|
||||
canonical_query = entity_mapping.get(query_entity.label, query_entity.label)
|
||||
if canonical_query.lower() in entity_map:
|
||||
matched_entities.append(entity_map[canonical_query.lower()])
|
||||
|
||||
# If no matches, fallback to embedding similarity
|
||||
if matched_entities:
|
||||
|
@ -105,10 +105,12 @@ mode = "self_hosted" # "cloud" or "self_hosted"
|
||||
[graph]
|
||||
provider = "ollama"
|
||||
model_name = "llama3.2"
|
||||
enable_entity_resolution = true
|
||||
|
||||
# [graph]
|
||||
# provider = "openai"
|
||||
# model_name = "gpt-4o-mini"
|
||||
# enable_entity_resolution = true
|
||||
|
||||
[telemetry]
|
||||
enabled = true
|
||||
|
@ -12,4 +12,4 @@ __all__ = [
|
||||
"Document",
|
||||
]
|
||||
|
||||
__version__ = "0.2.6"
|
||||
__version__ = "0.2.7"
|
||||
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "databridge-client"
|
||||
version = "0.2.6"
|
||||
version = "0.2.7"
|
||||
authors = [
|
||||
{ name = "DataBridge", email = "databridgesuperuser@gmail.com" },
|
||||
]
|
||||
|
@ -92,12 +92,10 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
|
||||
.linkDirectionalArrowLength(3)
|
||||
.linkDirectionalArrowRelPos(1);
|
||||
|
||||
// Add node labels if enabled
|
||||
if (showNodeLabels && graph.nodeCanvasObject) {
|
||||
// Always use nodeCanvasObject to have consistent rendering regardless of label visibility
|
||||
if (graph.nodeCanvasObject) {
|
||||
graph.nodeCanvasObject((node: NodeObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
||||
// Draw the node circle
|
||||
const label = node.label;
|
||||
const fontSize = 12/globalScale;
|
||||
const nodeR = 5;
|
||||
|
||||
if (typeof node.x !== 'number' || typeof node.y !== 'number') return;
|
||||
@ -110,31 +108,35 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
|
||||
ctx.fillStyle = node.color;
|
||||
ctx.fill();
|
||||
|
||||
// Draw the text label
|
||||
ctx.font = `${fontSize}px Sans-Serif`;
|
||||
ctx.textAlign = 'center';
|
||||
ctx.textBaseline = 'middle';
|
||||
ctx.fillStyle = 'black';
|
||||
|
||||
// Add a background for better readability
|
||||
const textWidth = ctx.measureText(label).width;
|
||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
||||
|
||||
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
||||
ctx.fillRect(
|
||||
x - bckgDimensions[0] / 2,
|
||||
y - bckgDimensions[1] / 2,
|
||||
bckgDimensions[0],
|
||||
bckgDimensions[1]
|
||||
);
|
||||
|
||||
ctx.fillStyle = 'black';
|
||||
ctx.fillText(label, x, y);
|
||||
// Only draw the text label if showNodeLabels is true
|
||||
if (showNodeLabels) {
|
||||
const label = node.label;
|
||||
const fontSize = 12/globalScale;
|
||||
|
||||
ctx.font = `${fontSize}px Sans-Serif`;
|
||||
ctx.textAlign = 'center';
|
||||
ctx.textBaseline = 'middle';
|
||||
|
||||
// Add a background for better readability
|
||||
const textWidth = ctx.measureText(label).width;
|
||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
||||
|
||||
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
||||
ctx.fillRect(
|
||||
x - bckgDimensions[0] / 2,
|
||||
y - bckgDimensions[1] / 2,
|
||||
bckgDimensions[0],
|
||||
bckgDimensions[1]
|
||||
);
|
||||
|
||||
ctx.fillStyle = 'black';
|
||||
ctx.fillText(label, x, y);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Add link labels if enabled
|
||||
if (showLinkLabels && graph.linkCanvasObject) {
|
||||
// Always use linkCanvasObject for consistent rendering
|
||||
if (graph.linkCanvasObject) {
|
||||
graph.linkCanvasObject((link: LinkObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
||||
// Draw the link line
|
||||
const start = link.source as NodeObject;
|
||||
@ -155,55 +157,59 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
|
||||
ctx.lineWidth = 1;
|
||||
ctx.stroke();
|
||||
|
||||
// Draw the label at the middle of the link
|
||||
const label = link.type;
|
||||
if (label) {
|
||||
const fontSize = 10/globalScale;
|
||||
ctx.font = `${fontSize}px Sans-Serif`;
|
||||
|
||||
// Calculate middle point
|
||||
const middleX = startX + (endX - startX) / 2;
|
||||
const middleY = startY + (endY - startY) / 2;
|
||||
|
||||
// Add a background for better readability
|
||||
const textWidth = ctx.measureText(label).width;
|
||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
||||
|
||||
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
||||
ctx.fillRect(
|
||||
middleX - bckgDimensions[0] / 2,
|
||||
middleY - bckgDimensions[1] / 2,
|
||||
bckgDimensions[0],
|
||||
bckgDimensions[1]
|
||||
);
|
||||
|
||||
ctx.textAlign = 'center';
|
||||
ctx.textBaseline = 'middle';
|
||||
ctx.fillStyle = 'black';
|
||||
ctx.fillText(label, middleX, middleY);
|
||||
|
||||
// Draw arrow
|
||||
const arrowLength = 5;
|
||||
const dx = endX - startX;
|
||||
const dy = endY - startY;
|
||||
const angle = Math.atan2(dy, dx);
|
||||
|
||||
const arrowX = middleX + Math.cos(angle) * arrowLength;
|
||||
const arrowY = middleY + Math.sin(angle) * arrowLength;
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(arrowX, arrowY);
|
||||
ctx.lineTo(
|
||||
arrowX - arrowLength * Math.cos(angle - Math.PI / 6),
|
||||
arrowY - arrowLength * Math.sin(angle - Math.PI / 6)
|
||||
);
|
||||
ctx.lineTo(
|
||||
arrowX - arrowLength * Math.cos(angle + Math.PI / 6),
|
||||
arrowY - arrowLength * Math.sin(angle + Math.PI / 6)
|
||||
);
|
||||
ctx.closePath();
|
||||
ctx.fillStyle = 'rgba(0, 0, 0, 0.5)';
|
||||
ctx.fill();
|
||||
// Draw arrowhead regardless of label visibility
|
||||
const arrowLength = 5;
|
||||
const dx = endX - startX;
|
||||
const dy = endY - startY;
|
||||
const angle = Math.atan2(dy, dx);
|
||||
|
||||
// Calculate a position near the target for the arrow
|
||||
const arrowDistance = 15; // Distance from target node
|
||||
const arrowX = endX - Math.cos(angle) * arrowDistance;
|
||||
const arrowY = endY - Math.sin(angle) * arrowDistance;
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(arrowX, arrowY);
|
||||
ctx.lineTo(
|
||||
arrowX - arrowLength * Math.cos(angle - Math.PI / 6),
|
||||
arrowY - arrowLength * Math.sin(angle - Math.PI / 6)
|
||||
);
|
||||
ctx.lineTo(
|
||||
arrowX - arrowLength * Math.cos(angle + Math.PI / 6),
|
||||
arrowY - arrowLength * Math.sin(angle + Math.PI / 6)
|
||||
);
|
||||
ctx.closePath();
|
||||
ctx.fillStyle = 'rgba(0, 0, 0, 0.5)';
|
||||
ctx.fill();
|
||||
|
||||
// Only draw label if showLinkLabels is true
|
||||
if (showLinkLabels) {
|
||||
const label = link.type;
|
||||
if (label) {
|
||||
const fontSize = 10/globalScale;
|
||||
ctx.font = `${fontSize}px Sans-Serif`;
|
||||
|
||||
// Calculate middle point
|
||||
const middleX = startX + (endX - startX) / 2;
|
||||
const middleY = startY + (endY - startY) / 2;
|
||||
|
||||
// Add a background for better readability
|
||||
const textWidth = ctx.measureText(label).width;
|
||||
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
|
||||
|
||||
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
|
||||
ctx.fillRect(
|
||||
middleX - bckgDimensions[0] / 2,
|
||||
middleY - bckgDimensions[1] / 2,
|
||||
bckgDimensions[0],
|
||||
bckgDimensions[1]
|
||||
);
|
||||
|
||||
ctx.textAlign = 'center';
|
||||
ctx.textBaseline = 'middle';
|
||||
ctx.fillStyle = 'black';
|
||||
ctx.fillText(label, middleX, middleY);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -98,11 +98,17 @@ const GraphSection: React.FC<GraphSectionProps> = ({ apiBaseUrl }) => {
|
||||
color: entityTypeColors[entity.type.toLowerCase()] || entityTypeColors.default
|
||||
}));
|
||||
|
||||
const links = graph.relationships.map(rel => ({
|
||||
source: rel.source_id,
|
||||
target: rel.target_id,
|
||||
type: rel.type
|
||||
}));
|
||||
// Create a Set of all entity IDs for faster lookups
|
||||
const nodeIdSet = new Set(graph.entities.map(entity => entity.id));
|
||||
|
||||
// Filter relationships to only include those where both source and target nodes exist
|
||||
const links = graph.relationships
|
||||
.filter(rel => nodeIdSet.has(rel.source_id) && nodeIdSet.has(rel.target_id))
|
||||
.map(rel => ({
|
||||
source: rel.source_id,
|
||||
target: rel.target_id,
|
||||
type: rel.type
|
||||
}));
|
||||
|
||||
return { nodes, links };
|
||||
}, []);
|
||||
@ -251,7 +257,7 @@ const GraphSection: React.FC<GraphSectionProps> = ({ apiBaseUrl }) => {
|
||||
initializeGraph();
|
||||
}, 100);
|
||||
}
|
||||
}, [selectedGraph, activeTab, initializeGraph]);
|
||||
}, [selectedGraph, activeTab, initializeGraph, showNodeLabels, showLinkLabels]);
|
||||
|
||||
// Handle tab change
|
||||
const handleTabChange = (value: string) => {
|
||||
|
Loading…
x
Reference in New Issue
Block a user