mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
SciER Science graph evals (#81)
This commit is contained in:
parent
95369635e2
commit
40fc8fab6b
@ -480,6 +480,8 @@ class GraphService:
|
||||
relationships = []
|
||||
# List to collect all extracted entities for resolution
|
||||
all_entities = []
|
||||
# Track all initial entities with their labels to fix relationship mapping
|
||||
initial_entities = []
|
||||
|
||||
# Collect all chunk sources from documents.
|
||||
chunk_sources = [
|
||||
@ -505,6 +507,9 @@ class GraphService:
|
||||
chunk_entities, chunk_relationships = await self.extract_entities_from_text(
|
||||
chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides
|
||||
)
|
||||
|
||||
# Store all initially extracted entities to track their IDs
|
||||
initial_entities.extend(chunk_entities)
|
||||
|
||||
# Add entities to the collection, avoiding duplicates based on exact label match
|
||||
for entity in chunk_entities:
|
||||
@ -551,6 +556,9 @@ class GraphService:
|
||||
)
|
||||
raise
|
||||
|
||||
# Build a mapping from entity ID to label for ALL initially extracted entities
|
||||
original_entity_id_to_label = {entity.id: entity.label for entity in initial_entities}
|
||||
|
||||
# Check if entity resolution is enabled in settings
|
||||
settings = get_settings()
|
||||
|
||||
@ -591,37 +599,62 @@ class GraphService:
|
||||
# Update relationships to use canonical entity labels
|
||||
updated_relationships = []
|
||||
|
||||
# Create an entity index by ID for efficient lookups
|
||||
entity_by_id = {entity.id: entity for entity in all_entities}
|
||||
|
||||
# Remap relationships using original entity ID to label mapping
|
||||
remapped_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for relationship in relationships:
|
||||
# Lookup entities by ID directly from the index
|
||||
source_entity = entity_by_id.get(relationship.source_id)
|
||||
target_entity = entity_by_id.get(relationship.target_id)
|
||||
|
||||
if source_entity and target_entity:
|
||||
# Get canonical labels
|
||||
source_canonical = entity_mapping.get(source_entity.label, source_entity.label)
|
||||
target_canonical = entity_mapping.get(target_entity.label, target_entity.label)
|
||||
# Get canonical entities
|
||||
canonical_source = resolved_entities_dict.get(source_canonical)
|
||||
canonical_target = resolved_entities_dict.get(target_canonical)
|
||||
if canonical_source and canonical_target:
|
||||
# Update relationship to point to canonical entities
|
||||
relationship.source_id = canonical_source.id
|
||||
relationship.target_id = canonical_target.id
|
||||
updated_relationships.append(relationship)
|
||||
else:
|
||||
# Skip relationships that can't be properly mapped
|
||||
logger.warning(
|
||||
"Skipping relationship between '%s' and '%s' - canonical entities not found",
|
||||
source_entity.label,
|
||||
target_entity.label,
|
||||
)
|
||||
else:
|
||||
# Keep relationship as is if we can't find the entities
|
||||
# Use original_entity_id_to_label to get the labels for relationship endpoints
|
||||
original_source_label = original_entity_id_to_label.get(relationship.source_id)
|
||||
original_target_label = original_entity_id_to_label.get(relationship.target_id)
|
||||
|
||||
if not original_source_label or not original_target_label:
|
||||
logger.warning(
|
||||
f"Skipping relationship with type '{relationship.type}' - could not find original entity labels"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Find canonical labels using the mapping from the resolver
|
||||
source_canonical = entity_mapping.get(original_source_label, original_source_label)
|
||||
target_canonical = entity_mapping.get(original_target_label, original_target_label)
|
||||
|
||||
# Find the final unique Entity objects using the canonical labels
|
||||
canonical_source = resolved_entities_dict.get(source_canonical)
|
||||
canonical_target = resolved_entities_dict.get(target_canonical)
|
||||
|
||||
if canonical_source and canonical_target:
|
||||
# Successfully found the final entities, update the relationship's IDs
|
||||
relationship.source_id = canonical_source.id
|
||||
relationship.target_id = canonical_target.id
|
||||
updated_relationships.append(relationship)
|
||||
return resolved_entities_dict, updated_relationships
|
||||
remapped_count += 1
|
||||
else:
|
||||
# Could not map to final entities, log and skip
|
||||
logger.warning(
|
||||
f"Skipping relationship between '{original_source_label}' and '{original_target_label}' - "
|
||||
f"canonical entities not found after resolution"
|
||||
)
|
||||
skipped_count += 1
|
||||
|
||||
logger.info(f"Remapped {remapped_count} relationships, skipped {skipped_count} relationships")
|
||||
|
||||
# Deduplicate relationships (same source, target, type)
|
||||
final_relationships_map = {}
|
||||
for rel in updated_relationships:
|
||||
key = (rel.source_id, rel.target_id, rel.type)
|
||||
if key not in final_relationships_map:
|
||||
final_relationships_map[key] = rel
|
||||
else:
|
||||
# Merge sources into the existing relationship
|
||||
existing_rel = final_relationships_map[key]
|
||||
self._merge_relationship_sources(existing_rel, rel)
|
||||
|
||||
final_relationships = list(final_relationships_map.values())
|
||||
logger.info(f"Deduplicated to {len(final_relationships)} unique relationships")
|
||||
|
||||
return resolved_entities_dict, final_relationships
|
||||
|
||||
# If no entity resolution occurred, return original entities and relationships
|
||||
return entities, relationships
|
||||
|
||||
@ -685,9 +718,9 @@ class GraphService:
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an entity extraction assistant. Extract entities and their relationships from text precisely and thoroughly. "
|
||||
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). "
|
||||
"For relationships, use a simple format with source, target, and relationship fields. "
|
||||
"You are an entity extraction and relationship extraction assistant. Extract entities and their relationships from text precisely and thoroughly, extract as many entities and relationships as possible. "
|
||||
"For entities, include entity label and type (some examples: PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). If the user has given examples, use those, these are just suggestions"
|
||||
"For relationships, use a simple format with source, target, and relationship fields. Be very through, there are many relationships that are not obvious"
|
||||
"IMPORTANT: The source and target fields must be simple strings representing entity labels. For example: "
|
||||
"if you extract entities 'Entity A' and 'Entity B', a relationship would have source: 'Entity A', target: 'Entity B', relationship: 'relates to'. "
|
||||
"Respond directly in json format, without any additional text or explanations. "
|
||||
|
87
evaluations/Science graphs (SciER)/README.md
Normal file
87
evaluations/Science graphs (SciER)/README.md
Normal file
@ -0,0 +1,87 @@
|
||||
# SciER Evaluation for Morphik
|
||||
|
||||
This directory contains scripts for evaluating different language models' entity and relation extraction capabilities against the [SciER (Scientific Entity Recognition) dataset](https://github.com/allenai/SciERC).
|
||||
|
||||
## Overview
|
||||
|
||||
The evaluation workflow is split into two parts:
|
||||
1. **Graph Creation**: Generate a knowledge graph using your configured model in `morphik.toml`
|
||||
2. **Graph Evaluation**: Evaluate the created graph against the ground truth annotations
|
||||
|
||||
This separation allows you to test different models by changing the configuration in `morphik.toml` between runs.
|
||||
|
||||
The evaluation uses the SciER test dataset which can be found at [https://github.com/edzq/SciER/blob/main/SciER/LLM/test.jsonl](https://github.com/edzq/SciER/blob/main/SciER/LLM/test.jsonl).
|
||||
|
||||
## Setup
|
||||
|
||||
1. Make sure you have a local Morphik server running
|
||||
2. Set up your OpenAI API key for evaluation:
|
||||
```bash
|
||||
export OPENAI_API_KEY=your_key_here
|
||||
```
|
||||
|
||||
## Running the Evaluation
|
||||
|
||||
### Step 1: Configure Your Model
|
||||
|
||||
Edit `morphik.toml` to specify the model you want to test:
|
||||
|
||||
```toml
|
||||
[graph]
|
||||
model = "openai_gpt4o" # Reference to a key in registered_models
|
||||
enable_entity_resolution = true
|
||||
```
|
||||
|
||||
### Step 2: Create a Knowledge Graph
|
||||
|
||||
Run the graph creation script:
|
||||
|
||||
```bash
|
||||
python scier_evaluation.py --model-name gpt4o
|
||||
```
|
||||
|
||||
The script will output a graph name like `scier_gpt4o_12345678` when complete.
|
||||
|
||||
### Step 3: Evaluate the Knowledge Graph
|
||||
|
||||
Evaluate the created graph:
|
||||
|
||||
```bash
|
||||
python evaluate_result.py --graph-name scier_gpt4o_12345678
|
||||
```
|
||||
|
||||
### Step 4: Test Different Models
|
||||
|
||||
To compare models:
|
||||
1. Change the model in `morphik.toml`
|
||||
2. Repeat steps 2-3 with a different `--model-name`
|
||||
3. Compare the resulting metrics and visualizations
|
||||
|
||||
## Command Arguments
|
||||
|
||||
### scier_evaluation.py
|
||||
- `--model-name`: Name to identify this model in results (default: "default_model")
|
||||
- `--limit`: Maximum number of documents to process (default: 57)
|
||||
- `--run-id`: Unique identifier for the run (default: auto-generated)
|
||||
|
||||
### evaluate_result.py
|
||||
- `--graph-name`: Name of the graph to evaluate (**required**)
|
||||
- `--model-name`: Name for the evaluation results (default: "existing_model_openai")
|
||||
- `--similarity-threshold`: Threshold for semantic similarity matching (default: 0.70)
|
||||
- `--embedding-model`: OpenAI embedding model to use (default: "text-embedding-3-small")
|
||||
|
||||
## Results
|
||||
|
||||
The evaluation generates:
|
||||
- CSV files with precision, recall, and F1 metrics
|
||||
- Visualizations comparing entity and relation extraction performance
|
||||
- Entity and relation count comparisons
|
||||
|
||||
Results are saved in a directory with the format: `scier_results_{model_name}_{run_id}/`
|
||||
|
||||
## Tips for Model Comparison
|
||||
|
||||
- Use descriptive model names in the `--model-name` parameter
|
||||
- Keep the same similarity threshold across evaluations
|
||||
- Compare models using the generated visualization charts
|
||||
- Look at both entity and relation extraction metrics
|
18
evaluations/Science graphs (SciER)/count_entities.py
Normal file
18
evaluations/Science graphs (SciER)/count_entities.py
Normal file
@ -0,0 +1,18 @@
|
||||
import json
|
||||
|
||||
entity_count = 0
|
||||
relation_count = 0
|
||||
document_count = 0
|
||||
|
||||
with open("test.jsonl", "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
document_count += 1
|
||||
data = json.loads(line)
|
||||
if "ner" in data:
|
||||
entity_count += len(data["ner"])
|
||||
if "rel" in data:
|
||||
relation_count += len(data["rel"])
|
||||
|
||||
print(f"Total documents: {document_count}")
|
||||
print(f"Total entities: {entity_count}")
|
||||
print(f"Total relationships: {relation_count}")
|
22
evaluations/Science graphs (SciER)/data_loader.py
Normal file
22
evaluations/Science graphs (SciER)/data_loader.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
|
||||
|
||||
def load_jsonl(jsonl_path: str) -> list:
|
||||
data = []
|
||||
with open(jsonl_path) as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
|
||||
# ========= usage example for load_jsonl function with LLM input =======
|
||||
# dataset = load_jsonl('./SciER/LLM/test.jsonl')
|
||||
# sent = dataset[0]
|
||||
# print(sent.keys())
|
||||
# print(sent['sentence'])
|
||||
# print('----------------')
|
||||
# print(sent['ner'])
|
||||
# print('----------------')
|
||||
# print(sent['rel'])
|
||||
# print('----------------')
|
||||
# print(sent['rel_plus'])
|
1116
evaluations/Science graphs (SciER)/evaluate_result.py
Normal file
1116
evaluations/Science graphs (SciER)/evaluate_result.py
Normal file
File diff suppressed because it is too large
Load Diff
289
evaluations/Science graphs (SciER)/scier_evaluation.py
Normal file
289
evaluations/Science graphs (SciER)/scier_evaluation.py
Normal file
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SciER Graph Creation Script for Morphik
|
||||
|
||||
This script creates a knowledge graph from the SciER dataset.
|
||||
It ingests the documents and creates a graph using custom prompt overrides.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import argparse
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from morphik import Morphik
|
||||
from morphik.models import (
|
||||
EntityExtractionPromptOverride,
|
||||
EntityExtractionExample,
|
||||
GraphPromptOverrides,
|
||||
)
|
||||
|
||||
# Import SciER data loader
|
||||
from data_loader import load_jsonl
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def setup_morphik_client() -> Morphik:
|
||||
"""Initialize and return a Morphik client."""
|
||||
# Connect to Morphik (adjust parameters as needed)
|
||||
return Morphik(timeout=300000, is_local=True)
|
||||
|
||||
|
||||
def load_scier_data(dataset_path: str = "test.jsonl", limit: int = None) -> List[Dict]:
|
||||
"""
|
||||
Load SciER dataset from the specified JSONL file.
|
||||
|
||||
Args:
|
||||
dataset_path: Path to the JSONL file
|
||||
limit: Maximum number of records to load (None for all)
|
||||
|
||||
Returns:
|
||||
List of dataset records
|
||||
"""
|
||||
data = load_jsonl(dataset_path)
|
||||
if limit:
|
||||
data = data[:limit]
|
||||
return data
|
||||
|
||||
|
||||
def prepare_text_for_ingestion(records: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Prepare SciER records for ingestion into Morphik.
|
||||
|
||||
Args:
|
||||
records: List of SciER records
|
||||
|
||||
Returns:
|
||||
List of dictionaries ready for ingestion
|
||||
"""
|
||||
documents = []
|
||||
|
||||
# Group records by doc_id to create complete documents
|
||||
doc_groups = defaultdict(list)
|
||||
for record in records:
|
||||
doc_groups[record["doc_id"]].append(record)
|
||||
|
||||
# Convert grouped records to documents
|
||||
for doc_id, records in doc_groups.items():
|
||||
text = "\n".join([record["sentence"] for record in records])
|
||||
|
||||
# Collect all entities and relations for ground truth
|
||||
all_entities = []
|
||||
all_relations = []
|
||||
for record in records:
|
||||
all_entities.extend(record["ner"])
|
||||
all_relations.extend(record["rel"])
|
||||
|
||||
documents.append(
|
||||
{
|
||||
"text": text,
|
||||
"metadata": {
|
||||
"doc_id": doc_id,
|
||||
"ground_truth_entities": all_entities,
|
||||
"ground_truth_relations": all_relations,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def create_graph_extraction_override(entity_types: List[str]) -> EntityExtractionPromptOverride:
|
||||
"""
|
||||
Create graph extraction prompt override with examples for both entities and relations.
|
||||
|
||||
Args:
|
||||
entity_types: List of entity types (Dataset, Method, Task)
|
||||
|
||||
Returns:
|
||||
EntityExtractionPromptOverride object
|
||||
"""
|
||||
examples = []
|
||||
|
||||
if "Dataset" in entity_types:
|
||||
examples.extend(
|
||||
[
|
||||
EntityExtractionExample(label="ImageNet", type="Dataset"),
|
||||
EntityExtractionExample(label="CIFAR-10", type="Dataset"),
|
||||
EntityExtractionExample(label="MNIST", type="Dataset"),
|
||||
EntityExtractionExample(label="Penn TreeBank", type="Dataset"),
|
||||
EntityExtractionExample(label="SQuAD", type="Dataset"),
|
||||
EntityExtractionExample(label="MultiNLI", type="Dataset"),
|
||||
]
|
||||
)
|
||||
|
||||
if "Method" in entity_types:
|
||||
examples.extend(
|
||||
[
|
||||
# General models
|
||||
EntityExtractionExample(label="Convolutional Neural Network", type="Method"),
|
||||
EntityExtractionExample(label="Random Forest", type="Method"),
|
||||
# Architecture-specific models from SciER
|
||||
EntityExtractionExample(label="BERT", type="Method"),
|
||||
EntityExtractionExample(label="Transformer", type="Method"),
|
||||
EntityExtractionExample(label="LSTM", type="Method"),
|
||||
EntityExtractionExample(label="Bidirectional LSTM", type="Method"),
|
||||
EntityExtractionExample(label="self-attentive models", type="Method"),
|
||||
EntityExtractionExample(label="seq2seq", type="Method"),
|
||||
# Components
|
||||
EntityExtractionExample(label="attention mechanism", type="Method"),
|
||||
EntityExtractionExample(label="feature extraction mechanisms", type="Method"),
|
||||
]
|
||||
)
|
||||
|
||||
if "Task" in entity_types:
|
||||
examples.extend(
|
||||
[
|
||||
# General tasks
|
||||
EntityExtractionExample(label="Image Classification", type="Task"),
|
||||
EntityExtractionExample(label="Named Entity Recognition", type="Task"),
|
||||
# NLP tasks from SciER
|
||||
EntityExtractionExample(label="Machine Translation", type="Task"),
|
||||
EntityExtractionExample(label="neural machine translation", type="Task"),
|
||||
EntityExtractionExample(label="sentiment analysis", type="Task"),
|
||||
EntityExtractionExample(label="entailment", type="Task"),
|
||||
EntityExtractionExample(label="text classification", type="Task"),
|
||||
EntityExtractionExample(label="natural language processing", type="Task"),
|
||||
EntityExtractionExample(label="sequence-to-sequence problems", type="Task"),
|
||||
EntityExtractionExample(label="NLP", type="Task"),
|
||||
]
|
||||
)
|
||||
|
||||
# Simplest version - bare standard placeholders
|
||||
prompt_template = """
|
||||
Your task is to carefully read the following scientific text and extract specific information.
|
||||
You need to extract:
|
||||
1. **Entities:** Identify any mentions of Datasets, Methods, and Tasks. Use the entity examples provided below to understand what to look for, that is a very small list, there are many many entities.
|
||||
2. **Relationships:** Identify relationships *between the extracted entities* based on the information stated in the text. Use only the relationship types defined below.
|
||||
|
||||
**Entity Examples (this is a very brief list, there are many many entities):**
|
||||
{examples}
|
||||
|
||||
**Relationship Information:**
|
||||
Desired Relationship Types (only extract these relationships, nothing else, there are a lot of relationships, be nuanced and careful, think hard about how entities relate to each other):
|
||||
- Used-For: [Method/Dataset] is used for [Task]
|
||||
- Feature-Of: [Feature] is a feature of [Method/Task]
|
||||
- Hyponym-Of: [Specific] is a type of [General]
|
||||
- Part-Of: [Component] is part of [System]
|
||||
- Compare: [Entity A] is compared to [Entity B]
|
||||
- Evaluate-For: [Method] is evaluated for [Metric/Task]
|
||||
- Conjunction: [Entity A] is mentioned together with [Entity B] without a specific relation
|
||||
- Evaluate-On: [Method] is evaluated on [Dataset]
|
||||
- Synonym-Of: [Entity A] is the same as [Entity B]
|
||||
|
||||
**Instructions:**
|
||||
- Extract entities first, identifying their label (the text mention) and type (Dataset, Method, or Task).
|
||||
- Then, extract relationships between the entities you found. The 'source' and 'target' of the relationship MUST be the exact entity labels you extracted.
|
||||
- Only extract information explicitly mentioned in the text. Do not infer or add outside knowledge.
|
||||
- Format your entire output as a single JSON object containing two keys: "entities" (a list of entity objects) and "relationships" (a list of relationship objects).
|
||||
|
||||
**Text to analyze:**
|
||||
{content}
|
||||
"""
|
||||
|
||||
return EntityExtractionPromptOverride(prompt_template=prompt_template, examples=examples)
|
||||
|
||||
|
||||
def create_graph(
|
||||
db: Morphik, documents: List[Dict], model_name: str, run_id: str
|
||||
) -> Tuple[List[str], Dict]:
|
||||
"""
|
||||
Create a knowledge graph from the documents.
|
||||
|
||||
Args:
|
||||
db: Morphik client
|
||||
documents: List of documents to ingest
|
||||
model_name: Name of the model being used (for tracking)
|
||||
run_id: Unique identifier for this run
|
||||
|
||||
Returns:
|
||||
Tuple of (list of document IDs, graphs dict)
|
||||
"""
|
||||
print(f"\n=== Creating graph with {model_name} model ===")
|
||||
|
||||
# Ingest documents
|
||||
doc_ids = []
|
||||
for doc in tqdm(documents, desc="Ingesting documents"):
|
||||
# Add metadata for tracking
|
||||
doc["metadata"]["evaluation_run_id"] = run_id
|
||||
doc["metadata"]["model"] = model_name
|
||||
|
||||
# Ingest the document
|
||||
result = db.ingest_text(doc["text"], metadata=doc["metadata"])
|
||||
doc_ids.append(result.external_id)
|
||||
|
||||
# Create graph extraction override (which includes both entity and relationship instructions)
|
||||
entity_extraction_override = create_graph_extraction_override(["Dataset", "Method", "Task"])
|
||||
|
||||
# Wrap the combined override correctly for the API
|
||||
graph_overrides = GraphPromptOverrides(entity_extraction=entity_extraction_override)
|
||||
|
||||
# Create a knowledge graph with overrides
|
||||
print("Creating knowledge graph with prompt overrides...")
|
||||
graph = db.create_graph(
|
||||
name=f"scier_{model_name}_{run_id}", documents=doc_ids, prompt_overrides=graph_overrides
|
||||
)
|
||||
|
||||
print(
|
||||
f"Created graph with {len(graph.entities)} entities and {len(graph.relationships)} relationships"
|
||||
)
|
||||
|
||||
return doc_ids, {"graph": graph}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to create a graph from the SciER dataset."""
|
||||
parser = argparse.ArgumentParser(description="SciER Graph Creation Script for Morphik")
|
||||
parser.add_argument(
|
||||
"--limit", type=int, default=57, help="Maximum number of documents to process (default: 57)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-id", type=str, default=None, help="Unique run identifier (default: auto-generated)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the currently configured model (default: auto-detected)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Generate run ID if not provided
|
||||
run_id = args.run_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Auto-detect or use provided model name
|
||||
model_name = args.model_name or "default_model"
|
||||
|
||||
print(f"Running graph creation for model: {model_name}")
|
||||
print(f"Run ID: {run_id}")
|
||||
|
||||
# Initialize Morphik client
|
||||
db = setup_morphik_client()
|
||||
|
||||
# Load SciER dataset
|
||||
scier_data = load_scier_data("test.jsonl", limit=args.limit)
|
||||
print(f"Loaded {len(scier_data)} records")
|
||||
|
||||
# Prepare documents for ingestion
|
||||
documents = prepare_text_for_ingestion(scier_data)
|
||||
print(f"Prepared {len(documents)} documents for ingestion")
|
||||
|
||||
# Create the graph
|
||||
doc_ids, graphs = create_graph(db, documents, model_name, run_id)
|
||||
|
||||
# Print graph name for evaluation
|
||||
graph_name = f"scier_{model_name}_{run_id}"
|
||||
print(f"\nGraph creation complete! Created graph: {graph_name}")
|
||||
print(
|
||||
f"To evaluate this graph, run: python evaluate_result.py --graph-name {graph_name}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
259
evaluations/hotpot_ragas_eval.py
Normal file
259
evaluations/hotpot_ragas_eval.py
Normal file
@ -0,0 +1,259 @@
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
# Add the SDK path to the Python path
|
||||
sdk_path = str(Path(__file__).parent.parent / "sdks" / "python")
|
||||
sys.path.insert(0, sdk_path)
|
||||
|
||||
from datasets import load_dataset, Dataset
|
||||
from ragas import evaluate
|
||||
from ragas.metrics import faithfulness, answer_correctness, context_precision
|
||||
import pandas as pd
|
||||
from dotenv import load_dotenv
|
||||
from morphik import Morphik
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Connect to Morphik
|
||||
db = Morphik(timeout=10000, is_local=True)
|
||||
|
||||
# Generate a run identifier
|
||||
def generate_run_id():
|
||||
"""Generate a unique run identifier"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def load_hotpotqa_dataset(num_samples=10, split="validation"):
|
||||
"""Load and prepare the HotpotQA dataset"""
|
||||
dataset = load_dataset("hotpot_qa", "distractor", split=split, trust_remote_code=True)
|
||||
|
||||
# Sample a subset
|
||||
dataset = dataset.select(range(min(num_samples, len(dataset))))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def process_with_morphik(dataset, run_id=None):
|
||||
"""
|
||||
Process dataset with Morphik and prepare data for RAGAS evaluation
|
||||
|
||||
Args:
|
||||
dataset: The dataset to process
|
||||
run_id: Unique identifier for this evaluation run
|
||||
"""
|
||||
# Generate a run_id if not provided
|
||||
if run_id is None:
|
||||
run_id = generate_run_id()
|
||||
|
||||
print(f"Using run identifier: {run_id}")
|
||||
|
||||
data_samples = {
|
||||
"question": [],
|
||||
"answer": [],
|
||||
"contexts": [],
|
||||
"ground_truth": [],
|
||||
"run_id": [] # Store run_id for each sample
|
||||
}
|
||||
|
||||
for i, item in enumerate(tqdm(dataset, desc="Processing documents")):
|
||||
try:
|
||||
# Extract question and ground truth
|
||||
question = item["question"].strip()
|
||||
ground_truth = item["answer"].strip()
|
||||
|
||||
if not question or not ground_truth:
|
||||
print(f"Skipping item {i}: Empty question or answer")
|
||||
continue
|
||||
|
||||
# Ingest the document's context into Morphik
|
||||
context = ""
|
||||
for title, sentences in zip(item["context"]["title"], item["context"]["sentences"]):
|
||||
paragraph = " ".join(sentences)
|
||||
context += f"{title}:\n{paragraph}\n\n"
|
||||
|
||||
# Handle a potentially longer context
|
||||
# if len(context) > 10000:
|
||||
# print(f"Warning: Long context ({len(context)} chars), truncating...")
|
||||
# context = context[:10000]
|
||||
|
||||
# Ingest text with run_id in metadata
|
||||
doc_id = db.ingest_text(
|
||||
context,
|
||||
metadata={
|
||||
"source": "hotpotqa",
|
||||
"question_id": item.get("_id", ""),
|
||||
"item_index": i,
|
||||
"evaluation_run_id": run_id # Add run_id to metadata
|
||||
},
|
||||
use_colpali=False
|
||||
).external_id
|
||||
|
||||
# Query Morphik for the answer with concise prompt override
|
||||
prompt_override = {
|
||||
"query": {
|
||||
"prompt_template": "Answer the following question based on the provided context. Your answer should be as concise as possible. If a yes/no answer is appropriate, just respond with 'Yes' or 'No'. Do not provide explanations or additional context unless absolutely necessary.\n\nQuestion: {question}\n\nContext: {context}"
|
||||
}
|
||||
}
|
||||
response = db.query(
|
||||
question,
|
||||
use_colpali=False,
|
||||
k=10,
|
||||
filters={"evaluation_run_id": run_id},
|
||||
prompt_overrides=prompt_override
|
||||
)
|
||||
answer = response.completion
|
||||
|
||||
if not answer:
|
||||
print(f"Warning: Empty answer for question: {question[:50]}...")
|
||||
answer = "No answer provided"
|
||||
|
||||
# Get retrieved chunks for context with filter by run_id
|
||||
chunks = db.retrieve_chunks(
|
||||
query=question,
|
||||
k=10,
|
||||
filters={"evaluation_run_id": run_id} # Filter by run_id
|
||||
)
|
||||
context_texts = [chunk.content for chunk in chunks]
|
||||
|
||||
if not context_texts:
|
||||
print(f"Warning: No contexts retrieved for question: {question[:50]}...")
|
||||
context_texts = ["No context retrieved"]
|
||||
|
||||
# Add to our dataset
|
||||
data_samples["question"].append(question)
|
||||
data_samples["answer"].append(answer)
|
||||
data_samples["contexts"].append(context_texts)
|
||||
data_samples["ground_truth"].append(ground_truth)
|
||||
data_samples["run_id"].append(run_id)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(f"Error processing item {i}:")
|
||||
print(f"Question: {item.get('question', 'N/A')[:50]}...")
|
||||
print(f"Error: {e}")
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
return data_samples, run_id
|
||||
|
||||
|
||||
def run_evaluation(num_samples=5, output_file="ragas_results.csv", run_id=None):
|
||||
"""
|
||||
Run the full evaluation pipeline
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples to use from the dataset
|
||||
output_file: Path to save the results CSV
|
||||
run_id: Optional run identifier. If None, a new one will be generated
|
||||
"""
|
||||
try:
|
||||
# Load dataset
|
||||
print("Loading HotpotQA dataset...")
|
||||
hotpot_dataset = load_hotpotqa_dataset(num_samples=num_samples)
|
||||
print(f"Loaded {len(hotpot_dataset)} samples from HotpotQA")
|
||||
|
||||
# Process with Morphik
|
||||
print("Processing with Morphik...")
|
||||
data_samples, run_id = process_with_morphik(hotpot_dataset, run_id=run_id)
|
||||
|
||||
# Check if we have enough samples
|
||||
if len(data_samples["question"]) == 0:
|
||||
print("Error: No samples were successfully processed. Exiting.")
|
||||
return
|
||||
|
||||
print(f"Successfully processed {len(data_samples['question'])} samples")
|
||||
|
||||
# Convert to RAGAS format
|
||||
ragas_dataset = Dataset.from_dict(data_samples)
|
||||
|
||||
# Run RAGAS evaluation
|
||||
print("Running RAGAS evaluation...")
|
||||
metrics = [faithfulness, answer_correctness, context_precision]
|
||||
|
||||
result = evaluate(ragas_dataset, metrics=metrics)
|
||||
|
||||
# Convert results to DataFrame and save
|
||||
df_result = result.to_pandas()
|
||||
|
||||
# Add run_id to the results DataFrame
|
||||
df_result['run_id'] = run_id
|
||||
|
||||
print("\nRAGAS Evaluation Results:")
|
||||
print(df_result)
|
||||
|
||||
# Add more detailed analysis
|
||||
print("\nDetailed Metric Analysis:")
|
||||
# First ensure all metric columns are numeric
|
||||
for column in ["faithfulness", "answer_correctness", "context_precision"]:
|
||||
if column in df_result.columns:
|
||||
try:
|
||||
# Convert column to numeric, errors='coerce' will replace non-numeric values with NaN
|
||||
df_result[column] = pd.to_numeric(df_result[column], errors="coerce")
|
||||
# Calculate and print mean, ignoring NaN values
|
||||
mean_value = df_result[column].mean(skipna=True)
|
||||
if pd.notna(mean_value): # Check if mean is not NaN
|
||||
print(f"{column}: {mean_value:.4f}")
|
||||
else:
|
||||
print(f"{column}: No valid numeric values found")
|
||||
except Exception as e:
|
||||
print(f"Error processing {column}: {e}")
|
||||
print(f"Values: {df_result[column].head().tolist()}")
|
||||
|
||||
# Include run_id in the output filename if not explicitly provided
|
||||
if output_file == "ragas_results.csv":
|
||||
# Get just the filename without extension
|
||||
base_name = output_file.rsplit('.', 1)[0]
|
||||
output_file = f"{base_name}_{run_id}.csv"
|
||||
|
||||
# Save results
|
||||
df_result.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
return df_result, run_id
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(f"Error in evaluation: {e}")
|
||||
traceback.print_exc()
|
||||
print("Exiting due to error.")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""Command-line entry point"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run RAGAS evaluation on Morphik using HotpotQA dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--samples", type=int, default=5, help="Number of samples to use (default: 5)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="ragas_results.csv",
|
||||
help="Output file for results (default: ragas_results.csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific run identifier to use (default: auto-generated UUID)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_evaluation(
|
||||
num_samples=args.samples,
|
||||
output_file=args.output,
|
||||
run_id=args.run_id
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user