From 40fc8fab6b9af65dbdf208e1ef38e3c1e1303fa3 Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Sat, 12 Apr 2025 00:24:38 -0700 Subject: [PATCH] SciER Science graph evals (#81) --- core/services/graph_service.py | 97 +- evaluations/Science graphs (SciER)/README.md | 87 ++ .../Science graphs (SciER)/count_entities.py | 18 + .../Science graphs (SciER)/data_loader.py | 22 + .../Science graphs (SciER)/evaluate_result.py | 1116 +++++++++++++++++ .../scier_evaluation.py | 289 +++++ evaluations/hotpot_ragas_eval.py | 259 ++++ 7 files changed, 1856 insertions(+), 32 deletions(-) create mode 100644 evaluations/Science graphs (SciER)/README.md create mode 100644 evaluations/Science graphs (SciER)/count_entities.py create mode 100644 evaluations/Science graphs (SciER)/data_loader.py create mode 100644 evaluations/Science graphs (SciER)/evaluate_result.py create mode 100644 evaluations/Science graphs (SciER)/scier_evaluation.py create mode 100644 evaluations/hotpot_ragas_eval.py diff --git a/core/services/graph_service.py b/core/services/graph_service.py index 3d11c21..2368267 100644 --- a/core/services/graph_service.py +++ b/core/services/graph_service.py @@ -480,6 +480,8 @@ class GraphService: relationships = [] # List to collect all extracted entities for resolution all_entities = [] + # Track all initial entities with their labels to fix relationship mapping + initial_entities = [] # Collect all chunk sources from documents. chunk_sources = [ @@ -505,6 +507,9 @@ class GraphService: chunk_entities, chunk_relationships = await self.extract_entities_from_text( chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides ) + + # Store all initially extracted entities to track their IDs + initial_entities.extend(chunk_entities) # Add entities to the collection, avoiding duplicates based on exact label match for entity in chunk_entities: @@ -551,6 +556,9 @@ class GraphService: ) raise + # Build a mapping from entity ID to label for ALL initially extracted entities + original_entity_id_to_label = {entity.id: entity.label for entity in initial_entities} + # Check if entity resolution is enabled in settings settings = get_settings() @@ -591,37 +599,62 @@ class GraphService: # Update relationships to use canonical entity labels updated_relationships = [] - # Create an entity index by ID for efficient lookups - entity_by_id = {entity.id: entity for entity in all_entities} - + # Remap relationships using original entity ID to label mapping + remapped_count = 0 + skipped_count = 0 + for relationship in relationships: - # Lookup entities by ID directly from the index - source_entity = entity_by_id.get(relationship.source_id) - target_entity = entity_by_id.get(relationship.target_id) - - if source_entity and target_entity: - # Get canonical labels - source_canonical = entity_mapping.get(source_entity.label, source_entity.label) - target_canonical = entity_mapping.get(target_entity.label, target_entity.label) - # Get canonical entities - canonical_source = resolved_entities_dict.get(source_canonical) - canonical_target = resolved_entities_dict.get(target_canonical) - if canonical_source and canonical_target: - # Update relationship to point to canonical entities - relationship.source_id = canonical_source.id - relationship.target_id = canonical_target.id - updated_relationships.append(relationship) - else: - # Skip relationships that can't be properly mapped - logger.warning( - "Skipping relationship between '%s' and '%s' - canonical entities not found", - source_entity.label, - target_entity.label, - ) - else: - # Keep relationship as is if we can't find the entities + # Use original_entity_id_to_label to get the labels for relationship endpoints + original_source_label = original_entity_id_to_label.get(relationship.source_id) + original_target_label = original_entity_id_to_label.get(relationship.target_id) + + if not original_source_label or not original_target_label: + logger.warning( + f"Skipping relationship with type '{relationship.type}' - could not find original entity labels" + ) + skipped_count += 1 + continue + + # Find canonical labels using the mapping from the resolver + source_canonical = entity_mapping.get(original_source_label, original_source_label) + target_canonical = entity_mapping.get(original_target_label, original_target_label) + + # Find the final unique Entity objects using the canonical labels + canonical_source = resolved_entities_dict.get(source_canonical) + canonical_target = resolved_entities_dict.get(target_canonical) + + if canonical_source and canonical_target: + # Successfully found the final entities, update the relationship's IDs + relationship.source_id = canonical_source.id + relationship.target_id = canonical_target.id updated_relationships.append(relationship) - return resolved_entities_dict, updated_relationships + remapped_count += 1 + else: + # Could not map to final entities, log and skip + logger.warning( + f"Skipping relationship between '{original_source_label}' and '{original_target_label}' - " + f"canonical entities not found after resolution" + ) + skipped_count += 1 + + logger.info(f"Remapped {remapped_count} relationships, skipped {skipped_count} relationships") + + # Deduplicate relationships (same source, target, type) + final_relationships_map = {} + for rel in updated_relationships: + key = (rel.source_id, rel.target_id, rel.type) + if key not in final_relationships_map: + final_relationships_map[key] = rel + else: + # Merge sources into the existing relationship + existing_rel = final_relationships_map[key] + self._merge_relationship_sources(existing_rel, rel) + + final_relationships = list(final_relationships_map.values()) + logger.info(f"Deduplicated to {len(final_relationships)} unique relationships") + + return resolved_entities_dict, final_relationships + # If no entity resolution occurred, return original entities and relationships return entities, relationships @@ -685,9 +718,9 @@ class GraphService: system_message = { "role": "system", "content": ( - "You are an entity extraction assistant. Extract entities and their relationships from text precisely and thoroughly. " - "For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " - "For relationships, use a simple format with source, target, and relationship fields. " + "You are an entity extraction and relationship extraction assistant. Extract entities and their relationships from text precisely and thoroughly, extract as many entities and relationships as possible. " + "For entities, include entity label and type (some examples: PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). If the user has given examples, use those, these are just suggestions" + "For relationships, use a simple format with source, target, and relationship fields. Be very through, there are many relationships that are not obvious" "IMPORTANT: The source and target fields must be simple strings representing entity labels. For example: " "if you extract entities 'Entity A' and 'Entity B', a relationship would have source: 'Entity A', target: 'Entity B', relationship: 'relates to'. " "Respond directly in json format, without any additional text or explanations. " diff --git a/evaluations/Science graphs (SciER)/README.md b/evaluations/Science graphs (SciER)/README.md new file mode 100644 index 0000000..0281db5 --- /dev/null +++ b/evaluations/Science graphs (SciER)/README.md @@ -0,0 +1,87 @@ +# SciER Evaluation for Morphik + +This directory contains scripts for evaluating different language models' entity and relation extraction capabilities against the [SciER (Scientific Entity Recognition) dataset](https://github.com/allenai/SciERC). + +## Overview + +The evaluation workflow is split into two parts: +1. **Graph Creation**: Generate a knowledge graph using your configured model in `morphik.toml` +2. **Graph Evaluation**: Evaluate the created graph against the ground truth annotations + +This separation allows you to test different models by changing the configuration in `morphik.toml` between runs. + +The evaluation uses the SciER test dataset which can be found at [https://github.com/edzq/SciER/blob/main/SciER/LLM/test.jsonl](https://github.com/edzq/SciER/blob/main/SciER/LLM/test.jsonl). + +## Setup + +1. Make sure you have a local Morphik server running +2. Set up your OpenAI API key for evaluation: + ```bash + export OPENAI_API_KEY=your_key_here + ``` + +## Running the Evaluation + +### Step 1: Configure Your Model + +Edit `morphik.toml` to specify the model you want to test: + +```toml +[graph] +model = "openai_gpt4o" # Reference to a key in registered_models +enable_entity_resolution = true +``` + +### Step 2: Create a Knowledge Graph + +Run the graph creation script: + +```bash +python scier_evaluation.py --model-name gpt4o +``` + +The script will output a graph name like `scier_gpt4o_12345678` when complete. + +### Step 3: Evaluate the Knowledge Graph + +Evaluate the created graph: + +```bash +python evaluate_result.py --graph-name scier_gpt4o_12345678 +``` + +### Step 4: Test Different Models + +To compare models: +1. Change the model in `morphik.toml` +2. Repeat steps 2-3 with a different `--model-name` +3. Compare the resulting metrics and visualizations + +## Command Arguments + +### scier_evaluation.py +- `--model-name`: Name to identify this model in results (default: "default_model") +- `--limit`: Maximum number of documents to process (default: 57) +- `--run-id`: Unique identifier for the run (default: auto-generated) + +### evaluate_result.py +- `--graph-name`: Name of the graph to evaluate (**required**) +- `--model-name`: Name for the evaluation results (default: "existing_model_openai") +- `--similarity-threshold`: Threshold for semantic similarity matching (default: 0.70) +- `--embedding-model`: OpenAI embedding model to use (default: "text-embedding-3-small") + +## Results + +The evaluation generates: +- CSV files with precision, recall, and F1 metrics +- Visualizations comparing entity and relation extraction performance +- Entity and relation count comparisons + +Results are saved in a directory with the format: `scier_results_{model_name}_{run_id}/` + +## Tips for Model Comparison + +- Use descriptive model names in the `--model-name` parameter +- Keep the same similarity threshold across evaluations +- Compare models using the generated visualization charts +- Look at both entity and relation extraction metrics diff --git a/evaluations/Science graphs (SciER)/count_entities.py b/evaluations/Science graphs (SciER)/count_entities.py new file mode 100644 index 0000000..c53eab3 --- /dev/null +++ b/evaluations/Science graphs (SciER)/count_entities.py @@ -0,0 +1,18 @@ +import json + +entity_count = 0 +relation_count = 0 +document_count = 0 + +with open("test.jsonl", "r", encoding="utf-8") as f: + for line in f: + document_count += 1 + data = json.loads(line) + if "ner" in data: + entity_count += len(data["ner"]) + if "rel" in data: + relation_count += len(data["rel"]) + +print(f"Total documents: {document_count}") +print(f"Total entities: {entity_count}") +print(f"Total relationships: {relation_count}") diff --git a/evaluations/Science graphs (SciER)/data_loader.py b/evaluations/Science graphs (SciER)/data_loader.py new file mode 100644 index 0000000..9c7ea8a --- /dev/null +++ b/evaluations/Science graphs (SciER)/data_loader.py @@ -0,0 +1,22 @@ +import json + + +def load_jsonl(jsonl_path: str) -> list: + data = [] + with open(jsonl_path) as f: + for line in f: + data.append(json.loads(line)) + return data + + +# ========= usage example for load_jsonl function with LLM input ======= +# dataset = load_jsonl('./SciER/LLM/test.jsonl') +# sent = dataset[0] +# print(sent.keys()) +# print(sent['sentence']) +# print('----------------') +# print(sent['ner']) +# print('----------------') +# print(sent['rel']) +# print('----------------') +# print(sent['rel_plus']) diff --git a/evaluations/Science graphs (SciER)/evaluate_result.py b/evaluations/Science graphs (SciER)/evaluate_result.py new file mode 100644 index 0000000..d95f495 --- /dev/null +++ b/evaluations/Science graphs (SciER)/evaluate_result.py @@ -0,0 +1,1116 @@ +#!/usr/bin/env python3 +""" +SciER Evaluation Script for Morphik - OpenAI Embeddings Version + +This script evaluates an existing Morphik graph against the SciER dataset +using OpenAI embeddings for semantic similarity calculations. +""" + +import os +import json +import uuid +import argparse +from pathlib import Path +from typing import Dict, List, Tuple, Any, Set +from collections import defaultdict +import pandas as pd +import numpy as np +from dotenv import load_dotenv +from tqdm import tqdm +import matplotlib.pyplot as plt +import requests +from scipy.spatial.distance import cosine +import multiprocessing +from concurrent.futures import ThreadPoolExecutor +import time + +from morphik import Morphik + +# Import SciER data loader +from data_loader import load_jsonl + +# Load environment variables +load_dotenv() + + +class OpenAIEmbedding: + """ + OpenAI embedding similarity calculator. + A faster alternative to SciBERT for computing semantic similarity. + Uses text-embedding-3-small model. + """ + + def __init__( + self, + model_name="text-embedding-3-small", + threshold=0.70, + api_base="https://api.openai.com/v1", + cache_size=10000, + batch_size=20, + ): + self.model_name = model_name + self.threshold = threshold + self.api_base = api_base + self.embedding_cache = {} # Cache to store embeddings + self.cache_size = cache_size + self.batch_size = batch_size + + # Get OpenAI API key from environment + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OPENAI_API_KEY not found in environment variables") + + # Rate limiting parameters + self.requests_per_minute = 3500 # Adjust based on your OpenAI rate limits + self.min_time_between_requests = 60.0 / self.requests_per_minute + self.last_request_time = 0 + + def get_embedding(self, text): + """Get embeddings for a text string using OpenAI API.""" + if not text.strip(): + # Return a zero vector for empty text + return np.zeros(1536) # text-embedding-3-small dimension + + # Check cache first + if text in self.embedding_cache: + return self.embedding_cache[text] + + # Rate limiting + current_time = time.time() + time_since_last_request = current_time - self.last_request_time + if time_since_last_request < self.min_time_between_requests: + wait_time = self.min_time_between_requests - time_since_last_request + time.sleep(wait_time) + + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + response = requests.post( + f"{self.api_base}/embeddings", + headers=headers, + json={"model": self.model_name, "input": text, "encoding_format": "float"}, + ) + + self.last_request_time = time.time() + + if response.status_code != 200: + print(f"Error from OpenAI API: {response.text}") + return np.zeros(1536) + + data = response.json() + embedding = np.array(data["data"][0]["embedding"]) + + # Cache the embedding + if len(self.embedding_cache) < self.cache_size: + self.embedding_cache[text] = embedding + + return embedding + + except Exception as e: + print(f"Exception when calling OpenAI API: {e}") + return np.zeros(1536) + + def get_embeddings_batch(self, texts): + """Get embeddings for multiple texts in batch.""" + embeddings = [] + texts_to_process = [] + indices_to_process = [] + + # First check cache + for i, text in enumerate(texts): + if not text.strip(): + embeddings.append(np.zeros(1536)) + elif text in self.embedding_cache: + embeddings.append(self.embedding_cache[text]) + else: + texts_to_process.append(text) + indices_to_process.append(i) + embeddings.append(None) # Placeholder + + # Process uncached texts in smaller batches + if texts_to_process: + for i in range(0, len(texts_to_process), self.batch_size): + batch = texts_to_process[i : i + self.batch_size] + batch_indices = indices_to_process[i : i + self.batch_size] + + # OpenAI supports batching directly in the API + try: + # Rate limiting + current_time = time.time() + time_since_last_request = current_time - self.last_request_time + if time_since_last_request < self.min_time_between_requests: + wait_time = self.min_time_between_requests - time_since_last_request + time.sleep(wait_time) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + response = requests.post( + f"{self.api_base}/embeddings", + headers=headers, + json={"model": self.model_name, "input": batch, "encoding_format": "float"}, + ) + + self.last_request_time = time.time() + + if response.status_code != 200: + print(f"Error from OpenAI API: {response.text}") + batch_embeddings = [np.zeros(1536) for _ in batch] + else: + data = response.json() + # Sort by index as OpenAI returns in the same order as input + batch_embeddings = [np.array(item["embedding"]) for item in data["data"]] + + # Update embeddings and cache + for j, embedding in enumerate(batch_embeddings): + idx = batch_indices[j] + text = batch[j] + embeddings[idx] = embedding + if len(self.embedding_cache) < self.cache_size: + self.embedding_cache[text] = embedding + + except Exception as e: + print(f"Exception when batch calling OpenAI API: {e}") + # Fall back to individual API calls + with ThreadPoolExecutor(max_workers=min(len(batch), 5)) as executor: + batch_embeddings = list(executor.map(self.get_embedding, batch)) + + # Update embeddings and cache + for j, embedding in enumerate(batch_embeddings): + idx = batch_indices[j] + text = batch[j] + embeddings[idx] = embedding + if len(self.embedding_cache) < self.cache_size: + self.embedding_cache[text] = embedding + + return embeddings + + def compute_similarity(self, text1, text2): + """Compute cosine similarity between two texts.""" + embedding1 = self.get_embedding(text1) + embedding2 = self.get_embedding(text2) + + # Compute cosine similarity (1 - cosine distance) + similarity = 1 - cosine(embedding1, embedding2) + return similarity + + def compute_similarity_with_embeddings(self, embedding1, embedding2): + """Compute cosine similarity between two pre-computed embeddings.""" + # Compute cosine similarity (1 - cosine distance) + similarity = 1 - cosine(embedding1, embedding2) + return similarity + + def are_semantically_similar(self, text1, text2): + """Check if two texts are semantically similar based on threshold.""" + similarity = self.compute_similarity(text1, text2) + return similarity >= self.threshold, similarity + + def compute_similarities_batch(self, text_pairs): + """ + Compute similarities for multiple text pairs in parallel. + + Args: + text_pairs: List of (text1, text2) tuples + + Returns: + List of (is_similar, similarity) tuples + """ + # Extract all unique texts + all_texts = [] + text_indices = {} + + for text1, text2 in text_pairs: + if text1 not in text_indices: + text_indices[text1] = len(all_texts) + all_texts.append(text1) + if text2 not in text_indices: + text_indices[text2] = len(all_texts) + all_texts.append(text2) + + # Get embeddings for all texts in batch + all_embeddings = self.get_embeddings_batch(all_texts) + + # Compute similarities + results = [] + for text1, text2 in text_pairs: + embedding1 = all_embeddings[text_indices[text1]] + embedding2 = all_embeddings[text_indices[text2]] + similarity = self.compute_similarity_with_embeddings(embedding1, embedding2) + is_similar = similarity >= self.threshold + results.append((is_similar, similarity)) + + return results + + +def setup_morphik_client() -> Morphik: + """Initialize and return a Morphik client.""" + return Morphik(timeout=300000, is_local=True) + + +def load_scier_data(dataset_path: str = "test.jsonl", limit: int = None) -> List[Dict]: + """ + Load SciER dataset from the specified JSONL file. + + Args: + dataset_path: Path to the JSONL file + limit: Maximum number of records to load (None for all) + + Returns: + List of dataset records + """ + data = load_jsonl(dataset_path) + if limit: + data = data[:limit] + return data + + +def prepare_text_for_evaluation(records: List[Dict]) -> List[Dict]: + """ + Prepare SciER records for evaluation. + + Args: + records: List of SciER records + + Returns: + List of dictionaries with text and ground truth + """ + documents = [] + + # Group records by doc_id to create complete documents + doc_groups = defaultdict(list) + for record in records: + doc_groups[record["doc_id"]].append(record) + + # Convert grouped records to documents + for doc_id, records in doc_groups.items(): + text = "\n".join([record["sentence"] for record in records]) + + # Collect all entities and relations for ground truth + all_entities = [] + all_relations = [] + for record in records: + all_entities.extend(record["ner"]) + all_relations.extend(record["rel"]) + + documents.append( + { + "text": text, + "metadata": { + "doc_id": doc_id, + "ground_truth_entities": all_entities, + "ground_truth_relations": all_relations, + }, + } + ) + + return documents + + +def load_existing_graph(db: Morphik, graph_name: str) -> Dict: + """ + Load an existing graph by name. + + Args: + db: Morphik client + graph_name: Name of the graph to load + + Returns: + Dictionary containing the graph + """ + print(f"Loading existing graph: {graph_name}") + + # List all graphs and find the one with the specified name + graphs = db.list_graphs() + + target_graph = None + for graph in graphs: + if graph.name == graph_name: + target_graph = graph + break + + if target_graph is None: + raise ValueError( + f"Graph '{graph_name}' not found. Available graphs: {[g.name for g in graphs]}" + ) + + print( + f"Found graph with {len(target_graph.entities)} entities and {len(target_graph.relationships)} relationships" + ) + + return {"graph": target_graph} + + +def evaluate_entity_extraction( + ground_truth: List[List], + extracted_entities: List[Dict], + similarity_calculator, + entity_type_match_required: bool = True, + batch_size: int = 100, + max_workers: int = None, +) -> Dict[str, float]: + """ + Evaluate entity extraction performance using semantic similarity with parallelization. + + Args: + ground_truth: List of ground truth entities from SciER + extracted_entities: List of entities extracted by Morphik + similarity_calculator: Similarity calculator instance + entity_type_match_required: Whether entity types must match + batch_size: Size of batches for parallel processing + max_workers: Maximum number of worker processes (defaults to CPU count) + + Returns: + Dict with precision, recall, and F1 metrics + """ + # Set default max_workers if not provided + if max_workers is None: + max_workers = max(1, multiprocessing.cpu_count() - 1) + + # Process ground truth + gt_entities = [] + for entity in ground_truth: + entity_text = entity[0].lower() + entity_type = entity[1] + gt_entities.append((entity_text, entity_type)) + + # Process extracted entities + extracted = [] + for entity in extracted_entities: + entity_text = entity.label.lower() + entity_type = entity.type + extracted.append((entity_text, entity_type)) + + print( + f"Processing {len(gt_entities)} ground truth entities against {len(extracted)} extracted entities" + ) + + # Prepare all valid entity comparisons based on type matching + comparisons = [] + for i, gt_entity in enumerate(gt_entities): + gt_text, gt_type = gt_entity + for j, ext_entity in enumerate(extracted): + ext_text, ext_type = ext_entity + # Skip if entity type doesn't match (if required) + if entity_type_match_required and gt_type != ext_type: + continue + comparisons.append((i, j, gt_text, ext_text)) + + print(f"Generated {len(comparisons)} comparisons to process") + + # Process in batches + gt_best_matches = {} # gt_index -> (ext_index, score) + + for i in tqdm(range(0, len(comparisons), batch_size), desc="Processing entity batches"): + batch = comparisons[i : i + batch_size] + + # Extract text pairs for batch processing + text_pairs = [(comp[2], comp[3]) for comp in batch] + + # Compute similarities in batch + similarity_results = similarity_calculator.compute_similarities_batch(text_pairs) + + # Process results + for idx, ((gt_idx, ext_idx, _, _), (is_similar, similarity)) in enumerate( + zip(batch, similarity_results) + ): + if is_similar: + # Update best match if it's better than the current one + if gt_idx not in gt_best_matches or similarity > gt_best_matches[gt_idx][1]: + gt_best_matches[gt_idx] = (ext_idx, similarity) + + # Extract matched entities + ext_matched = set(match[0] for match in gt_best_matches.values()) + + # Count matches + true_positives = len(gt_best_matches) + false_negatives = len(gt_entities) - true_positives + false_positives = len(extracted) - len(ext_matched) + + # Calculate metrics + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Create a mapping of ground truth texts to extracted entity IDs for relation evaluation + entity_match_map = {} + for i, (ext_idx, _) in gt_best_matches.items(): + gt_text, _ = gt_entities[i] + ext_id = ( + extracted_entities[ext_idx].id if hasattr(extracted_entities[ext_idx], "id") else None + ) + if ext_id: + entity_match_map[gt_text] = ext_id + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": true_positives, + "false_positives": false_positives, + "false_negatives": false_negatives, + "ground_truth_count": len(gt_entities), + "extracted_count": len(extracted), + "entity_match_map": entity_match_map, + } + + +def evaluate_relation_extraction( + ground_truth: List[List], + extracted_relations: List, + entity_match_map: Dict[str, str], + similarity_calculator, + relation_type_match_required: bool = True, + batch_size: int = 100, + max_workers: int = None, +) -> Dict[str, float]: + """ + Evaluate relation extraction performance using semantic similarity with parallelization. + + Args: + ground_truth: List of ground truth relations from SciER + extracted_relations: List of relations extracted by Morphik + entity_match_map: Mapping from ground truth entity text to extracted entity ID + similarity_calculator: Similarity calculator instance + relation_type_match_required: Whether relation types must match + batch_size: Size of batches for parallel processing + max_workers: Maximum number of worker processes (defaults to CPU count) + + Returns: + Dict with precision, recall, and F1 metrics + """ + # Set default max_workers if not provided + if max_workers is None: + max_workers = max(1, multiprocessing.cpu_count() - 1) + + # Process ground truth + gt_relations = [] + for relation in ground_truth: + source = relation[0].lower() + relation_type = relation[1] + target = relation[2].lower() + gt_relations.append((source, relation_type, target)) + + # Reverse map from entity ID to ground truth text + id_to_gt_text = {entity_id: gt_text for gt_text, entity_id in entity_match_map.items()} + + # Debug output to understand entity mappings + print(f"Entity mapping size: {len(entity_match_map)}") + print(f"Entity reverse mapping size: {len(id_to_gt_text)}") + + # Process extracted relations + extracted_rel_tuples = [] + skipped_relations = 0 + + # Get unique entity IDs from relationships to see what might be missing + all_source_ids = set() + all_target_ids = set() + + for relation in extracted_relations: + try: + if hasattr(relation, "source_id") and hasattr(relation, "target_id"): + source_id = relation.source_id + target_id = relation.target_id + relation_type = relation.type + + # Track all source and target IDs for debugging + all_source_ids.add(source_id) + all_target_ids.add(target_id) + + # Don't skip missing entities, we'll use direct text comparison instead + source_text = id_to_gt_text.get(source_id, None) + target_text = id_to_gt_text.get(target_id, None) + + # If we have direct mappings, use them + if source_text is not None and target_text is not None: + extracted_rel_tuples.append((source_text, relation_type, target_text)) + # Otherwise, try to get the source and target directly from the entity object + else: + try: + # Try to get the source and target texts from their labels + source_entity = next((e for e in relation.source_entity if e), None) + target_entity = next((e for e in relation.target_entity if e), None) + + if ( + source_entity + and target_entity + and hasattr(source_entity, "label") + and hasattr(target_entity, "label") + ): + source_label = source_entity.label.lower() + target_label = target_entity.label.lower() + extracted_rel_tuples.append((source_label, relation_type, target_label)) + else: + skipped_relations += 1 + except Exception as inner_e: + skipped_relations += 1 + print(f"Error extracting entity labels: {inner_e}") + + except (AttributeError, KeyError, TypeError) as e: + print(f"Error processing relation: {relation}. Error: {e}") + + print( + f"Processing {len(gt_relations)} ground truth relations against {len(extracted_rel_tuples)} extracted relations" + ) + print(f"Skipped {skipped_relations} relations due to missing entity mappings") + print(f"Total unique source IDs: {len(all_source_ids)}, target IDs: {len(all_target_ids)}") + + # Debug: Print some sample ground truth relations + if gt_relations: + print("Sample ground truth relations:") + for i, rel in enumerate(gt_relations[:5]): + print(f" {i}: {rel}") + + # Debug: Print some sample extracted relations + if extracted_rel_tuples: + print("Sample extracted relations:") + for i, rel in enumerate(extracted_rel_tuples[:5]): + print(f" {i}: {rel}") + + # Prepare relation comparisons, but don't filter by entity map membership + comparisons = [] + for i, gt_relation in enumerate(gt_relations): + gt_source, gt_rel_type, gt_target = gt_relation + + for j, ext_relation in enumerate(extracted_rel_tuples): + ext_source, ext_rel_type, ext_target = ext_relation + + # Skip if relation type doesn't match (if required) + if relation_type_match_required and gt_rel_type != ext_rel_type: + continue + + comparisons.append((i, j, gt_source, ext_source, gt_target, ext_target)) + + print(f"Generated {len(comparisons)} relation comparisons to process") + + # Process in batches + matched_gt_indices = set() + matched_ext_indices = set() + best_scores = {} # gt_idx -> (ext_idx, score) + + for i in tqdm(range(0, len(comparisons), batch_size), desc="Processing relation batches"): + batch = comparisons[i : i + batch_size] + + # Create batches for source and target comparisons + source_pairs = [(comp[2], comp[3]) for comp in batch] + target_pairs = [(comp[4], comp[5]) for comp in batch] + + # Compute source similarities in batch + source_results = similarity_calculator.compute_similarities_batch(source_pairs) + + # Compute target similarities in batch + target_results = similarity_calculator.compute_similarities_batch(target_pairs) + + # Process results + for idx, ( + (gt_idx, ext_idx, _, _, _, _), + (source_is_similar, source_sim), + (target_is_similar, target_sim), + ) in enumerate(zip(batch, source_results, target_results)): + + if source_is_similar and target_is_similar: + match_score = (source_sim + target_sim) / 2 + + # Update best match if it's better than current one + if gt_idx not in best_scores or match_score > best_scores[gt_idx][1]: + best_scores[gt_idx] = (ext_idx, match_score) + + # Extract matched indices + for gt_idx, (ext_idx, _) in best_scores.items(): + matched_gt_indices.add(gt_idx) + matched_ext_indices.add(ext_idx) + + # Count matches + true_positives = len(matched_gt_indices) + false_negatives = len(gt_relations) - true_positives + false_positives = len(extracted_rel_tuples) - len(matched_ext_indices) + + # Calculate metrics + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": true_positives, + "false_positives": false_positives, + "false_negatives": false_negatives, + "ground_truth_count": len(gt_relations), + "extracted_count": len(extracted_rel_tuples), + } + + +def evaluate_graph( + db: Morphik, + documents: List[Dict], + graph: Any, + model: str, + similarity_calculator, + batch_size: int = 100, + max_workers: int = None, +) -> Dict[str, Dict]: + """ + Evaluate graph extraction against ground truth using semantic similarity. + + Args: + db: Morphik client + documents: Original documents with ground truth + graph: Graph to evaluate + model: Model name used (for tracking) + similarity_calculator: The initialized similarity calculator instance + batch_size: Size of batches for parallel processing + max_workers: Maximum number of worker processes + + Returns: + Dictionary of evaluation results + """ + print(f"\n=== Evaluating graph for {model} ===") + + # Get all entities and relationships from the graph + entities = graph.entities + relationships = graph.relationships + + # Debug information about the graph + print(f"Total entities in graph: {len(entities)}") + print(f"Total relationships in graph: {len(relationships)}") + + # Debug Entity structure if available + if len(entities) > 0: + entity = entities[0] + print(f"DEBUG - Entity structure: {entity}") + print(f"DEBUG - Entity type: {type(entity)}") + print(f"DEBUG - Entity attributes: {dir(entity)[:20]}") + + # Debug Relationship structure if available + if len(relationships) > 0: + relationship = relationships[0] + print(f"DEBUG - Relationship structure: {relationship}") + print(f"DEBUG - Relationship type: {type(relationship)}") + print(f"DEBUG - Relationship attributes: {dir(relationship)[:20]}") + + # Aggregate ground truth from all documents + all_gt_entities = [] + all_gt_relations = [] + + for doc in documents: + all_gt_entities.extend(doc["metadata"]["ground_truth_entities"]) + all_gt_relations.extend(doc["metadata"]["ground_truth_relations"]) + + print(f"Total ground truth entities: {len(all_gt_entities)}") + print(f"Total ground truth relations: {len(all_gt_relations)}") + + # Print some sample ground truth relations + print("Sample ground truth relations:") + for i, rel in enumerate(all_gt_relations[:5]): + print(f" {i}: {rel}") + + # Evaluate entity extraction with parallelization parameters + entity_metrics = evaluate_entity_extraction( + all_gt_entities, + entities, + similarity_calculator=similarity_calculator, + batch_size=batch_size, + max_workers=max_workers, + ) + + # Get entity mapping for semantic relation evaluation + entity_match_map = entity_metrics.pop("entity_match_map", {}) + + # Evaluate relation extraction with parallelization parameters + relation_metrics = evaluate_relation_extraction( + all_gt_relations, + relationships, + entity_match_map=entity_match_map, + similarity_calculator=similarity_calculator, + batch_size=batch_size, + max_workers=max_workers, + ) + + # Store results + results = { + "model": model, + "test_type": "evaluation", + "entity_metrics": entity_metrics, + "relation_metrics": relation_metrics, + "evaluation_method": "openai_embeddings", + } + + # Print summary + print( + f"Entity Extraction (openai_embeddings) - Precision: {entity_metrics['precision']:.4f}, " + f"Recall: {entity_metrics['recall']:.4f}, " + f"F1: {entity_metrics['f1']:.4f}" + ) + print( + f"Relation Extraction (openai_embeddings) - Precision: {relation_metrics['precision']:.4f}, " + f"Recall: {relation_metrics['recall']:.4f}, " + f"F1: {relation_metrics['f1']:.4f}" + ) + + return results + + +def save_results(results: Dict[str, Dict], model_name: str, run_id: str) -> str: + """ + Save evaluation results to CSV and generate visualizations. + + Args: + results: Evaluation results dictionary + model_name: Name of the model used + run_id: Unique run identifier + + Returns: + Path to the saved results directory + """ + # Create results directory + results_dir = f"scier_results_{model_name}_{run_id}" + os.makedirs(results_dir, exist_ok=True) + + # Prepare data for DataFrame + rows = [] + model = results["model"] + test_type = results["test_type"] + evaluation_method = results.get("evaluation_method", "openai_embeddings") + + # Entity metrics + entity_metrics = results["entity_metrics"] + rows.append( + { + "model": model, + "test_type": test_type, + "extraction_type": "entity", + "evaluation_method": evaluation_method, + "precision": entity_metrics["precision"], + "recall": entity_metrics["recall"], + "f1": entity_metrics["f1"], + "true_positives": entity_metrics["true_positives"], + "false_positives": entity_metrics["false_positives"], + "false_negatives": entity_metrics["false_negatives"], + "ground_truth_count": entity_metrics["ground_truth_count"], + "extracted_count": entity_metrics["extracted_count"], + } + ) + + # Relation metrics + relation_metrics = results["relation_metrics"] + rows.append( + { + "model": model, + "test_type": test_type, + "extraction_type": "relation", + "evaluation_method": evaluation_method, + "precision": relation_metrics["precision"], + "recall": relation_metrics["recall"], + "f1": relation_metrics["f1"], + "true_positives": relation_metrics["true_positives"], + "false_positives": relation_metrics["false_positives"], + "false_negatives": relation_metrics["false_negatives"], + "ground_truth_count": relation_metrics["ground_truth_count"], + "extracted_count": relation_metrics["extracted_count"], + } + ) + + # Create DataFrame and save to CSV + df = pd.DataFrame(rows) + csv_path = os.path.join(results_dir, f"{model_name}_evaluation_results.csv") + df.to_csv(csv_path, index=False) + print(f"Results saved to {csv_path}") + + # Generate visualizations + generate_visualizations(df, results_dir, model_name) + + return results_dir + + +def generate_visualizations(df: pd.DataFrame, output_dir: str, model_name: str) -> None: + """ + Generate visualization charts for the results. + + Args: + df: DataFrame with evaluation results + output_dir: Directory to save visualizations + model_name: Name of the model used + """ + # Precision, recall, and F1 + plt.figure(figsize=(15, 10)) + + # Set up the data + entity_df = df[df["extraction_type"] == "entity"] + relation_df = df[df["extraction_type"] == "relation"] + + # Get evaluation method if available + evaluation_method = ( + df["evaluation_method"].iloc[0] + if "evaluation_method" in df.columns + else "openai_embeddings" + ) + + # Set up the plots + metrics = ["precision", "recall", "f1"] + positions = [0, 1, 2, 4, 5, 6] + width = 0.35 + + # Plot entity data + plt.bar( + positions[:3], + entity_df[metrics].values[0], + width, + label="Entity", + ) + + # Plot relation data + plt.bar( + positions[3:], + relation_df[metrics].values[0], + width, + label="Relation", + ) + + # Add labels and formatting + plt.xticks( + positions, + ["P", "R", "F1", "P", "R", "F1"], + ) + plt.ylim(0, 1.0) + plt.ylabel("Score") + plt.title(f"{model_name} Performance ({evaluation_method})") + plt.legend() + plt.tight_layout() + + # Add text annotations for entity metrics + for i, p in enumerate(positions[:3]): + plt.text( + p, + entity_df[metrics[i]].values[0] + 0.02, + f"{entity_df[metrics[i]].values[0]:.3f}", + ha="center", + va="bottom", + rotation=0, + ) + + # Add text annotations for relation metrics + for i, p in enumerate(positions[3:]): + plt.text( + p, + relation_df[metrics[i]].values[0] + 0.02, + f"{relation_df[metrics[i]].values[0]:.3f}", + ha="center", + va="bottom", + rotation=0, + ) + + # Add vertical separator between entity and relation metrics + plt.axvline(x=3.5, color="gray", linestyle="--", alpha=0.5) + plt.text(1.5, 0.95, "Entity Extraction", ha="center", va="top", fontsize=12) + plt.text(5.5, 0.95, "Relation Extraction", ha="center", va="top", fontsize=12) + + # Save the plot + plt.savefig(os.path.join(output_dir, f"{model_name}_metrics_comparison.png")) + plt.close() + + # Save ground truth vs extracted counts comparison + plt.figure(figsize=(10, 6)) + + # Entity counts + plt.subplot(1, 2, 1) + + gt_count = entity_df["ground_truth_count"].values[0] + extracted_count = entity_df["extracted_count"].values[0] + + counts = [gt_count, extracted_count] + labels = ["Ground Truth", "Extracted"] + + x_positions = list(range(len(counts))) + plt.bar(x_positions, counts, width=0.6) + plt.xticks(x_positions, labels) + plt.ylabel("Count") + plt.title("Entity Counts") + + # Add count labels + for i, count in enumerate(counts): + plt.text(i, count + 2, str(count), ha="center", va="bottom") + + # Relation counts + plt.subplot(1, 2, 2) + + gt_count_rel = relation_df["ground_truth_count"].values[0] + extracted_count_rel = relation_df["extracted_count"].values[0] + + counts_rel = [gt_count_rel, extracted_count_rel] + labels_rel = ["Ground Truth", "Extracted"] + + x_positions_rel = list(range(len(counts_rel))) + plt.bar(x_positions_rel, counts_rel, width=0.6) + plt.xticks(x_positions_rel, labels_rel) + plt.ylabel("Count") + plt.title("Relation Counts") + + # Add count labels + for i, count in enumerate(counts_rel): + plt.text(i, count + 2, str(count), ha="center", va="bottom") + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"{model_name}_count_comparison.png")) + plt.close() + + +def main(): + """Main function to run the evaluation.""" + parser = argparse.ArgumentParser( + description="SciER Evaluation Script for Morphik using OpenAI embeddings" + ) + parser.add_argument( + "--limit", type=int, default=57, help="Maximum number of documents to process (default: 57)" + ) + parser.add_argument( + "--run-id", type=str, default=None, help="Unique run identifier (default: auto-generated)" + ) + parser.add_argument( + "--graph-name", + type=str, + required=True, + help="Name of the existing graph to evaluate (e.g., 'scier_gpt4o_12345678')", + ) + parser.add_argument( + "--model-name", + type=str, + default="existing_model_openai", + help="Name for this evaluation run (default: existing_model_openai)", + ) + parser.add_argument( + "--similarity-threshold", + type=float, + default=0.70, + help="Threshold for semantic similarity matching (default: 0.70)", + ) + parser.add_argument( + "--embedding-model", + type=str, + default="text-embedding-3-small", + help="OpenAI embedding model to use (default: text-embedding-3-small)", + ) + parser.add_argument( + "--api-base", + type=str, + default="https://api.openai.com/v1", + help="OpenAI API base URL (default: https://api.openai.com/v1)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Batch size for parallel processing (default: 100)", + ) + parser.add_argument( + "--max-workers", + type=int, + default=None, + help="Maximum number of worker processes (default: CPU count - 1)", + ) + parser.add_argument( + "--cache-size", + type=int, + default=10000, + help="Size of embedding cache (default: 10000)", + ) + parser.add_argument( + "--embedding-batch-size", + type=int, + default=20, + help="Batch size for embedding API calls (default: 20)", + ) + parser.add_argument( + "--requests-per-minute", + type=int, + default=3500, + help="Rate limit for OpenAI API requests per minute (default: 3500)", + ) + + args = parser.parse_args() + + # Generate run ID if not provided + run_id = args.run_id or str(uuid.uuid4())[:8] + + print(f"Running evaluation for model: {args.model_name}") + print(f"Run ID: {run_id}") + print(f"Using OpenAI embeddings with model: {args.embedding_model}") + print(f"Similarity threshold: {args.similarity_threshold}") + print(f"API base URL: {args.api_base}") + print(f"Batch size: {args.batch_size}") + print(f"Max workers: {args.max_workers or 'CPU count - 1'}") + print(f"Embedding cache size: {args.cache_size}") + print(f"Embedding batch size: {args.embedding_batch_size}") + print(f"Requests per minute: {args.requests_per_minute}") + + # Check if OPENAI_API_KEY is set + if not os.getenv("OPENAI_API_KEY"): + print("Error: OPENAI_API_KEY environment variable is not set.") + print("Please set it using: export OPENAI_API_KEY=your_api_key") + return + + # Initialize Morphik client + db = setup_morphik_client() + + # Initialize OpenAI embedding with optimized parameters + similarity_calculator = OpenAIEmbedding( + model_name=args.embedding_model, + threshold=args.similarity_threshold, + api_base=args.api_base, + cache_size=args.cache_size, + batch_size=args.embedding_batch_size, + ) + # Update rate limit if specified + similarity_calculator.requests_per_minute = args.requests_per_minute + similarity_calculator.min_time_between_requests = 60.0 / args.requests_per_minute + + # Load SciER dataset + scier_data = load_scier_data("test.jsonl", limit=args.limit) + print(f"Loaded {len(scier_data)} records") + + # Prepare documents + documents = prepare_text_for_evaluation(scier_data) + print(f"Prepared {len(documents)} documents for evaluation") + + # Load existing graph + graph_data = load_existing_graph(db, args.graph_name) + + # Evaluate graph + results = evaluate_graph( + db=db, + documents=documents, + graph=graph_data["graph"], + model=args.model_name, + similarity_calculator=similarity_calculator, + batch_size=args.batch_size, + max_workers=args.max_workers, + ) + + # Save results using the input graph name as requested + results_dir = save_results(results, f"{args.graph_name}_result", run_id) + print(f"\nEvaluation complete! Results saved to {results_dir}") + + +if __name__ == "__main__": + main() diff --git a/evaluations/Science graphs (SciER)/scier_evaluation.py b/evaluations/Science graphs (SciER)/scier_evaluation.py new file mode 100644 index 0000000..48d1ac9 --- /dev/null +++ b/evaluations/Science graphs (SciER)/scier_evaluation.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +SciER Graph Creation Script for Morphik + +This script creates a knowledge graph from the SciER dataset. +It ingests the documents and creates a graph using custom prompt overrides. +""" + +import os +import uuid +import argparse +from typing import Dict, List, Tuple, Any +from collections import defaultdict +from tqdm import tqdm +from dotenv import load_dotenv + +from morphik import Morphik +from morphik.models import ( + EntityExtractionPromptOverride, + EntityExtractionExample, + GraphPromptOverrides, +) + +# Import SciER data loader +from data_loader import load_jsonl + +# Load environment variables +load_dotenv() + + +def setup_morphik_client() -> Morphik: + """Initialize and return a Morphik client.""" + # Connect to Morphik (adjust parameters as needed) + return Morphik(timeout=300000, is_local=True) + + +def load_scier_data(dataset_path: str = "test.jsonl", limit: int = None) -> List[Dict]: + """ + Load SciER dataset from the specified JSONL file. + + Args: + dataset_path: Path to the JSONL file + limit: Maximum number of records to load (None for all) + + Returns: + List of dataset records + """ + data = load_jsonl(dataset_path) + if limit: + data = data[:limit] + return data + + +def prepare_text_for_ingestion(records: List[Dict]) -> List[Dict]: + """ + Prepare SciER records for ingestion into Morphik. + + Args: + records: List of SciER records + + Returns: + List of dictionaries ready for ingestion + """ + documents = [] + + # Group records by doc_id to create complete documents + doc_groups = defaultdict(list) + for record in records: + doc_groups[record["doc_id"]].append(record) + + # Convert grouped records to documents + for doc_id, records in doc_groups.items(): + text = "\n".join([record["sentence"] for record in records]) + + # Collect all entities and relations for ground truth + all_entities = [] + all_relations = [] + for record in records: + all_entities.extend(record["ner"]) + all_relations.extend(record["rel"]) + + documents.append( + { + "text": text, + "metadata": { + "doc_id": doc_id, + "ground_truth_entities": all_entities, + "ground_truth_relations": all_relations, + }, + } + ) + + return documents + + +def create_graph_extraction_override(entity_types: List[str]) -> EntityExtractionPromptOverride: + """ + Create graph extraction prompt override with examples for both entities and relations. + + Args: + entity_types: List of entity types (Dataset, Method, Task) + + Returns: + EntityExtractionPromptOverride object + """ + examples = [] + + if "Dataset" in entity_types: + examples.extend( + [ + EntityExtractionExample(label="ImageNet", type="Dataset"), + EntityExtractionExample(label="CIFAR-10", type="Dataset"), + EntityExtractionExample(label="MNIST", type="Dataset"), + EntityExtractionExample(label="Penn TreeBank", type="Dataset"), + EntityExtractionExample(label="SQuAD", type="Dataset"), + EntityExtractionExample(label="MultiNLI", type="Dataset"), + ] + ) + + if "Method" in entity_types: + examples.extend( + [ + # General models + EntityExtractionExample(label="Convolutional Neural Network", type="Method"), + EntityExtractionExample(label="Random Forest", type="Method"), + # Architecture-specific models from SciER + EntityExtractionExample(label="BERT", type="Method"), + EntityExtractionExample(label="Transformer", type="Method"), + EntityExtractionExample(label="LSTM", type="Method"), + EntityExtractionExample(label="Bidirectional LSTM", type="Method"), + EntityExtractionExample(label="self-attentive models", type="Method"), + EntityExtractionExample(label="seq2seq", type="Method"), + # Components + EntityExtractionExample(label="attention mechanism", type="Method"), + EntityExtractionExample(label="feature extraction mechanisms", type="Method"), + ] + ) + + if "Task" in entity_types: + examples.extend( + [ + # General tasks + EntityExtractionExample(label="Image Classification", type="Task"), + EntityExtractionExample(label="Named Entity Recognition", type="Task"), + # NLP tasks from SciER + EntityExtractionExample(label="Machine Translation", type="Task"), + EntityExtractionExample(label="neural machine translation", type="Task"), + EntityExtractionExample(label="sentiment analysis", type="Task"), + EntityExtractionExample(label="entailment", type="Task"), + EntityExtractionExample(label="text classification", type="Task"), + EntityExtractionExample(label="natural language processing", type="Task"), + EntityExtractionExample(label="sequence-to-sequence problems", type="Task"), + EntityExtractionExample(label="NLP", type="Task"), + ] + ) + + # Simplest version - bare standard placeholders + prompt_template = """ +Your task is to carefully read the following scientific text and extract specific information. +You need to extract: +1. **Entities:** Identify any mentions of Datasets, Methods, and Tasks. Use the entity examples provided below to understand what to look for, that is a very small list, there are many many entities. +2. **Relationships:** Identify relationships *between the extracted entities* based on the information stated in the text. Use only the relationship types defined below. + +**Entity Examples (this is a very brief list, there are many many entities):** +{examples} + +**Relationship Information:** +Desired Relationship Types (only extract these relationships, nothing else, there are a lot of relationships, be nuanced and careful, think hard about how entities relate to each other): +- Used-For: [Method/Dataset] is used for [Task] +- Feature-Of: [Feature] is a feature of [Method/Task] +- Hyponym-Of: [Specific] is a type of [General] +- Part-Of: [Component] is part of [System] +- Compare: [Entity A] is compared to [Entity B] +- Evaluate-For: [Method] is evaluated for [Metric/Task] +- Conjunction: [Entity A] is mentioned together with [Entity B] without a specific relation +- Evaluate-On: [Method] is evaluated on [Dataset] +- Synonym-Of: [Entity A] is the same as [Entity B] + +**Instructions:** +- Extract entities first, identifying their label (the text mention) and type (Dataset, Method, or Task). +- Then, extract relationships between the entities you found. The 'source' and 'target' of the relationship MUST be the exact entity labels you extracted. +- Only extract information explicitly mentioned in the text. Do not infer or add outside knowledge. +- Format your entire output as a single JSON object containing two keys: "entities" (a list of entity objects) and "relationships" (a list of relationship objects). + +**Text to analyze:** +{content} +""" + + return EntityExtractionPromptOverride(prompt_template=prompt_template, examples=examples) + + +def create_graph( + db: Morphik, documents: List[Dict], model_name: str, run_id: str +) -> Tuple[List[str], Dict]: + """ + Create a knowledge graph from the documents. + + Args: + db: Morphik client + documents: List of documents to ingest + model_name: Name of the model being used (for tracking) + run_id: Unique identifier for this run + + Returns: + Tuple of (list of document IDs, graphs dict) + """ + print(f"\n=== Creating graph with {model_name} model ===") + + # Ingest documents + doc_ids = [] + for doc in tqdm(documents, desc="Ingesting documents"): + # Add metadata for tracking + doc["metadata"]["evaluation_run_id"] = run_id + doc["metadata"]["model"] = model_name + + # Ingest the document + result = db.ingest_text(doc["text"], metadata=doc["metadata"]) + doc_ids.append(result.external_id) + + # Create graph extraction override (which includes both entity and relationship instructions) + entity_extraction_override = create_graph_extraction_override(["Dataset", "Method", "Task"]) + + # Wrap the combined override correctly for the API + graph_overrides = GraphPromptOverrides(entity_extraction=entity_extraction_override) + + # Create a knowledge graph with overrides + print("Creating knowledge graph with prompt overrides...") + graph = db.create_graph( + name=f"scier_{model_name}_{run_id}", documents=doc_ids, prompt_overrides=graph_overrides + ) + + print( + f"Created graph with {len(graph.entities)} entities and {len(graph.relationships)} relationships" + ) + + return doc_ids, {"graph": graph} + + +def main(): + """Main function to create a graph from the SciER dataset.""" + parser = argparse.ArgumentParser(description="SciER Graph Creation Script for Morphik") + parser.add_argument( + "--limit", type=int, default=57, help="Maximum number of documents to process (default: 57)" + ) + parser.add_argument( + "--run-id", type=str, default=None, help="Unique run identifier (default: auto-generated)" + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Name of the currently configured model (default: auto-detected)", + ) + + args = parser.parse_args() + + # Generate run ID if not provided + run_id = args.run_id or str(uuid.uuid4())[:8] + + # Auto-detect or use provided model name + model_name = args.model_name or "default_model" + + print(f"Running graph creation for model: {model_name}") + print(f"Run ID: {run_id}") + + # Initialize Morphik client + db = setup_morphik_client() + + # Load SciER dataset + scier_data = load_scier_data("test.jsonl", limit=args.limit) + print(f"Loaded {len(scier_data)} records") + + # Prepare documents for ingestion + documents = prepare_text_for_ingestion(scier_data) + print(f"Prepared {len(documents)} documents for ingestion") + + # Create the graph + doc_ids, graphs = create_graph(db, documents, model_name, run_id) + + # Print graph name for evaluation + graph_name = f"scier_{model_name}_{run_id}" + print(f"\nGraph creation complete! Created graph: {graph_name}") + print( + f"To evaluate this graph, run: python evaluate_result.py --graph-name {graph_name}" + ) + + +if __name__ == "__main__": + main() diff --git a/evaluations/hotpot_ragas_eval.py b/evaluations/hotpot_ragas_eval.py new file mode 100644 index 0000000..70ebbbe --- /dev/null +++ b/evaluations/hotpot_ragas_eval.py @@ -0,0 +1,259 @@ +import os +import sys +import uuid +from pathlib import Path + +# Add the SDK path to the Python path +sdk_path = str(Path(__file__).parent.parent / "sdks" / "python") +sys.path.insert(0, sdk_path) + +from datasets import load_dataset, Dataset +from ragas import evaluate +from ragas.metrics import faithfulness, answer_correctness, context_precision +import pandas as pd +from dotenv import load_dotenv +from morphik import Morphik +from tqdm import tqdm +import argparse + +# Load environment variables +load_dotenv() + +# Connect to Morphik +db = Morphik(timeout=10000, is_local=True) + +# Generate a run identifier +def generate_run_id(): + """Generate a unique run identifier""" + return str(uuid.uuid4()) + + +def load_hotpotqa_dataset(num_samples=10, split="validation"): + """Load and prepare the HotpotQA dataset""" + dataset = load_dataset("hotpot_qa", "distractor", split=split, trust_remote_code=True) + + # Sample a subset + dataset = dataset.select(range(min(num_samples, len(dataset)))) + + return dataset + + +def process_with_morphik(dataset, run_id=None): + """ + Process dataset with Morphik and prepare data for RAGAS evaluation + + Args: + dataset: The dataset to process + run_id: Unique identifier for this evaluation run + """ + # Generate a run_id if not provided + if run_id is None: + run_id = generate_run_id() + + print(f"Using run identifier: {run_id}") + + data_samples = { + "question": [], + "answer": [], + "contexts": [], + "ground_truth": [], + "run_id": [] # Store run_id for each sample + } + + for i, item in enumerate(tqdm(dataset, desc="Processing documents")): + try: + # Extract question and ground truth + question = item["question"].strip() + ground_truth = item["answer"].strip() + + if not question or not ground_truth: + print(f"Skipping item {i}: Empty question or answer") + continue + + # Ingest the document's context into Morphik + context = "" + for title, sentences in zip(item["context"]["title"], item["context"]["sentences"]): + paragraph = " ".join(sentences) + context += f"{title}:\n{paragraph}\n\n" + + # Handle a potentially longer context + # if len(context) > 10000: + # print(f"Warning: Long context ({len(context)} chars), truncating...") + # context = context[:10000] + + # Ingest text with run_id in metadata + doc_id = db.ingest_text( + context, + metadata={ + "source": "hotpotqa", + "question_id": item.get("_id", ""), + "item_index": i, + "evaluation_run_id": run_id # Add run_id to metadata + }, + use_colpali=False + ).external_id + + # Query Morphik for the answer with concise prompt override + prompt_override = { + "query": { + "prompt_template": "Answer the following question based on the provided context. Your answer should be as concise as possible. If a yes/no answer is appropriate, just respond with 'Yes' or 'No'. Do not provide explanations or additional context unless absolutely necessary.\n\nQuestion: {question}\n\nContext: {context}" + } + } + response = db.query( + question, + use_colpali=False, + k=10, + filters={"evaluation_run_id": run_id}, + prompt_overrides=prompt_override + ) + answer = response.completion + + if not answer: + print(f"Warning: Empty answer for question: {question[:50]}...") + answer = "No answer provided" + + # Get retrieved chunks for context with filter by run_id + chunks = db.retrieve_chunks( + query=question, + k=10, + filters={"evaluation_run_id": run_id} # Filter by run_id + ) + context_texts = [chunk.content for chunk in chunks] + + if not context_texts: + print(f"Warning: No contexts retrieved for question: {question[:50]}...") + context_texts = ["No context retrieved"] + + # Add to our dataset + data_samples["question"].append(question) + data_samples["answer"].append(answer) + data_samples["contexts"].append(context_texts) + data_samples["ground_truth"].append(ground_truth) + data_samples["run_id"].append(run_id) + + except Exception as e: + import traceback + + print(f"Error processing item {i}:") + print(f"Question: {item.get('question', 'N/A')[:50]}...") + print(f"Error: {e}") + traceback.print_exc() + continue + + return data_samples, run_id + + +def run_evaluation(num_samples=5, output_file="ragas_results.csv", run_id=None): + """ + Run the full evaluation pipeline + + Args: + num_samples: Number of samples to use from the dataset + output_file: Path to save the results CSV + run_id: Optional run identifier. If None, a new one will be generated + """ + try: + # Load dataset + print("Loading HotpotQA dataset...") + hotpot_dataset = load_hotpotqa_dataset(num_samples=num_samples) + print(f"Loaded {len(hotpot_dataset)} samples from HotpotQA") + + # Process with Morphik + print("Processing with Morphik...") + data_samples, run_id = process_with_morphik(hotpot_dataset, run_id=run_id) + + # Check if we have enough samples + if len(data_samples["question"]) == 0: + print("Error: No samples were successfully processed. Exiting.") + return + + print(f"Successfully processed {len(data_samples['question'])} samples") + + # Convert to RAGAS format + ragas_dataset = Dataset.from_dict(data_samples) + + # Run RAGAS evaluation + print("Running RAGAS evaluation...") + metrics = [faithfulness, answer_correctness, context_precision] + + result = evaluate(ragas_dataset, metrics=metrics) + + # Convert results to DataFrame and save + df_result = result.to_pandas() + + # Add run_id to the results DataFrame + df_result['run_id'] = run_id + + print("\nRAGAS Evaluation Results:") + print(df_result) + + # Add more detailed analysis + print("\nDetailed Metric Analysis:") + # First ensure all metric columns are numeric + for column in ["faithfulness", "answer_correctness", "context_precision"]: + if column in df_result.columns: + try: + # Convert column to numeric, errors='coerce' will replace non-numeric values with NaN + df_result[column] = pd.to_numeric(df_result[column], errors="coerce") + # Calculate and print mean, ignoring NaN values + mean_value = df_result[column].mean(skipna=True) + if pd.notna(mean_value): # Check if mean is not NaN + print(f"{column}: {mean_value:.4f}") + else: + print(f"{column}: No valid numeric values found") + except Exception as e: + print(f"Error processing {column}: {e}") + print(f"Values: {df_result[column].head().tolist()}") + + # Include run_id in the output filename if not explicitly provided + if output_file == "ragas_results.csv": + # Get just the filename without extension + base_name = output_file.rsplit('.', 1)[0] + output_file = f"{base_name}_{run_id}.csv" + + # Save results + df_result.to_csv(output_file, index=False) + print(f"Results saved to {output_file}") + + return df_result, run_id + + except Exception as e: + import traceback + + print(f"Error in evaluation: {e}") + traceback.print_exc() + print("Exiting due to error.") + return None + + +def main(): + """Command-line entry point""" + parser = argparse.ArgumentParser( + description="Run RAGAS evaluation on Morphik using HotpotQA dataset" + ) + parser.add_argument( + "--samples", type=int, default=5, help="Number of samples to use (default: 5)" + ) + parser.add_argument( + "--output", + type=str, + default="ragas_results.csv", + help="Output file for results (default: ragas_results.csv)", + ) + parser.add_argument( + "--run-id", + type=str, + default=None, + help="Specific run identifier to use (default: auto-generated UUID)", + ) + args = parser.parse_args() + + run_evaluation( + num_samples=args.samples, + output_file=args.output, + run_id=args.run_id + ) + + +if __name__ == "__main__": + main()