morphik-core/core/services/graph_service.py
2025-04-13 14:52:26 -07:00

1471 lines
63 KiB
Python

import logging
import numpy as np
import json
from typing import Dict, Any, List, Optional, Tuple, Set
from pydantic import BaseModel
from datetime import datetime, timezone
from core.models.prompts import (
GraphPromptOverrides,
QueryPromptOverrides,
EntityExtractionPromptOverride
)
from core.models.completion import ChunkSource, CompletionResponse, CompletionRequest
from core.models.graph import Graph, Entity, Relationship
from core.models.auth import AuthContext
from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.completion.base_completion import BaseCompletionModel
from core.database.base_database import BaseDatabase
from core.models.documents import Document, ChunkResult
from core.config import get_settings
from core.services.entity_resolution import EntityResolver
logger = logging.getLogger(__name__)
class EntityExtraction(BaseModel):
"""Model for entity extraction results"""
label: str
type: str
properties: Dict[str, Any] = {}
class RelationshipExtraction(BaseModel):
"""Model for relationship extraction results"""
source: str
target: str
relationship: str
class ExtractionResult(BaseModel):
"""Model for structured extraction from LLM"""
entities: List[EntityExtraction] = []
relationships: List[RelationshipExtraction] = []
class GraphService:
"""Service for managing knowledge graphs and graph-based operations"""
def __init__(
self,
db: BaseDatabase,
embedding_model: BaseEmbeddingModel,
completion_model: BaseCompletionModel,
):
self.db = db
self.embedding_model = embedding_model
self.completion_model = completion_model
self.entity_resolver = EntityResolver()
async def update_graph(
self,
name: str,
auth: AuthContext,
document_service, # Passed in to avoid circular import
additional_filters: Optional[Dict[str, Any]] = None,
additional_documents: Optional[List[str]] = None,
prompt_overrides: Optional[GraphPromptOverrides] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Graph:
"""Update an existing graph with new documents.
This function processes additional documents matching the original or new filters,
extracts entities and relationships, and updates the graph with new information.
Args:
name: Name of the graph to update
auth: Authentication context
document_service: DocumentService instance for retrieving documents and chunks
additional_filters: Optional additional metadata filters to determine which new documents to include
additional_documents: Optional list of specific additional document IDs to include
prompt_overrides: Optional GraphPromptOverrides with customizations for prompts
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) to determine which documents to include
Returns:
Graph: The updated graph
"""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
# Get the existing graph
existing_graph = await self.db.get_graph(name, auth)
if not existing_graph:
raise ValueError(f"Graph '{name}' not found")
# Track explicitly added documents to ensure they're included in the final graph
# even if they don't have new entities or relationships
explicit_doc_ids = set(additional_documents or [])
# Find new documents to process
document_ids = await self._get_new_document_ids(
auth, existing_graph, additional_filters, additional_documents, system_filters
)
if not document_ids and not explicit_doc_ids:
# No new documents to add
return existing_graph
# Create a set for all document IDs that should be included in the updated graph
# Includes existing document IDs, explicitly added document IDs, and documents found via filters
all_doc_ids = set(existing_graph.document_ids).union(document_ids).union(explicit_doc_ids)
logger.info(f"Total document IDs to include in updated graph: {len(all_doc_ids)}")
# Batch retrieve all document IDs (both regular and explicit) in a single call
all_ids_to_retrieve = list(document_ids)
# Add explicit document IDs if not already included
if explicit_doc_ids and additional_documents:
# Add any missing IDs to the list
for doc_id in additional_documents:
if doc_id not in document_ids:
all_ids_to_retrieve.append(doc_id)
# Batch retrieve all documents in a single call
document_objects = await document_service.batch_retrieve_documents(
all_ids_to_retrieve, auth, system_filters.get("folder_name", None), system_filters.get("end_user_id", None)
)
# Process explicit documents if needed
if explicit_doc_ids and additional_documents:
# Extract authorized explicit IDs from the retrieved documents
authorized_explicit_ids = {
doc.external_id for doc in document_objects if doc.external_id in explicit_doc_ids
}
logger.info(
f"Authorized explicit document IDs: {len(authorized_explicit_ids)} out of {len(explicit_doc_ids)}"
)
# Update document_ids and all_doc_ids
document_ids.update(authorized_explicit_ids)
all_doc_ids.update(authorized_explicit_ids)
# If we have additional filters, make sure we include the document IDs from filter matches
# even if they don't have new entities or relationships
if additional_filters:
filtered_docs = await document_service.batch_retrieve_documents(
[
doc_id
for doc_id in all_doc_ids
if doc_id not in {d.external_id for d in document_objects}
],
auth,
system_filters.get("folder_name", None),
system_filters.get("end_user_id", None)
)
logger.info(f"Additional filtered documents to include: {len(filtered_docs)}")
document_objects.extend(filtered_docs)
if not document_objects:
# No authorized new documents
return existing_graph
# Validation is now handled by type annotations
# Extract entities and relationships from new documents
new_entities_dict, new_relationships = await self._process_documents_for_entities(
document_objects, auth, document_service, prompt_overrides
)
# Track document IDs that need to be included even without entities/relationships
additional_doc_ids = {doc.external_id for doc in document_objects}
# Merge new entities and relationships with existing ones
existing_graph = self._merge_graph_data(
existing_graph,
new_entities_dict,
new_relationships,
all_doc_ids,
additional_filters,
additional_doc_ids,
)
# Store the updated graph in the database
if not await self.db.update_graph(existing_graph):
raise Exception("Failed to update graph")
return existing_graph
async def _get_new_document_ids(
self,
auth: AuthContext,
existing_graph: Graph,
additional_filters: Optional[Dict[str, Any]] = None,
additional_documents: Optional[List[str]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Set[str]:
"""Get IDs of new documents to add to the graph."""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}
# Initialize with explicitly specified documents, ensuring it's a set
document_ids = set(additional_documents or [])
# Process documents matching additional filters
if additional_filters or system_filters:
filtered_docs = await self.db.get_documents(auth, filters=additional_filters, system_filters=system_filters)
filter_doc_ids = {doc.external_id for doc in filtered_docs}
logger.info(f"Found {len(filter_doc_ids)} documents matching additional filters and system filters")
document_ids.update(filter_doc_ids)
# Process documents matching the original filters
if existing_graph.filters:
# Original filters shouldn't include system filters, as we're applying them separately
filtered_docs = await self.db.get_documents(auth, filters=existing_graph.filters, system_filters=system_filters)
orig_filter_doc_ids = {doc.external_id for doc in filtered_docs}
logger.info(f"Found {len(orig_filter_doc_ids)} documents matching original filters and system filters")
document_ids.update(orig_filter_doc_ids)
# Get only the document IDs that are not already in the graph
new_doc_ids = document_ids - set(existing_graph.document_ids)
logger.info(
f"Found {len(new_doc_ids)} new documents to add to graph '{existing_graph.name}'"
)
return new_doc_ids
def _merge_graph_data(
self,
existing_graph: Graph,
new_entities_dict: Dict[str, Entity],
new_relationships: List[Relationship],
document_ids: Set[str],
additional_filters: Optional[Dict[str, Any]] = None,
additional_doc_ids: Optional[Set[str]] = None,
) -> Graph:
"""Merge new entities and relationships with existing graph data."""
# Create a mapping of existing entities by label for merging
existing_entities_dict = {entity.label: entity for entity in existing_graph.entities}
# Merge entities
merged_entities = self._merge_entities(existing_entities_dict, new_entities_dict)
# Create a mapping of entity labels to IDs for new relationships
entity_id_map = {entity.label: entity.id for entity in merged_entities.values()}
# Merge relationships
merged_relationships = self._merge_relationships(
existing_graph.relationships, new_relationships, new_entities_dict, entity_id_map
)
# Update the graph
existing_graph.entities = list(merged_entities.values())
existing_graph.relationships = merged_relationships
# Ensure we include all necessary document IDs:
# 1. All document IDs from document_ids parameter
# 2. All document IDs that have authorized documents (from additional_doc_ids)
final_doc_ids = document_ids.copy()
if additional_doc_ids:
final_doc_ids.update(additional_doc_ids)
logger.info(f"Final document count in graph: {len(final_doc_ids)}")
existing_graph.document_ids = list(final_doc_ids)
existing_graph.updated_at = datetime.now(timezone.utc)
# Update filters if additional filters were provided
if additional_filters and existing_graph.filters:
# Smarter filter merging
self._smart_merge_filters(existing_graph.filters, additional_filters)
return existing_graph
def _smart_merge_filters(
self, existing_filters: Dict[str, Any], additional_filters: Dict[str, Any]
):
"""Merge filters with more intelligence to handle different data types and filter values."""
for key, value in additional_filters.items():
# If the key doesn't exist in existing filters, just add it
if key not in existing_filters:
existing_filters[key] = value
continue
existing_value = existing_filters[key]
# Handle list values - merge them
if isinstance(existing_value, list) and isinstance(value, list):
# Union the lists without duplicates
existing_filters[key] = list(set(existing_value + value))
# Handle dict values - recursively merge them
elif isinstance(existing_value, dict) and isinstance(value, dict):
# Recursive merge for nested dictionaries
self._smart_merge_filters(existing_value, value)
# Default to overwriting with the new value
else:
existing_filters[key] = value
def _merge_entities(
self, existing_entities: Dict[str, Entity], new_entities: Dict[str, Entity]
) -> Dict[str, Entity]:
"""Merge new entities with existing entities."""
merged_entities = existing_entities.copy()
for label, new_entity in new_entities.items():
if label in merged_entities:
# Entity exists, merge chunk sources and document IDs
existing_entity = merged_entities[label]
# Merge document IDs
for doc_id in new_entity.document_ids:
if doc_id not in existing_entity.document_ids:
existing_entity.document_ids.append(doc_id)
# Merge chunk sources
for doc_id, chunk_numbers in new_entity.chunk_sources.items():
if doc_id not in existing_entity.chunk_sources:
existing_entity.chunk_sources[doc_id] = chunk_numbers
else:
for chunk_num in chunk_numbers:
if chunk_num not in existing_entity.chunk_sources[doc_id]:
existing_entity.chunk_sources[doc_id].append(chunk_num)
else:
# Add new entity
merged_entities[label] = new_entity
return merged_entities
def _merge_relationships(
self,
existing_relationships: List[Relationship],
new_relationships: List[Relationship],
new_entities_dict: Dict[str, Entity],
entity_id_map: Dict[str, str],
) -> List[Relationship]:
"""Merge new relationships with existing ones."""
merged_relationships = list(existing_relationships)
# Create reverse mappings for entity IDs to labels for efficient lookup
entity_id_to_label = {entity.id: label for label, entity in new_entities_dict.items()}
for rel in new_relationships:
# Look up entity labels using the reverse mapping
source_label = entity_id_to_label.get(rel.source_id)
target_label = entity_id_to_label.get(rel.target_id)
if source_label in entity_id_map and target_label in entity_id_map:
# Update relationship to use existing entity IDs
rel.source_id = entity_id_map[source_label]
rel.target_id = entity_id_map[target_label]
# Check if this relationship already exists
is_duplicate = False
for existing_rel in existing_relationships:
if (
existing_rel.source_id == rel.source_id
and existing_rel.target_id == rel.target_id
and existing_rel.type == rel.type
):
# Found a duplicate, merge the chunk sources
is_duplicate = True
self._merge_relationship_sources(existing_rel, rel)
break
if not is_duplicate:
merged_relationships.append(rel)
return merged_relationships
def _merge_relationship_sources(
self, existing_rel: Relationship, new_rel: Relationship
) -> None:
"""Merge chunk sources and document IDs from new relationship into existing one."""
# Merge chunk sources
for doc_id, chunk_numbers in new_rel.chunk_sources.items():
if doc_id not in existing_rel.chunk_sources:
existing_rel.chunk_sources[doc_id] = chunk_numbers
else:
for chunk_num in chunk_numbers:
if chunk_num not in existing_rel.chunk_sources[doc_id]:
existing_rel.chunk_sources[doc_id].append(chunk_num)
# Merge document IDs
for doc_id in new_rel.document_ids:
if doc_id not in existing_rel.document_ids:
existing_rel.document_ids.append(doc_id)
async def create_graph(
self,
name: str,
auth: AuthContext,
document_service, # Passed in to avoid circular import
filters: Optional[Dict[str, Any]] = None,
documents: Optional[List[str]] = None,
prompt_overrides: Optional[GraphPromptOverrides] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Graph:
"""Create a graph from documents.
This function processes documents matching filters or specific document IDs,
extracts entities and relationships from document chunks, and saves them as a graph.
Args:
name: Name of the graph to create
auth: Authentication context
document_service: DocumentService instance for retrieving documents and chunks
filters: Optional metadata filters to determine which documents to include
documents: Optional list of specific document IDs to include
prompt_overrides: Optional GraphPromptOverrides with customizations for prompts
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) to determine which documents to include
Returns:
Graph: The created graph
"""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
# Find documents to process based on filters and/or specific document IDs
document_ids = set(documents or [])
# If filters were provided, get matching documents
if filters or system_filters:
filtered_docs = await self.db.get_documents(auth, filters=filters, system_filters=system_filters)
document_ids.update(doc.external_id for doc in filtered_docs)
if not document_ids:
raise ValueError("No documents found matching criteria")
# Convert system_filters for document retrieval
folder_name = system_filters.get("folder_name") if system_filters else None
end_user_id = system_filters.get("end_user_id") if system_filters else None
# Batch retrieve documents for authorization check
document_objects = await document_service.batch_retrieve_documents(
list(document_ids),
auth,
folder_name,
end_user_id
)
# Log for debugging
logger.info(f"Graph creation with folder_name={folder_name}, end_user_id={end_user_id}")
logger.info(f"Documents retrieved: {len(document_objects)} out of {len(document_ids)} requested")
if not document_objects:
raise ValueError("No authorized documents found matching criteria")
# Validation is now handled by type annotations
# Create a new graph with authorization info
graph = Graph(
name=name,
document_ids=[doc.external_id for doc in document_objects],
filters=filters,
owner={"type": auth.entity_type, "id": auth.entity_id},
access_control={
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
},
)
# Add folder_name and end_user_id to system_metadata if provided
if system_filters:
if "folder_name" in system_filters:
graph.system_metadata["folder_name"] = system_filters["folder_name"]
if "end_user_id" in system_filters:
graph.system_metadata["end_user_id"] = system_filters["end_user_id"]
# Extract entities and relationships
entities, relationships = await self._process_documents_for_entities(
document_objects, auth, document_service, prompt_overrides
)
# Add entities and relationships to the graph
graph.entities = list(entities.values())
graph.relationships = relationships
# Store the graph in the database
if not await self.db.store_graph(graph):
raise Exception("Failed to store graph")
return graph
async def _process_documents_for_entities(
self,
documents: List[Document],
auth: AuthContext,
document_service,
prompt_overrides: Optional[GraphPromptOverrides] = None,
) -> Tuple[Dict[str, Entity], List[Relationship]]:
"""Process documents to extract entities and relationships.
Args:
documents: List of documents to process
auth: Authentication context
document_service: DocumentService instance for retrieving chunks
prompt_overrides: Optional dictionary with customizations for prompts
{
"entity_resolution": {
"prompt_template": "Custom template...",
"examples": [{"canonical": "...", "variants": [...]}]
}
}
Returns:
Tuple of (entities_dict, relationships_list)
"""
# Dictionary to collect entities by label (to avoid duplicates)
entities = {}
# List to collect all relationships
relationships = []
# List to collect all extracted entities for resolution
all_entities = []
# Track all initial entities with their labels to fix relationship mapping
initial_entities = []
# Collect all chunk sources from documents.
chunk_sources = [
ChunkSource(document_id=doc.external_id, chunk_number=i)
for doc in documents
for i, _ in enumerate(doc.chunk_ids)
]
# Batch retrieve chunks
chunks = await document_service.batch_retrieve_chunks(chunk_sources, auth)
logger.info(f"Retrieved {len(chunks)} chunks for processing")
# Process each chunk individually
for chunk in chunks:
try:
# Get entity_extraction override if provided
extraction_overrides = None
if prompt_overrides:
# Get entity_extraction from the model
extraction_overrides = prompt_overrides.entity_extraction
# Extract entities and relationships from the chunk
chunk_entities, chunk_relationships = await self.extract_entities_from_text(
chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides
)
# Store all initially extracted entities to track their IDs
initial_entities.extend(chunk_entities)
# Add entities to the collection, avoiding duplicates based on exact label match
for entity in chunk_entities:
if entity.label not in entities:
# For new entities, initialize chunk_sources with the current chunk
entities[entity.label] = entity
all_entities.append(entity)
else:
# If entity already exists, add this chunk source if not already present
existing_entity = entities[entity.label]
# Add to chunk_sources dictionary
if chunk.document_id not in existing_entity.chunk_sources:
existing_entity.chunk_sources[chunk.document_id] = [chunk.chunk_number]
elif (
chunk.chunk_number
not in existing_entity.chunk_sources[chunk.document_id]
):
existing_entity.chunk_sources[chunk.document_id].append(
chunk.chunk_number
)
# Add the current chunk source to each relationship
for relationship in chunk_relationships:
# Add to chunk_sources dictionary
if chunk.document_id not in relationship.chunk_sources:
relationship.chunk_sources[chunk.document_id] = [chunk.chunk_number]
elif chunk.chunk_number not in relationship.chunk_sources[chunk.document_id]:
relationship.chunk_sources[chunk.document_id].append(chunk.chunk_number)
# Add relationships to the collection
relationships.extend(chunk_relationships)
except ValueError as e:
# Handle specific extraction errors we've wrapped
logger.warning(
f"Skipping chunk {chunk.chunk_number} in document {chunk.document_id}: {e}"
)
continue
except Exception as e:
# For other errors, log and re-raise to abort graph creation
logger.error(
f"Fatal error processing chunk {chunk.chunk_number} in document {chunk.document_id}: {e}"
)
raise
# Build a mapping from entity ID to label for ALL initially extracted entities
original_entity_id_to_label = {entity.id: entity.label for entity in initial_entities}
# Check if entity resolution is enabled in settings
settings = get_settings()
# Resolve entities to handle variations like "Trump" vs "Donald J Trump"
if settings.ENABLE_ENTITY_RESOLUTION:
logger.info("Resolving %d entities using LLM...", len(all_entities))
# Extract entity_resolution part if this is a structured override
resolution_overrides = None
if prompt_overrides:
if hasattr(prompt_overrides, "entity_resolution"):
# Get from Pydantic model
resolution_overrides = prompt_overrides.entity_resolution
elif isinstance(prompt_overrides, dict) and "entity_resolution" in prompt_overrides:
# Get from dict
resolution_overrides = prompt_overrides["entity_resolution"]
else:
# Otherwise pass as-is
resolution_overrides = prompt_overrides
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(
all_entities, resolution_overrides
)
logger.info("Entity resolution completed successfully")
else:
logger.info("Entity resolution is disabled in settings.")
# Return identity mapping (each entity maps to itself)
entity_mapping = {entity.label: entity.label for entity in all_entities}
resolved_entities = all_entities
if entity_mapping:
logger.info("Entity resolution complete. Found %d mappings.", len(entity_mapping))
# Create a new entities dictionary with resolved entities
resolved_entities_dict = {}
# Build new entities dictionary with canonical labels
for entity in resolved_entities:
resolved_entities_dict[entity.label] = entity
# Update relationships to use canonical entity labels
updated_relationships = []
# Remap relationships using original entity ID to label mapping
remapped_count = 0
skipped_count = 0
for relationship in relationships:
# Use original_entity_id_to_label to get the labels for relationship endpoints
original_source_label = original_entity_id_to_label.get(relationship.source_id)
original_target_label = original_entity_id_to_label.get(relationship.target_id)
if not original_source_label or not original_target_label:
logger.warning(
f"Skipping relationship with type '{relationship.type}' - could not find original entity labels"
)
skipped_count += 1
continue
# Find canonical labels using the mapping from the resolver
source_canonical = entity_mapping.get(original_source_label, original_source_label)
target_canonical = entity_mapping.get(original_target_label, original_target_label)
# Find the final unique Entity objects using the canonical labels
canonical_source = resolved_entities_dict.get(source_canonical)
canonical_target = resolved_entities_dict.get(target_canonical)
if canonical_source and canonical_target:
# Successfully found the final entities, update the relationship's IDs
relationship.source_id = canonical_source.id
relationship.target_id = canonical_target.id
updated_relationships.append(relationship)
remapped_count += 1
else:
# Could not map to final entities, log and skip
logger.warning(
f"Skipping relationship between '{original_source_label}' and '{original_target_label}' - "
f"canonical entities not found after resolution"
)
skipped_count += 1
logger.info(f"Remapped {remapped_count} relationships, skipped {skipped_count} relationships")
# Deduplicate relationships (same source, target, type)
final_relationships_map = {}
for rel in updated_relationships:
key = (rel.source_id, rel.target_id, rel.type)
if key not in final_relationships_map:
final_relationships_map[key] = rel
else:
# Merge sources into the existing relationship
existing_rel = final_relationships_map[key]
self._merge_relationship_sources(existing_rel, rel)
final_relationships = list(final_relationships_map.values())
logger.info(f"Deduplicated to {len(final_relationships)} unique relationships")
return resolved_entities_dict, final_relationships
# If no entity resolution occurred, return original entities and relationships
return entities, relationships
async def extract_entities_from_text(
self,
content: str,
doc_id: str,
chunk_number: int,
prompt_overrides: Optional[EntityExtractionPromptOverride] = None,
) -> Tuple[List[Entity], List[Relationship]]:
"""
Extract entities and relationships from text content using the LLM.
Args:
content: Text content to process
doc_id: Document ID
chunk_number: Chunk number within the document
Returns:
Tuple of (entities, relationships)
"""
settings = get_settings()
# Limit text length to avoid token limits
content_limited = content[: min(len(content), 5000)]
# We'll use the Pydantic model directly when calling litellm
# No need to generate JSON schema separately
# Get entity extraction overrides if available
extraction_overrides = {}
# Convert prompt_overrides to dict for processing
if prompt_overrides:
# If it's already an EntityExtractionPromptOverride, convert to dict
extraction_overrides = prompt_overrides.model_dump(exclude_none=True)
# Check for custom prompt template
custom_prompt = extraction_overrides.get("prompt_template")
custom_examples = extraction_overrides.get("examples")
# Prepare examples if provided
examples_str = ""
if custom_examples:
# Ensure proper serialization for both dict and Pydantic model examples
if (
isinstance(custom_examples, list)
and custom_examples
and hasattr(custom_examples[0], "model_dump")
):
# List of Pydantic model objects
serialized_examples = [example.model_dump() for example in custom_examples]
else:
# List of dictionaries
serialized_examples = custom_examples
examples_json = {"entities": serialized_examples}
examples_str = f"\nHere are some examples of the kind of entities to extract:\n```json\n{json.dumps(examples_json, indent=2)}\n```\n"
# Modify the system message to handle properties as a string that will be parsed later
system_message = {
"role": "system",
"content": (
"You are an entity extraction and relationship extraction assistant. Extract entities and their relationships from text precisely and thoroughly, extract as many entities and relationships as possible. "
"For entities, include entity label and type (some examples: PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). If the user has given examples, use those, these are just suggestions"
"For relationships, use a simple format with source, target, and relationship fields. Be very through, there are many relationships that are not obvious"
"IMPORTANT: The source and target fields must be simple strings representing entity labels. For example: "
"if you extract entities 'Entity A' and 'Entity B', a relationship would have source: 'Entity A', target: 'Entity B', relationship: 'relates to'. "
"Respond directly in json format, without any additional text or explanations. "
),
}
# Use custom prompt if provided, otherwise use default
if custom_prompt:
user_message = {
"role": "user",
"content": custom_prompt.format(content=content_limited, examples=examples_str),
}
else:
user_message = {
"role": "user",
"content": (
"Extract named entities and their relationships from the following text. "
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). "
"For relationships, specify the source entity, target entity, and the relationship between them. "
"The source and target must be simple strings matching the entity labels, not objects. "
f"{examples_str}"
"Sample relationship format: {\"source\": \"Entity A\", \"target\": \"Entity B\", \"relationship\": \"works for\"}\n\n"
"Return your response as valid JSON:\n\n" + content_limited
),
}
# Get the model configuration from registered_models
model_config = settings.REGISTERED_MODELS.get(settings.GRAPH_MODEL, {})
if not model_config:
raise ValueError(
f"Model '{settings.GRAPH_MODEL}' not found in registered_models configuration"
)
# Prepare the completion request parameters
model_params = {
"model": model_config.get("model_name"),
"messages": [system_message, user_message],
"response_format": ExtractionResult,
}
# Add all model-specific parameters from the config
for key, value in model_config.items():
if key != "model_name": # Skip as we've already handled it
model_params[key] = value
import litellm
import instructor
# Use instructor with litellm to get structured responses
client = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
try:
# Use LiteLLM with instructor for structured completion
logger.debug(f"Calling LiteLLM with instructor and params: {model_params}")
# Extract the model and messages from model_params
model = model_params.pop("model")
messages = model_params.pop("messages")
# Use instructor's chat.completions.create with response_model
response = await client.chat.completions.create(
model=model,
messages=messages,
response_model=ExtractionResult,
**model_params
)
try:
logger.info(f"Extraction result type: {type(response)}")
extraction_result = response # The response is already our Pydantic model
# Make sure the extraction_result has the expected properties
if not hasattr(extraction_result, "entities"):
extraction_result.entities = []
if not hasattr(extraction_result, "relationships"):
extraction_result.relationships = []
except AttributeError as e:
logger.error(f"Invalid response format from LiteLLM: {e}")
logger.debug(f"Raw response structure: {response.choices[0]}")
return [], []
except Exception as e:
logger.error(f"Error during entity extraction with LiteLLM: {str(e)}")
# Enable this for more verbose debugging
# litellm.set_verbose = True
return [], []
# Process extraction results
entities, relationships = self._process_extraction_results(
extraction_result, doc_id, chunk_number
)
logger.info(
f"Extracted {len(entities)} entities and {len(relationships)} relationships from document {doc_id}, chunk {chunk_number}"
)
return entities, relationships
def _process_extraction_results(
self, extraction_result: ExtractionResult, doc_id: str, chunk_number: int
) -> Tuple[List[Entity], List[Relationship]]:
"""Process extraction results into entity and relationship objects."""
# Initialize chunk_sources with the current chunk - reused across entities
chunk_sources = {doc_id: [chunk_number]}
# Convert extracted data to entity objects using list comprehension
entities = [
Entity(
label=entity.label,
type=entity.type,
properties=entity.properties,
chunk_sources=chunk_sources.copy(), # Need to copy to avoid shared reference
document_ids=[doc_id],
)
for entity in extraction_result.entities
]
# Create a mapping of entity labels to IDs
entity_mapping = {entity.label: entity.id for entity in entities}
# Convert to relationship objects using list comprehension with filtering
relationships = [
Relationship(
source_id=entity_mapping[rel.source],
target_id=entity_mapping[rel.target],
type=rel.relationship,
chunk_sources=chunk_sources.copy(), # Need to copy to avoid shared reference
document_ids=[doc_id],
)
for rel in extraction_result.relationships
if rel.source in entity_mapping and rel.target in entity_mapping
]
return entities, relationships
async def query_with_graph(
self,
query: str,
graph_name: str,
auth: AuthContext,
document_service, # Passed to avoid circular import
filters: Optional[Dict[str, Any]] = None,
k: int = 20,
min_score: float = 0.0,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
use_reranking: Optional[bool] = None,
use_colpali: Optional[bool] = None,
hop_depth: int = 1,
include_paths: bool = False,
prompt_overrides: Optional[QueryPromptOverrides] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> CompletionResponse:
"""Generate completion using knowledge graph-enhanced retrieval.
This method enhances retrieval by:
1. Extracting entities from the query
2. Finding similar entities in the graph
3. Traversing the graph to find related entities
4. Retrieving chunks containing these entities
5. Combining with traditional vector search results
6. Generating a completion with enhanced context
Args:
query: The query text
graph_name: Name of the graph to use
auth: Authentication context
document_service: DocumentService instance for retrieving documents
filters: Optional metadata filters
k: Number of chunks to retrieve
min_score: Minimum similarity score
max_tokens: Maximum tokens for completion
temperature: Temperature for completion
use_reranking: Whether to use reranking
use_colpali: Whether to use colpali embedding
hop_depth: Number of relationship hops to traverse (1-3)
include_paths: Whether to include relationship paths in response
prompt_overrides: Optional QueryPromptOverrides with customizations for prompts
"""
logger.info(f"Querying with graph: {graph_name}, hop depth: {hop_depth}")
# Validation is now handled by type annotations
# Build system filters for scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
logger.info(f"Querying graph with system_filters: {system_filters}")
graph = await self.db.get_graph(graph_name, auth, system_filters=system_filters)
if not graph:
logger.warning(f"Graph '{graph_name}' not found or not accessible")
# Fall back to standard retrieval if graph not found
return await document_service.query(
query=query,
auth=auth,
filters=filters,
k=k,
min_score=min_score,
max_tokens=max_tokens,
temperature=temperature,
use_reranking=use_reranking,
use_colpali=use_colpali,
graph_name=None,
folder_name=folder_name,
end_user_id=end_user_id,
)
# Parallel approach
# 1. Standard vector search
vector_chunks = await document_service.retrieve_chunks(
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
)
logger.info(f"Vector search retrieved {len(vector_chunks)} chunks")
# 2. Graph-based retrieval
# First extract entities from the query
query_entities = await self._extract_entities_from_query(query, prompt_overrides)
logger.info(
f"Extracted {len(query_entities)} entities from query: {', '.join(e.label for e in query_entities)}"
)
# If no entities extracted, fallback to embedding similarity
if not query_entities:
# Find similar entities using embedding similarity
top_entities = await self._find_similar_entities(query, graph.entities, k)
else:
# Use entity resolution to handle variants of the same entity
settings = get_settings()
# First, create combined list of query entities and graph entities for resolution
combined_entities = query_entities + graph.entities
# Resolve entities to identify variants if enabled
if settings.ENABLE_ENTITY_RESOLUTION:
logger.info(f"Resolving {len(combined_entities)} entities from query and graph...")
# Get the entity_resolution override if provided
resolution_overrides = None
if prompt_overrides:
# Get just the entity_resolution part
resolution_overrides = prompt_overrides.entity_resolution
resolved_entities, entity_mapping = await self.entity_resolver.resolve_entities(
combined_entities, prompt_overrides=resolution_overrides
)
else:
logger.info("Entity resolution is disabled in settings.")
# Return identity mapping (each entity maps to itself)
entity_mapping = {entity.label: entity.label for entity in combined_entities}
# Create a mapping of resolved entity labels to graph entities
entity_map = {}
for entity in graph.entities:
# Get canonical form for this entity
canonical_label = entity_mapping.get(entity.label, entity.label)
entity_map[canonical_label.lower()] = entity
matched_entities = []
# Match extracted entities with graph entities using canonical labels
for query_entity in query_entities:
# Get canonical form for this query entity
canonical_query = entity_mapping.get(query_entity.label, query_entity.label)
if canonical_query.lower() in entity_map:
matched_entities.append(entity_map[canonical_query.lower()])
# If no matches, fallback to embedding similarity
if matched_entities:
top_entities = [
(entity, 1.0) for entity in matched_entities
] # Score 1.0 for direct matches
else:
top_entities = await self._find_similar_entities(query, graph.entities, k)
logger.info(f"Found {len(top_entities)} relevant entities in graph")
# Traverse the graph to find related entities
expanded_entities = self._expand_entities(graph, [e[0] for e in top_entities], hop_depth)
logger.info(f"Expanded to {len(expanded_entities)} entities after traversal")
# Get specific chunks containing these entities
graph_chunks = await self._retrieve_entity_chunks(
expanded_entities, auth, filters, document_service, folder_name, end_user_id
)
logger.info(f"Retrieved {len(graph_chunks)} chunks containing relevant entities")
# Calculate paths if requested
paths = []
if include_paths:
paths = self._find_relationship_paths(graph, [e[0] for e in top_entities], hop_depth)
logger.info(f"Found {len(paths)} relationship paths")
# Combine vector and graph results
combined_chunks = self._combine_chunk_results(vector_chunks, graph_chunks, k)
# Generate completion with enhanced context
completion_response = await self._generate_completion(
query,
combined_chunks,
document_service,
max_tokens,
temperature,
include_paths,
paths,
auth,
graph_name,
prompt_overrides,
folder_name=folder_name,
end_user_id=end_user_id,
)
return completion_response
async def _extract_entities_from_query(
self, query: str, prompt_overrides: Optional[QueryPromptOverrides] = None
) -> List[Entity]:
"""Extract entities from the query text using the LLM."""
try:
# Get entity_extraction override if provided
extraction_overrides = None
if prompt_overrides:
# Get the entity_extraction part
extraction_overrides = prompt_overrides.entity_extraction
# Extract entities from the query using the same extraction function
# but with a simplified prompt specific for queries
entities, _ = await self.extract_entities_from_text(
content=query,
doc_id="query", # Use "query" as doc_id
chunk_number=0, # Use 0 as chunk_number
prompt_overrides=extraction_overrides,
)
return entities
except Exception as e:
# If extraction fails, log and return empty list to fall back to embedding similarity
logger.warning(f"Failed to extract entities from query: {e}")
return []
async def _find_similar_entities(
self, query: str, entities: List[Entity], k: int
) -> List[Tuple[Entity, float]]:
"""Find entities similar to the query based on embedding similarity."""
if not entities:
return []
# Get embedding for query
query_embedding = await self.embedding_model.embed_for_query(query)
# Create entity text representations and get embeddings for all entities
entity_texts = [
f"{entity.label} {entity.type} "
+ " ".join(f"{key}: {value}" for key, value in entity.properties.items())
for entity in entities
]
# Get embeddings for all entity texts
entity_embeddings = await self._batch_get_embeddings(entity_texts)
# Calculate similarities and pair with entities
entity_similarities = [
(entity, self._calculate_cosine_similarity(query_embedding, embedding))
for entity, embedding in zip(entities, entity_embeddings)
]
# Sort by similarity and take top k
entity_similarities.sort(key=lambda x: x[1], reverse=True)
return entity_similarities[: min(k, len(entity_similarities))]
async def _batch_get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a batch of texts efficiently."""
# This could be implemented with proper batch embedding if the embedding model supports it
# For now, we'll just map over the texts and get embeddings one by one
return [await self.embedding_model.embed_for_query(text) for text in texts]
def _expand_entities(
self, graph: Graph, seed_entities: List[Entity], hop_depth: int
) -> List[Entity]:
"""Expand entities by traversing relationships."""
if hop_depth <= 1:
return seed_entities
# Create a set of entity IDs we've seen
seen_entity_ids = {entity.id for entity in seed_entities}
all_entities = list(seed_entities)
# Create a map for fast entity lookup
entity_map = {entity.id: entity for entity in graph.entities}
# For each hop
for _ in range(hop_depth - 1):
new_entities = []
# For each entity we've found so far
for entity in all_entities:
# Find connected entities through relationships
connected_ids = self._get_connected_entity_ids(
graph.relationships, entity.id, seen_entity_ids
)
# Add new connected entities
for entity_id in connected_ids:
if target_entity := entity_map.get(entity_id):
new_entities.append(target_entity)
seen_entity_ids.add(entity_id)
# Add new entities to our list
all_entities.extend(new_entities)
# Stop if no new entities found
if not new_entities:
break
return all_entities
def _get_connected_entity_ids(
self, relationships: List[Relationship], entity_id: str, seen_ids: Set[str]
) -> Set[str]:
"""Get IDs of entities connected to the given entity that haven't been seen yet."""
connected_ids = set()
for relationship in relationships:
# Check outgoing relationships
if relationship.source_id == entity_id and relationship.target_id not in seen_ids:
connected_ids.add(relationship.target_id)
# Check incoming relationships
elif relationship.target_id == entity_id and relationship.source_id not in seen_ids:
connected_ids.add(relationship.source_id)
return connected_ids
async def _retrieve_entity_chunks(
self,
entities: List[Entity],
auth: AuthContext,
filters: Optional[Dict[str, Any]],
document_service,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> List[ChunkResult]:
"""Retrieve chunks containing the specified entities."""
# Initialize filters if None
if filters is None:
filters = {}
if not entities:
return []
# Collect all chunk sources from entities using set comprehension
entity_chunk_sources = {
(doc_id, chunk_num)
for entity in entities
for doc_id, chunk_numbers in entity.chunk_sources.items()
for chunk_num in chunk_numbers
}
# Get unique document IDs for authorization check
doc_ids = {doc_id for doc_id, _ in entity_chunk_sources}
# Check document authorization with system filters
documents = await document_service.batch_retrieve_documents(list(doc_ids), auth, folder_name, end_user_id)
# Apply filters if needed
authorized_doc_ids = {
doc.external_id
for doc in documents
if not filters or all(doc.metadata.get(k) == v for k, v in filters.items())
}
# Filter chunk sources to only those from authorized documents
chunk_sources = [
ChunkSource(document_id=doc_id, chunk_number=chunk_num)
for doc_id, chunk_num in entity_chunk_sources
if doc_id in authorized_doc_ids
]
# Retrieve and return chunks if we have any valid sources
return (
await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name=folder_name, end_user_id=end_user_id)
if chunk_sources
else []
)
def _combine_chunk_results(
self, vector_chunks: List[ChunkResult], graph_chunks: List[ChunkResult], k: int
) -> List[ChunkResult]:
"""Combine and deduplicate chunk results from vector search and graph search."""
# Create dictionary with vector chunks first
all_chunks = {f"{chunk.document_id}_{chunk.chunk_number}": chunk for chunk in vector_chunks}
# Process and add graph chunks with a boost
for chunk in graph_chunks:
chunk_key = f"{chunk.document_id}_{chunk.chunk_number}"
# Set default score if missing and apply boost (5%)
chunk.score = min(1.0, (getattr(chunk, "score", 0.7) or 0.7) * 1.05)
# Keep the higher-scored version
if chunk_key not in all_chunks or chunk.score > (getattr(all_chunks.get(chunk_key), "score", 0) or 0):
all_chunks[chunk_key] = chunk
# Convert to list, sort by score, and return top k
return sorted(all_chunks.values(), key=lambda x: getattr(x, "score", 0), reverse=True)[:k]
def _find_relationship_paths(
self, graph: Graph, seed_entities: List[Entity], hop_depth: int
) -> List[List[str]]:
"""Find meaningful paths in the graph starting from seed entities."""
paths = []
entity_map = {entity.id: entity for entity in graph.entities}
# For each seed entity
for start_entity in seed_entities:
# Start BFS from this entity
queue = [(start_entity.id, [start_entity.label])]
visited = set([start_entity.id])
while queue:
entity_id, path = queue.pop(0)
# If path is already at max length, record it but don't expand
if len(path) >= hop_depth * 2: # *2 because path includes relationship types
paths.append(path)
continue
# Find connected relationships
for relationship in graph.relationships:
# Process both outgoing and incoming relationships
if relationship.source_id == entity_id:
target_id = relationship.target_id
if target_id in visited:
continue
target_entity = entity_map.get(target_id)
if not target_entity:
continue
# Check for common chunks
common_chunks = self._find_common_chunks(
entity_map[entity_id], target_entity, relationship
)
# Only include relationships where entities co-occur
if common_chunks:
visited.add(target_id)
# Create path with relationship info
rel_context = (
f"({relationship.type}, {len(common_chunks)} shared chunks)"
)
new_path = path + [rel_context, target_entity.label]
queue.append((target_id, new_path))
paths.append(new_path)
elif relationship.target_id == entity_id:
source_id = relationship.source_id
if source_id in visited:
continue
source_entity = entity_map.get(source_id)
if not source_entity:
continue
# Check for common chunks
common_chunks = self._find_common_chunks(
entity_map[entity_id], source_entity, relationship
)
# Only include relationships where entities co-occur
if common_chunks:
visited.add(source_id)
# Create path with relationship info (note reverse direction)
rel_context = (
f"(is {relationship.type} of, {len(common_chunks)} shared chunks)"
)
new_path = path + [rel_context, source_entity.label]
queue.append((source_id, new_path))
paths.append(new_path)
return paths
def _find_common_chunks(
self, entity1: Entity, entity2: Entity, relationship: Relationship
) -> Set[Tuple[str, int]]:
"""Find chunks that contain both entities and their relationship."""
# Get chunk locations for each element
entity1_chunks = set()
for doc_id, chunk_numbers in entity1.chunk_sources.items():
for chunk_num in chunk_numbers:
entity1_chunks.add((doc_id, chunk_num))
entity2_chunks = set()
for doc_id, chunk_numbers in entity2.chunk_sources.items():
for chunk_num in chunk_numbers:
entity2_chunks.add((doc_id, chunk_num))
rel_chunks = set()
for doc_id, chunk_numbers in relationship.chunk_sources.items():
for chunk_num in chunk_numbers:
rel_chunks.add((doc_id, chunk_num))
# Return intersection
return entity1_chunks.intersection(entity2_chunks).intersection(rel_chunks)
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors."""
# Convert to numpy arrays and calculate in one go
vec1_np, vec2_np = np.array(vec1), np.array(vec2)
# Get magnitudes
magnitude1, magnitude2 = np.linalg.norm(vec1_np), np.linalg.norm(vec2_np)
# Avoid division by zero and calculate similarity
return (
0
if magnitude1 == 0 or magnitude2 == 0
else np.dot(vec1_np, vec2_np) / (magnitude1 * magnitude2)
)
async def _generate_completion(
self,
query: str,
chunks: List[ChunkResult],
document_service,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
include_paths: bool = False,
paths: Optional[List[List[str]]] = None,
auth: Optional[AuthContext] = None,
graph_name: Optional[str] = None,
prompt_overrides: Optional[QueryPromptOverrides] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> CompletionResponse:
"""Generate completion using the retrieved chunks and optional path information."""
if not chunks:
chunks = [] # Ensure chunks is a list even if empty
# Create document results for context augmentation
documents = await document_service._create_document_results(auth, chunks)
# Create augmented chunk contents
chunk_contents = [
chunk.augmented_content(documents[chunk.document_id])
for chunk in chunks
if chunk.document_id in documents
]
# Include graph context in prompt if paths are requested
if include_paths and paths:
# Create a readable representation of the paths
paths_text = "Knowledge Graph Context:\n"
# Limit to 5 paths to avoid token limits
for path in paths[:5]:
paths_text += " -> ".join(path) + "\n"
# Add to the first chunk or create a new first chunk if none
if chunk_contents:
chunk_contents[0] = paths_text + "\n\n" + chunk_contents[0]
else:
chunk_contents = [paths_text]
# Generate completion with prompt override if provided
custom_prompt_template = None
if prompt_overrides and prompt_overrides.query:
custom_prompt_template = prompt_overrides.query.prompt_template
request = CompletionRequest(
query=query,
context_chunks=chunk_contents,
max_tokens=max_tokens,
temperature=temperature,
prompt_template=custom_prompt_template,
folder_name=folder_name,
end_user_id=end_user_id,
)
# Get completion from model
response = await document_service.completion_model.complete(request)
# Add sources information
response.sources = [
ChunkSource(
document_id=chunk.document_id,
chunk_number=chunk.chunk_number,
score=getattr(chunk, "score", 0),
)
for chunk in chunks
]
# Include graph metadata if paths were requested
if include_paths:
# Initialize metadata if it doesn't exist
if not hasattr(response, "metadata") or response.metadata is None:
response.metadata = {}
# Extract unique entities from paths (items that don't start with "(")
unique_entities = set()
if paths:
for path in paths[:5]:
for item in path:
if not item.startswith("("):
unique_entities.add(item)
# Add graph-specific metadata
response.metadata["graph"] = {
"name": graph_name,
"relevant_entities": list(unique_entities),
"paths": [" -> ".join(path) for path in paths[:5]] if paths else [],
}
return response