Add entity resolution (#58)

This commit is contained in:
Adityavardhan Agrawal 2025-03-29 19:25:01 -07:00 committed by GitHub
parent 9ce0507616
commit 2436a11f29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 530 additions and 91 deletions

View File

@ -67,6 +67,7 @@ class Settings(BaseSettings):
# Graph configuration # Graph configuration
GRAPH_PROVIDER: Literal["ollama", "openai"] GRAPH_PROVIDER: Literal["ollama", "openai"]
GRAPH_MODEL: str GRAPH_MODEL: str
ENABLE_ENTITY_RESOLUTION: bool = True
# Reranker configuration # Reranker configuration
USE_RERANKING: bool USE_RERANKING: bool
@ -301,6 +302,7 @@ def get_settings() -> Settings:
graph_config = { graph_config = {
"GRAPH_PROVIDER": config["graph"]["provider"], "GRAPH_PROVIDER": config["graph"]["provider"],
"GRAPH_MODEL": config["graph"]["model_name"], "GRAPH_MODEL": config["graph"]["model_name"],
"ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True),
} }
# load telemetry config # load telemetry config

View File

@ -0,0 +1,344 @@
import httpx
import logging
import json
from typing import List, Dict, Any, Tuple
from pydantic import BaseModel, ConfigDict
from core.config import get_settings
from core.models.graph import Entity
from openai import AsyncOpenAI
logger = logging.getLogger(__name__)
# Define Pydantic models for structured output
class EntityGroup(BaseModel):
model_config = ConfigDict(extra="forbid")
canonical: str
variants: List[str]
class EntityResolutionResult(BaseModel):
model_config = ConfigDict(extra="forbid")
entity_groups: List[EntityGroup]
class EntityResolver:
"""
Resolves and normalizes entities by identifying different variants of the same entity.
Handles cases like "Trump" vs "Donald J Trump" or "JFK" vs "Kennedy".
"""
def __init__(self):
"""Initialize the entity resolver"""
self.settings = get_settings()
async def resolve_entities(self, entities: List[Entity]) -> Tuple[List[Entity], Dict[str, str]]:
"""
Resolves entities by identifying and grouping entities that refer to the same real-world entity.
Args:
entities: List of extracted entities
Returns:
Tuple containing:
- List of normalized entities
- Dictionary mapping original entity text to canonical entity text
"""
if not entities:
return [], {}
# Extract entity labels to deduplicate
entity_labels = [e.label for e in entities]
# If there's only one entity, no need to resolve
if len(entity_labels) <= 1:
return entities, {entity_labels[0]: entity_labels[0]} if entity_labels else {}
# Use LLM to identify and group similar entities
resolved_entities = await self._resolve_with_llm(entity_labels)
# Create mapping from original to canonical forms
entity_mapping = {}
for group in resolved_entities:
canonical = group["canonical"]
for variant in group["variants"]:
entity_mapping[variant] = canonical
# Deduplicate entities based on mapping
unique_entities = []
label_to_entity_map = {}
# First, create a map from labels to entities
for entity in entities:
canonical_label = entity_mapping.get(entity.label, entity.label)
# If we haven't seen this canonical label yet, add the entity to our map
if canonical_label not in label_to_entity_map:
label_to_entity_map[canonical_label] = entity
else:
# If we have seen it, merge chunk sources
existing_entity = label_to_entity_map[canonical_label]
# Merge document IDs
for doc_id in entity.document_ids:
if doc_id not in existing_entity.document_ids:
existing_entity.document_ids.append(doc_id)
# Merge chunk sources
for doc_id, chunk_numbers in entity.chunk_sources.items():
if doc_id not in existing_entity.chunk_sources:
existing_entity.chunk_sources[doc_id] = list(chunk_numbers)
else:
for chunk_num in chunk_numbers:
if chunk_num not in existing_entity.chunk_sources[doc_id]:
existing_entity.chunk_sources[doc_id].append(chunk_num)
# Merge properties (optional)
for key, value in entity.properties.items():
if key not in existing_entity.properties:
existing_entity.properties[key] = value
# Now, update the labels and create unique entities list
for canonical_label, entity in label_to_entity_map.items():
# Update the entity label to the canonical form
entity.label = canonical_label
# Add aliases property if there are variants
variants = [variant for variant, canon in entity_mapping.items()
if canon == canonical_label and variant != canonical_label]
if variants:
if 'aliases' not in entity.properties:
entity.properties['aliases'] = []
entity.properties['aliases'].extend(variants)
# Deduplicate aliases
entity.properties['aliases'] = list(set(entity.properties['aliases']))
unique_entities.append(entity)
return unique_entities, entity_mapping
async def _resolve_with_llm(self, entity_labels: List[str]) -> List[Dict[str, Any]]:
"""
Uses LLM to identify and group similar entities.
Args:
entity_labels: List of entity labels to resolve
Returns:
List of entity groups, where each group is a dict with:
- "canonical": The canonical form of the entity
- "variants": List of variant forms of the entity
"""
prompt = self._create_entity_resolution_prompt(entity_labels)
if self.settings.GRAPH_PROVIDER == "openai":
return await self._resolve_with_openai(prompt, entity_labels)
elif self.settings.GRAPH_PROVIDER == "ollama":
return await self._resolve_with_ollama(prompt, entity_labels)
else:
logger.error(f"Unsupported graph provider: {self.settings.GRAPH_PROVIDER}")
# Fallback: treat each entity as unique
return [{"canonical": label, "variants": [label]} for label in entity_labels]
async def _resolve_with_openai(self, prompt: str, entity_labels: List[str]) -> List[Dict[str, Any]]:
"""
Resolves entities using OpenAI.
Args:
prompt: Prompt for entity resolution
entity_labels: Original entity labels
Returns:
List of entity groups
"""
# Define the schema directly instead of converting from Pydantic model
json_schema = {
"type": "object",
"properties": {
"entity_groups": {
"type": "array",
"items": {
"type": "object",
"properties": {
"canonical": {"type": "string"},
"variants": {
"type": "array",
"items": {"type": "string"}
}
},
"required": ["canonical", "variants"],
"additionalProperties": False
}
}
},
"required": ["entity_groups"],
"additionalProperties": False
}
# System and user messages
system_message = {
"role": "system",
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity."
}
user_message = {
"role": "user",
"content": prompt
}
try:
client = AsyncOpenAI(api_key=self.settings.OPENAI_API_KEY)
response = await client.responses.create(
model=self.settings.GRAPH_MODEL,
input=[system_message, user_message],
text={
"format": {
"type": "json_schema",
"name": "entity_resolution",
"schema": json_schema,
"strict": True
}
}
)
# Process response based on current API structure
if hasattr(response, 'output_text') and response.output_text:
content = response.output_text
try:
# Parse with Pydantic model for validation
parsed_data = json.loads(content)
# Validate with Pydantic model
resolution_result = EntityResolutionResult.model_validate(parsed_data)
# Extract entity groups
return [group.model_dump() for group in resolution_result.entity_groups]
except Exception as e:
logger.error("Error parsing entity resolution response: %r", e)
# Fallback to direct JSON parsing if Pydantic validation fails
try:
parsed_data = json.loads(content)
return parsed_data.get("entity_groups", [])
except Exception:
logger.error("JSON parsing failed, falling back to default resolution")
elif hasattr(response, 'refusal') and response.refusal:
logger.warning("OpenAI refused to resolve entities")
except Exception as e:
logger.error("Error during entity resolution with OpenAI: %r", e)
# Fallback: treat each entity as unique
logger.info("Falling back to treating each entity as unique")
return [{"canonical": label, "variants": [label]} for label in entity_labels]
async def _resolve_with_ollama(self, prompt: str, entity_labels: List[str]) -> List[Dict[str, Any]]:
"""
Resolves entities using Ollama.
Args:
prompt: Prompt for entity resolution
entity_labels: Original entity labels
Returns:
List of entity groups
"""
system_message = {
"role": "system",
"content": "You are an entity resolution expert. Your task is to identify and group different representations of the same real-world entity."
}
user_message = {
"role": "user",
"content": prompt
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Create the schema for structured output
format_schema = EntityResolutionResult.model_json_schema()
response = await client.post(
f"{self.settings.EMBEDDING_OLLAMA_BASE_URL}/api/chat",
json={
"model": self.settings.GRAPH_MODEL,
"messages": [system_message, user_message],
"stream": False,
"format": format_schema
},
)
response.raise_for_status()
result = response.json()
# Parse with Pydantic model for validation
resolution_result = EntityResolutionResult.model_validate_json(result["message"]["content"])
return [group.model_dump() for group in resolution_result.entity_groups]
except Exception as e:
logger.error("Error during entity resolution with Ollama: %r", e)
# Fallback: treat each entity as unique
return [{"canonical": label, "variants": [label]} for label in entity_labels]
def _create_entity_resolution_prompt(self, entity_labels: List[str]) -> str:
"""
Creates a prompt for the LLM to resolve entities.
Args:
entity_labels: List of entity labels to resolve
Returns:
Prompt string for the LLM
"""
entities_str = "\n".join([f"- {label}" for label in entity_labels])
entities_example_dict = {
"entity_groups": [
{
"canonical": "John F. Kennedy",
"variants": ["John F. Kennedy", "JFK", "Kennedy"]
},
{
"canonical": "United States of America",
"variants": ["United States of America", "USA", "United States"]
}
]
}
prompt = f"""
Below is a list of entities extracted from a document:
{entities_str}
Some of these entities may refer to the same real-world entity but with different names or spellings.
For example, "JFK" and "John F. Kennedy" refer to the same person.
Please analyze this list and group entities that refer to the same real-world entity.
For each group, provide:
1. A canonical (standard) form of the entity
2. All variant forms found in the list
Format your response as a JSON object with an "entity_groups" array, where each item in the array is an object with:
- "canonical": The canonical form (choose the most complete and formal name)
- "variants": Array of all variants (including the canonical form)
The exact format of the JSON structure should be:
```json
{str(entities_example_dict)}
```
Only include entities in your response that have multiple variants or are grouped with other entities.
If an entity has no variants and doesn't belong to any group, don't include it in your response.
Focus on identifying:
- Different names for the same person (e.g., full names vs. nicknames)
- Different forms of the same organization
- The same concept expressed differently
- Abbreviations and their full forms
- Spelling variations and typos
"""
return prompt

View File

@ -12,6 +12,7 @@ from core.completion.base_completion import BaseCompletionModel
from core.database.base_database import BaseDatabase from core.database.base_database import BaseDatabase
from core.models.documents import Document, ChunkResult from core.models.documents import Document, ChunkResult
from core.config import get_settings from core.config import get_settings
from core.services.entity_resolution import EntityResolver
from openai import AsyncOpenAI from openai import AsyncOpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,6 +50,7 @@ class GraphService:
self.db = db self.db = db
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.completion_model = completion_model self.completion_model = completion_model
self.entity_resolver = EntityResolver()
async def create_graph( async def create_graph(
self, self,
@ -140,6 +142,8 @@ class GraphService:
entities = {} entities = {}
# List to collect all relationships # List to collect all relationships
relationships = [] relationships = []
# List to collect all extracted entities for resolution
all_entities = []
# Collect all chunk sources from documents. # Collect all chunk sources from documents.
chunk_sources = [ chunk_sources = [
@ -162,11 +166,12 @@ class GraphService:
chunk.chunk_number chunk.chunk_number
) )
# Add entities to the collection, avoiding duplicates # Add entities to the collection, avoiding duplicates based on exact label match
for entity in chunk_entities: for entity in chunk_entities:
if entity.label not in entities: if entity.label not in entities:
# For new entities, initialize chunk_sources with the current chunk # For new entities, initialize chunk_sources with the current chunk
entities[entity.label] = entity entities[entity.label] = entity
all_entities.append(entity)
else: else:
# If entity already exists, add this chunk source if not already present # If entity already exists, add this chunk source if not already present
existing_entity = entities[entity.label] existing_entity = entities[entity.label]
@ -197,6 +202,58 @@ class GraphService:
logger.error(f"Fatal error processing chunk {chunk.chunk_number} in document {chunk.document_id}: {e}") logger.error(f"Fatal error processing chunk {chunk.chunk_number} in document {chunk.document_id}: {e}")
raise raise
# Check if entity resolution is enabled in settings
settings = get_settings()
# Resolve entities to handle variations like "Trump" vs "Donald J Trump"
if settings.ENABLE_ENTITY_RESOLUTION:
logger.info("Resolving %d entities using LLM...", len(all_entities))
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(all_entities)
logger.info("Entity resolution completed successfully")
else:
logger.info("Entity resolution is disabled in settings.")
# Return identity mapping (each entity maps to itself)
entity_mapping = {entity.label: entity.label for entity in all_entities}
resolved_entities = all_entities
if entity_mapping:
logger.info("Entity resolution complete. Found %d mappings.", len(entity_mapping))
# Create a new entities dictionary with resolved entities
resolved_entities_dict = {}
# Build new entities dictionary with canonical labels
for entity in resolved_entities:
resolved_entities_dict[entity.label] = entity
# Update relationships to use canonical entity labels
updated_relationships = []
# Create an entity index by ID for efficient lookups
entity_by_id = {entity.id: entity for entity in all_entities}
for relationship in relationships:
# Lookup entities by ID directly from the index
source_entity = entity_by_id.get(relationship.source_id)
target_entity = entity_by_id.get(relationship.target_id)
if source_entity and target_entity:
# Get canonical labels
source_canonical = entity_mapping.get(source_entity.label, source_entity.label)
target_canonical = entity_mapping.get(target_entity.label, target_entity.label)
# Get canonical entities
canonical_source = resolved_entities_dict.get(source_canonical)
canonical_target = resolved_entities_dict.get(target_canonical)
if canonical_source and canonical_target:
# Update relationship to point to canonical entities
relationship.source_id = canonical_source.id
relationship.target_id = canonical_target.id
updated_relationships.append(relationship)
else:
# Skip relationships that can't be properly mapped
logger.warning("Skipping relationship between '%s' and '%s' - canonical entities not found", source_entity.label, target_entity.label)
else:
# Keep relationship as is if we can't find the entities
updated_relationships.append(relationship)
return resolved_entities_dict, updated_relationships
# If no entity resolution occurred, return original entities and relationships
return entities, relationships return entities, relationships
async def extract_entities_from_text( async def extract_entities_from_text(
@ -328,7 +385,7 @@ class GraphService:
result = response.json() result = response.json()
# Log the raw response for debugging # Log the raw response for debugging
logger.info(f"Raw Ollama response for entity extraction: {result['message']['content']}") logger.debug(f"Raw Ollama response for entity extraction: {result['message']['content']}")
# Parse the JSON response - Pydantic will handle validation # Parse the JSON response - Pydantic will handle validation
extraction_result = ExtractionResult.model_validate_json(result["message"]["content"]) extraction_result = ExtractionResult.model_validate_json(result["message"]["content"])
@ -459,14 +516,36 @@ class GraphService:
# Find similar entities using embedding similarity # Find similar entities using embedding similarity
top_entities = await self._find_similar_entities(query, graph.entities, k) top_entities = await self._find_similar_entities(query, graph.entities, k)
else: else:
# Use extracted entities directly # Use entity resolution to handle variants of the same entity
entity_map = {entity.label.lower(): entity for entity in graph.entities} settings = get_settings()
# First, create combined list of query entities and graph entities for resolution
combined_entities = query_entities + graph.entities
# Resolve entities to identify variants if enabled
if settings.ENABLE_ENTITY_RESOLUTION:
logger.info(f"Resolving {len(combined_entities)} entities from query and graph...")
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(combined_entities)
else:
logger.info("Entity resolution is disabled in settings.")
# Return identity mapping (each entity maps to itself)
entity_mapping = {entity.label: entity.label for entity in combined_entities}
resolved_entities = combined_entities
# Create a mapping of resolved entity labels to graph entities
entity_map = {}
for entity in graph.entities:
# Get canonical form for this entity
canonical_label = entity_mapping.get(entity.label, entity.label)
entity_map[canonical_label.lower()] = entity
matched_entities = [] matched_entities = []
# Match extracted entities with graph entities using canonical labels
# Match extracted entities with graph entities
for query_entity in query_entities: for query_entity in query_entities:
if query_entity.label.lower() in entity_map: # Get canonical form for this query entity
matched_entities.append(entity_map[query_entity.label.lower()]) canonical_query = entity_mapping.get(query_entity.label, query_entity.label)
if canonical_query.lower() in entity_map:
matched_entities.append(entity_map[canonical_query.lower()])
# If no matches, fallback to embedding similarity # If no matches, fallback to embedding similarity
if matched_entities: if matched_entities:

View File

@ -105,10 +105,12 @@ mode = "self_hosted" # "cloud" or "self_hosted"
[graph] [graph]
provider = "ollama" provider = "ollama"
model_name = "llama3.2" model_name = "llama3.2"
enable_entity_resolution = true
# [graph] # [graph]
# provider = "openai" # provider = "openai"
# model_name = "gpt-4o-mini" # model_name = "gpt-4o-mini"
# enable_entity_resolution = true
[telemetry] [telemetry]
enabled = true enabled = true

View File

@ -12,4 +12,4 @@ __all__ = [
"Document", "Document",
] ]
__version__ = "0.2.6" __version__ = "0.2.7"

View File

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "databridge-client" name = "databridge-client"
version = "0.2.6" version = "0.2.7"
authors = [ authors = [
{ name = "DataBridge", email = "databridgesuperuser@gmail.com" }, { name = "DataBridge", email = "databridgesuperuser@gmail.com" },
] ]

View File

@ -92,12 +92,10 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
.linkDirectionalArrowLength(3) .linkDirectionalArrowLength(3)
.linkDirectionalArrowRelPos(1); .linkDirectionalArrowRelPos(1);
// Add node labels if enabled // Always use nodeCanvasObject to have consistent rendering regardless of label visibility
if (showNodeLabels && graph.nodeCanvasObject) { if (graph.nodeCanvasObject) {
graph.nodeCanvasObject((node: NodeObject, ctx: CanvasRenderingContext2D, globalScale: number) => { graph.nodeCanvasObject((node: NodeObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
// Draw the node circle // Draw the node circle
const label = node.label;
const fontSize = 12/globalScale;
const nodeR = 5; const nodeR = 5;
if (typeof node.x !== 'number' || typeof node.y !== 'number') return; if (typeof node.x !== 'number' || typeof node.y !== 'number') return;
@ -110,31 +108,35 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
ctx.fillStyle = node.color; ctx.fillStyle = node.color;
ctx.fill(); ctx.fill();
// Draw the text label // Only draw the text label if showNodeLabels is true
ctx.font = `${fontSize}px Sans-Serif`; if (showNodeLabels) {
ctx.textAlign = 'center'; const label = node.label;
ctx.textBaseline = 'middle'; const fontSize = 12/globalScale;
ctx.fillStyle = 'black';
ctx.font = `${fontSize}px Sans-Serif`;
// Add a background for better readability ctx.textAlign = 'center';
const textWidth = ctx.measureText(label).width; ctx.textBaseline = 'middle';
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
// Add a background for better readability
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)'; const textWidth = ctx.measureText(label).width;
ctx.fillRect( const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
x - bckgDimensions[0] / 2,
y - bckgDimensions[1] / 2, ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
bckgDimensions[0], ctx.fillRect(
bckgDimensions[1] x - bckgDimensions[0] / 2,
); y - bckgDimensions[1] / 2,
bckgDimensions[0],
ctx.fillStyle = 'black'; bckgDimensions[1]
ctx.fillText(label, x, y); );
ctx.fillStyle = 'black';
ctx.fillText(label, x, y);
}
}); });
} }
// Add link labels if enabled // Always use linkCanvasObject for consistent rendering
if (showLinkLabels && graph.linkCanvasObject) { if (graph.linkCanvasObject) {
graph.linkCanvasObject((link: LinkObject, ctx: CanvasRenderingContext2D, globalScale: number) => { graph.linkCanvasObject((link: LinkObject, ctx: CanvasRenderingContext2D, globalScale: number) => {
// Draw the link line // Draw the link line
const start = link.source as NodeObject; const start = link.source as NodeObject;
@ -155,55 +157,59 @@ const ForceGraphComponent: React.FC<ForceGraphComponentProps> = ({
ctx.lineWidth = 1; ctx.lineWidth = 1;
ctx.stroke(); ctx.stroke();
// Draw the label at the middle of the link // Draw arrowhead regardless of label visibility
const label = link.type; const arrowLength = 5;
if (label) { const dx = endX - startX;
const fontSize = 10/globalScale; const dy = endY - startY;
ctx.font = `${fontSize}px Sans-Serif`; const angle = Math.atan2(dy, dx);
// Calculate middle point // Calculate a position near the target for the arrow
const middleX = startX + (endX - startX) / 2; const arrowDistance = 15; // Distance from target node
const middleY = startY + (endY - startY) / 2; const arrowX = endX - Math.cos(angle) * arrowDistance;
const arrowY = endY - Math.sin(angle) * arrowDistance;
// Add a background for better readability
const textWidth = ctx.measureText(label).width; ctx.beginPath();
const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2); ctx.moveTo(arrowX, arrowY);
ctx.lineTo(
ctx.fillStyle = 'rgba(255, 255, 255, 0.8)'; arrowX - arrowLength * Math.cos(angle - Math.PI / 6),
ctx.fillRect( arrowY - arrowLength * Math.sin(angle - Math.PI / 6)
middleX - bckgDimensions[0] / 2, );
middleY - bckgDimensions[1] / 2, ctx.lineTo(
bckgDimensions[0], arrowX - arrowLength * Math.cos(angle + Math.PI / 6),
bckgDimensions[1] arrowY - arrowLength * Math.sin(angle + Math.PI / 6)
); );
ctx.closePath();
ctx.textAlign = 'center'; ctx.fillStyle = 'rgba(0, 0, 0, 0.5)';
ctx.textBaseline = 'middle'; ctx.fill();
ctx.fillStyle = 'black';
ctx.fillText(label, middleX, middleY); // Only draw label if showLinkLabels is true
if (showLinkLabels) {
// Draw arrow const label = link.type;
const arrowLength = 5; if (label) {
const dx = endX - startX; const fontSize = 10/globalScale;
const dy = endY - startY; ctx.font = `${fontSize}px Sans-Serif`;
const angle = Math.atan2(dy, dx);
// Calculate middle point
const arrowX = middleX + Math.cos(angle) * arrowLength; const middleX = startX + (endX - startX) / 2;
const arrowY = middleY + Math.sin(angle) * arrowLength; const middleY = startY + (endY - startY) / 2;
ctx.beginPath(); // Add a background for better readability
ctx.moveTo(arrowX, arrowY); const textWidth = ctx.measureText(label).width;
ctx.lineTo( const bckgDimensions = [textWidth, fontSize].map(n => n + fontSize * 0.2);
arrowX - arrowLength * Math.cos(angle - Math.PI / 6),
arrowY - arrowLength * Math.sin(angle - Math.PI / 6) ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
); ctx.fillRect(
ctx.lineTo( middleX - bckgDimensions[0] / 2,
arrowX - arrowLength * Math.cos(angle + Math.PI / 6), middleY - bckgDimensions[1] / 2,
arrowY - arrowLength * Math.sin(angle + Math.PI / 6) bckgDimensions[0],
); bckgDimensions[1]
ctx.closePath(); );
ctx.fillStyle = 'rgba(0, 0, 0, 0.5)';
ctx.fill(); ctx.textAlign = 'center';
ctx.textBaseline = 'middle';
ctx.fillStyle = 'black';
ctx.fillText(label, middleX, middleY);
}
} }
}); });
} }

View File

@ -98,11 +98,17 @@ const GraphSection: React.FC<GraphSectionProps> = ({ apiBaseUrl }) => {
color: entityTypeColors[entity.type.toLowerCase()] || entityTypeColors.default color: entityTypeColors[entity.type.toLowerCase()] || entityTypeColors.default
})); }));
const links = graph.relationships.map(rel => ({ // Create a Set of all entity IDs for faster lookups
source: rel.source_id, const nodeIdSet = new Set(graph.entities.map(entity => entity.id));
target: rel.target_id,
type: rel.type // Filter relationships to only include those where both source and target nodes exist
})); const links = graph.relationships
.filter(rel => nodeIdSet.has(rel.source_id) && nodeIdSet.has(rel.target_id))
.map(rel => ({
source: rel.source_id,
target: rel.target_id,
type: rel.type
}));
return { nodes, links }; return { nodes, links };
}, []); }, []);
@ -251,7 +257,7 @@ const GraphSection: React.FC<GraphSectionProps> = ({ apiBaseUrl }) => {
initializeGraph(); initializeGraph();
}, 100); }, 100);
} }
}, [selectedGraph, activeTab, initializeGraph]); }, [selectedGraph, activeTab, initializeGraph, showNodeLabels, showLinkLabels]);
// Handle tab change // Handle tab change
const handleTabChange = (value: string) => { const handleTabChange = (value: string) => {