mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
272 lines
10 KiB
Python
272 lines
10 KiB
Python
"""Knowledge graph query and management tools."""
|
|
|
|
import json
|
|
import logging
|
|
from typing import List, Literal, Optional
|
|
|
|
from core.models.auth import AuthContext
|
|
from core.models.graph import Entity
|
|
from core.services.document_service import DocumentService
|
|
from core.tools.document_tools import ToolError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def knowledge_graph_query(
|
|
query_type: Literal["list_entities", "entity", "path", "subgraph"],
|
|
start_nodes: List[str],
|
|
max_depth: int = 3,
|
|
graph_name: Optional[str] = None,
|
|
document_service: DocumentService = None,
|
|
auth: AuthContext = None,
|
|
end_user_id: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Query the knowledge graph for entities, relationships, and connections.
|
|
|
|
Args:
|
|
query_type: Type of knowledge graph query
|
|
start_nodes: Starting entity/entities for the query
|
|
max_depth: Maximum path length/depth to explore
|
|
graph_name: Name of the graph to query (optional)
|
|
document_service: DocumentService instance
|
|
auth: Authentication context
|
|
end_user_id: Optional end-user ID
|
|
|
|
Returns:
|
|
Knowledge graph query results as a string
|
|
"""
|
|
if document_service is None:
|
|
raise ToolError("Document service not provided")
|
|
|
|
try:
|
|
# Get graph service from document service
|
|
graph_service = document_service.graph_service
|
|
|
|
# Set up system filters for proper scoping
|
|
system_filters = {}
|
|
if end_user_id:
|
|
system_filters["end_user_id"] = end_user_id
|
|
|
|
# If no graph name provided, try to find an appropriate graph
|
|
if not graph_name:
|
|
# Get all available graphs for the user
|
|
available_graphs = await graph_service.db.list_graphs(auth, system_filters=system_filters)
|
|
|
|
if not available_graphs:
|
|
raise ToolError("No graphs found. Please create a graph first.")
|
|
|
|
# If user has only one graph, use that
|
|
if len(available_graphs) == 1:
|
|
graph_name = available_graphs[0].name
|
|
logger.info(f"Using the only available graph: {graph_name}")
|
|
else:
|
|
# Use "graph_main" as default if available
|
|
main_graph = next((g for g in available_graphs if g.name == "graph_main"), None)
|
|
if main_graph:
|
|
graph_name = "graph_main"
|
|
else:
|
|
# Otherwise use the first available graph
|
|
graph_name = available_graphs[0].name
|
|
|
|
logger.info(f"Multiple graphs available, using {graph_name}")
|
|
|
|
# Get the graph for the authorized user
|
|
graph = await graph_service.db.get_graph(graph_name, auth, system_filters=system_filters)
|
|
if not graph:
|
|
raise ToolError(f"Graph '{graph_name}' not found or not accessible")
|
|
|
|
# Create entity map for faster lookups
|
|
entity_map = {entity.id: entity for entity in graph.entities}
|
|
entity_by_label = {entity.label.lower(): entity for entity in graph.entities}
|
|
|
|
match query_type:
|
|
case "list_entities":
|
|
entities = await graph_service._find_similar_entities(start_nodes[0], graph.entities, 10)
|
|
results = [
|
|
{
|
|
"id": entity.id,
|
|
"label": entity.label,
|
|
"type": entity.type,
|
|
"properties": entity.properties,
|
|
"similarity_score": score,
|
|
}
|
|
for entity, score in entities
|
|
]
|
|
return json.dumps(results, indent=2)
|
|
|
|
case "entity":
|
|
if not start_nodes or len(start_nodes) == 0:
|
|
raise ToolError("Entity ID or label is required for entity query")
|
|
|
|
# Try to find entity by ID first, then by label
|
|
entity = entity_map.get(start_nodes[0])
|
|
|
|
# If not found by ID, try looking up by label (case-insensitive)
|
|
if not entity:
|
|
entity = entity_by_label.get(start_nodes[0].lower())
|
|
|
|
if not entity:
|
|
raise ToolError(f"Entity '{start_nodes[0]}' not found in the knowledge graph")
|
|
|
|
# Return entity details
|
|
return json.dumps(
|
|
{
|
|
"id": entity.id,
|
|
"label": entity.label,
|
|
"type": entity.type,
|
|
"properties": entity.properties,
|
|
"document_ids": entity.document_ids,
|
|
},
|
|
indent=2,
|
|
)
|
|
|
|
case "path":
|
|
if not start_nodes or len(start_nodes) < 2:
|
|
raise ToolError("Exactly two entity IDs/labels are required for path query")
|
|
|
|
# Find source and target entities
|
|
source = entity_map.get(start_nodes[0])
|
|
if not source:
|
|
source = entity_by_label.get(start_nodes[0].lower())
|
|
|
|
target = entity_map.get(start_nodes[1])
|
|
if not target:
|
|
target = entity_by_label.get(start_nodes[1].lower())
|
|
|
|
if not source or not target:
|
|
missing = []
|
|
if not source:
|
|
missing.append(start_nodes[0])
|
|
if not target:
|
|
missing.append(start_nodes[1])
|
|
raise ToolError(f"Entities not found: {', '.join(missing)}")
|
|
|
|
# Find all paths between the two entities
|
|
paths = graph_service._find_relationship_paths(graph, [source], max_depth)
|
|
|
|
# Filter paths that end with the target entity
|
|
target_paths = []
|
|
for path in paths:
|
|
if path and path[-1] == target.label:
|
|
target_paths.append(path)
|
|
|
|
if not target_paths:
|
|
return f"No path found between '{source.label}' and '{target.label}' within {max_depth} hops"
|
|
|
|
# Format results as human-readable paths
|
|
formatted_paths = []
|
|
for path in target_paths:
|
|
formatted_paths.append(" -> ".join(path))
|
|
|
|
return json.dumps({"source": source.label, "target": target.label, "paths": formatted_paths}, indent=2)
|
|
|
|
case "subgraph":
|
|
if not start_nodes or len(start_nodes) == 0:
|
|
raise ToolError("Entity ID or label is required for subgraph query")
|
|
|
|
# Find central entity
|
|
central = entity_map.get(start_nodes[0])
|
|
if not central:
|
|
central = entity_by_label.get(start_nodes[0].lower())
|
|
|
|
if not central:
|
|
raise ToolError(f"Entity '{start_nodes[0]}' not found in the knowledge graph")
|
|
|
|
# Get related entities through graph traversal
|
|
related_entities = graph_service._expand_entities(graph, [central], max_depth)
|
|
|
|
# Create a set of entity IDs in the subgraph
|
|
subgraph_entity_ids = {entity.id for entity in related_entities}
|
|
|
|
# Find relationships between these entities
|
|
subgraph_relationships = [
|
|
rel
|
|
for rel in graph.relationships
|
|
if rel.source_id in subgraph_entity_ids and rel.target_id in subgraph_entity_ids
|
|
]
|
|
|
|
# Format results
|
|
nodes = []
|
|
for entity in related_entities:
|
|
nodes.append(
|
|
{"id": entity.id, "label": entity.label, "type": entity.type, "properties": entity.properties}
|
|
)
|
|
|
|
relationships = []
|
|
for rel in subgraph_relationships:
|
|
source_label = entity_map.get(rel.source_id, Entity(label="Unknown")).label
|
|
target_label = entity_map.get(rel.target_id, Entity(label="Unknown")).label
|
|
|
|
relationships.append(
|
|
{
|
|
"id": rel.id,
|
|
"source": rel.source_id,
|
|
"source_label": source_label,
|
|
"target": rel.target_id,
|
|
"target_label": target_label,
|
|
"type": rel.type,
|
|
}
|
|
)
|
|
|
|
return json.dumps({"nodes": nodes, "relationships": relationships}, indent=2)
|
|
|
|
except Exception as e:
|
|
raise ToolError(f"Error querying knowledge graph: {str(e)}")
|
|
|
|
|
|
async def list_graphs(
|
|
document_service: DocumentService = None, auth: AuthContext = None, end_user_id: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
List all available knowledge graphs for the authorized user.
|
|
|
|
Args:
|
|
document_service: DocumentService instance
|
|
auth: Authentication context
|
|
end_user_id: Optional end-user ID for scoping
|
|
|
|
Returns:
|
|
List of available graphs as JSON string
|
|
"""
|
|
if document_service is None:
|
|
raise ToolError("Document service not provided")
|
|
|
|
try:
|
|
# Get graph service from document service
|
|
graph_service = document_service.graph_service
|
|
|
|
# Set up system filters for proper scoping
|
|
system_filters = {}
|
|
if end_user_id:
|
|
system_filters["end_user_id"] = end_user_id
|
|
|
|
# Get all available graphs for the user
|
|
available_graphs = await graph_service.db.list_graphs(auth, system_filters=system_filters)
|
|
|
|
if not available_graphs:
|
|
return json.dumps({"message": "No graphs found", "graphs": []})
|
|
|
|
# Format the results
|
|
graphs_info = []
|
|
for graph in available_graphs:
|
|
graphs_info.append(
|
|
{
|
|
"name": graph.name,
|
|
"document_count": len(graph.document_ids),
|
|
"entity_count": len(graph.entities),
|
|
"relationship_count": len(graph.relationships),
|
|
"created_at": (
|
|
graph.created_at.isoformat() if hasattr(graph, "created_at") and graph.created_at else None
|
|
),
|
|
"updated_at": (
|
|
graph.updated_at.isoformat() if hasattr(graph, "updated_at") and graph.updated_at else None
|
|
),
|
|
}
|
|
)
|
|
|
|
return json.dumps({"message": f"Found {len(graphs_info)} graph(s)", "graphs": graphs_info}, indent=2)
|
|
|
|
except Exception as e:
|
|
raise ToolError(f"Error listing graphs: {str(e)}")
|