mirror of
https://github.com/the-jordan-lab/docs.git
synced 2025-05-09 21:32:38 +00:00
382 lines
16 KiB
Python
382 lines
16 KiB
Python
![]() |
import os
|
||
|
import yaml
|
||
|
import logging
|
||
|
from typing import Dict, Any
|
||
|
import shutil
|
||
|
from datetime import datetime
|
||
|
import chromadb
|
||
|
from chromadb.config import Settings
|
||
|
from chromadb.utils import embedding_functions
|
||
|
import git
|
||
|
from Bio import Entrez
|
||
|
import json
|
||
|
|
||
|
class AgentTaskRunner:
|
||
|
"""
|
||
|
Main class for running agent-driven lab management actions.
|
||
|
Handles loading function call JSON, dispatching to handlers, and logging results.
|
||
|
"""
|
||
|
def __init__(self, actions_path: str = "actions.json"):
|
||
|
self.actions_path = actions_path
|
||
|
self.logger = self._setup_logger()
|
||
|
# Initialize ChromaDB client for Smart-Fill
|
||
|
self.chroma_client = chromadb.Client(Settings(persist_directory="Data/chroma_index"))
|
||
|
self.embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
|
||
|
api_key=os.getenv("OPENAI_API_KEY", "sk-..."),
|
||
|
model_name="text-embedding-ada-002"
|
||
|
)
|
||
|
# Build or load embedding index
|
||
|
self.protocol_collection = self.chroma_client.get_or_create_collection("protocols", embedding_function=self.embedding_fn)
|
||
|
self.experiment_collection = self.chroma_client.get_or_create_collection("experiments", embedding_function=self.embedding_fn)
|
||
|
self.index_repo_content()
|
||
|
# TODO: Initialize other state as needed (e.g., repo paths, templates)
|
||
|
|
||
|
def _setup_logger(self):
|
||
|
logger = logging.getLogger("AgentTaskRunner")
|
||
|
logger.setLevel(logging.INFO)
|
||
|
handler = logging.StreamHandler()
|
||
|
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s: %(message)s"))
|
||
|
logger.addHandler(handler)
|
||
|
return logger
|
||
|
|
||
|
def load_actions(self) -> Dict[str, Any]:
|
||
|
"""
|
||
|
Load the next action(s) for the agent to execute (e.g., from a JSON file or IPC).
|
||
|
"""
|
||
|
# TODO: Implement loading logic
|
||
|
return {}
|
||
|
|
||
|
def run(self):
|
||
|
"""
|
||
|
Main loop: load actions, dispatch to handler, log results.
|
||
|
"""
|
||
|
self.logger.info("AgentTaskRunner started.")
|
||
|
# DEMO: Simulate loading a sample action
|
||
|
sample_action = {
|
||
|
"function": "create_protocol",
|
||
|
"arguments": {
|
||
|
"name": "Sample Protocol",
|
||
|
"id": "PROT-1234",
|
||
|
"version": "1.0",
|
||
|
"description": "A sample protocol for demonstration.",
|
||
|
"author": "Demo User",
|
||
|
"created": datetime.now().strftime("%Y-%m-%d"),
|
||
|
"materials": [
|
||
|
{"Material": "Buffer A (10 mL)"},
|
||
|
{"Material": "Reagent X (5 mg)"}
|
||
|
],
|
||
|
"steps": [
|
||
|
"Mix Buffer A and Reagent X.",
|
||
|
"Incubate for 10 minutes."
|
||
|
],
|
||
|
"notes": "This is a demo protocol."
|
||
|
}
|
||
|
}
|
||
|
# Dispatch demo action
|
||
|
if sample_action["function"] == "create_protocol":
|
||
|
self.handle_create_protocol(sample_action["arguments"])
|
||
|
# TODO: Replace with real action loading and dispatch loop
|
||
|
# TODO: Implement Smart-Fill (suggest_metadata) and GitHub integration in future bursts
|
||
|
|
||
|
def _load_user_profiles(self):
|
||
|
path = os.path.join("Data", "user_profiles.json")
|
||
|
if os.path.exists(path):
|
||
|
with open(path, "r") as f:
|
||
|
return json.load(f)
|
||
|
return {}
|
||
|
|
||
|
def _save_user_profiles(self, profiles):
|
||
|
path = os.path.join("Data", "user_profiles.json")
|
||
|
with open(path, "w") as f:
|
||
|
json.dump(profiles, f, indent=2)
|
||
|
|
||
|
def _update_user_profile(self, username, key, value):
|
||
|
profiles = self._load_user_profiles()
|
||
|
user = profiles.get(username, {})
|
||
|
if key in user and isinstance(user[key], list):
|
||
|
if value not in user[key]:
|
||
|
user[key].append(value)
|
||
|
else:
|
||
|
user[key] = [value]
|
||
|
user["last_active"] = datetime.now().strftime("%Y-%m-%d")
|
||
|
profiles[username] = user
|
||
|
self._save_user_profiles(profiles)
|
||
|
|
||
|
def handle_create_protocol(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle creation of a new protocol YAML in Protocols/.
|
||
|
"""
|
||
|
template_path = os.path.join("Templates", "protocol_template.yaml")
|
||
|
with open(template_path, "r") as f:
|
||
|
template = yaml.safe_load(f)
|
||
|
# Fill in template with args
|
||
|
for key, value in args.items():
|
||
|
template[key] = value
|
||
|
# Generate filename
|
||
|
name = args.get("name", "unnamed_protocol").lower().replace(" ", "_")
|
||
|
version = args.get("version", "1.0")
|
||
|
filename = f"{name}_v{version}.yaml"
|
||
|
out_path = os.path.join("Protocols", filename)
|
||
|
# Ensure unique filename if exists
|
||
|
i = 1
|
||
|
base_filename = filename
|
||
|
while os.path.exists(out_path):
|
||
|
filename = f"{name}_v{version}_{i}.yaml"
|
||
|
out_path = os.path.join("Protocols", filename)
|
||
|
i += 1
|
||
|
with open(out_path, "w") as f:
|
||
|
yaml.dump(template, f, sort_keys=False)
|
||
|
self.logger.info(f"Created protocol: {out_path}")
|
||
|
# Update user profile
|
||
|
author = args.get("author", "unknown").replace(" ", "_").lower()
|
||
|
self._update_user_profile(author, "frequent_protocols", args.get("name", "unnamed_protocol"))
|
||
|
|
||
|
def handle_create_experiment(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle creation of a new experiment YAML in Experiments/.
|
||
|
"""
|
||
|
template_path = os.path.join("Templates", "experiment_template.yaml")
|
||
|
with open(template_path, "r") as f:
|
||
|
template = yaml.safe_load(f)
|
||
|
# Fill in template with args
|
||
|
for key, value in args.items():
|
||
|
template[key] = value
|
||
|
# Generate filename
|
||
|
date = args.get("date", datetime.now().strftime("%Y-%m-%d"))
|
||
|
researcher = args.get("researcher", "unknown").replace(" ", "_")
|
||
|
title = args.get("title", "untitled").lower().replace(" ", "_")
|
||
|
filename = f"{date}_{title}_{researcher}.yaml"
|
||
|
out_path = os.path.join("Experiments", filename)
|
||
|
# Ensure unique filename if exists
|
||
|
i = 1
|
||
|
base_filename = filename
|
||
|
while os.path.exists(out_path):
|
||
|
filename = f"{date}_{title}_{researcher}_{i}.yaml"
|
||
|
out_path = os.path.join("Experiments", filename)
|
||
|
i += 1
|
||
|
with open(out_path, "w") as f:
|
||
|
yaml.dump(template, f, sort_keys=False)
|
||
|
self.logger.info(f"Created experiment: {out_path}")
|
||
|
# Update user profile
|
||
|
researcher = args.get("researcher", "unknown").replace(" ", "_").lower()
|
||
|
if "cell_line" in args:
|
||
|
self._update_user_profile(researcher, "frequent_cell_lines", args["cell_line"])
|
||
|
self._update_user_profile(researcher, "recent_experiments", filename)
|
||
|
|
||
|
def handle_update_experiment(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle updating an existing experiment YAML (e.g., add results, change status).
|
||
|
Args should include: experiment_id (filename or unique id), updates (dict of fields to update)
|
||
|
"""
|
||
|
experiment_id = args.get("experiment_id")
|
||
|
updates = args.get("updates", {})
|
||
|
if not experiment_id or not updates:
|
||
|
self.logger.error("Missing experiment_id or updates for update_experiment.")
|
||
|
return
|
||
|
# Find experiment file
|
||
|
exp_dir = "Experiments"
|
||
|
exp_file = None
|
||
|
for fname in os.listdir(exp_dir):
|
||
|
if experiment_id in fname:
|
||
|
exp_file = os.path.join(exp_dir, fname)
|
||
|
break
|
||
|
if not exp_file or not os.path.exists(exp_file):
|
||
|
self.logger.error(f"Experiment file not found for id: {experiment_id}")
|
||
|
return
|
||
|
# Load, update, and save YAML
|
||
|
with open(exp_file, "r") as f:
|
||
|
exp = yaml.safe_load(f)
|
||
|
exp.update(updates)
|
||
|
with open(exp_file, "w") as f:
|
||
|
yaml.dump(exp, f, sort_keys=False)
|
||
|
self.logger.info(f"Updated experiment: {exp_file}")
|
||
|
# Validate if status is completed
|
||
|
if exp.get("status", "").lower() == "completed":
|
||
|
missing = []
|
||
|
for field in ["results", "conclusion"]:
|
||
|
if not exp.get(field):
|
||
|
missing.append(field)
|
||
|
if missing:
|
||
|
issue_title = f"Experiment {experiment_id} missing required fields"
|
||
|
issue_body = f"The following required fields are missing: {', '.join(missing)}. Please update the experiment record."
|
||
|
self.handle_open_issue({"title": issue_title, "body": issue_body})
|
||
|
# Log to changelog
|
||
|
self.append_changelog(f"Updated experiment {experiment_id}: {list(updates.keys())}")
|
||
|
# Optionally update tasks
|
||
|
self.mark_task_done_for_experiment(experiment_id)
|
||
|
|
||
|
def append_changelog(self, entry: str):
|
||
|
with open("CHANGELOG.md", "a") as f:
|
||
|
f.write(f"- {datetime.now().strftime('%Y-%m-%d %H:%M')}: {entry}\n")
|
||
|
|
||
|
def mark_task_done_for_experiment(self, experiment_id: str):
|
||
|
# Mark a lab task as done in TASKS.md if it references this experiment
|
||
|
if not os.path.exists("TASKS.md"): return
|
||
|
with open("TASKS.md", "r") as f:
|
||
|
lines = f.readlines()
|
||
|
updated = False
|
||
|
for i, line in enumerate(lines):
|
||
|
if experiment_id in line and line.strip().startswith("- [ ]"):
|
||
|
lines[i] = line.replace("- [ ]", "- [x]", 1)
|
||
|
updated = True
|
||
|
if updated:
|
||
|
with open("TASKS.md", "w") as f:
|
||
|
f.writelines(lines)
|
||
|
|
||
|
def handle_suggest_metadata(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle metadata suggestion using embeddings or external sources.
|
||
|
Args should include: field (str), context (str)
|
||
|
"""
|
||
|
field = args.get("field")
|
||
|
context = args.get("context", "")
|
||
|
username = args.get("user", "unknown").replace(" ", "_").lower()
|
||
|
# Personalized suggestions from user profile
|
||
|
profiles = self._load_user_profiles()
|
||
|
user = profiles.get(username, {})
|
||
|
personalized = []
|
||
|
if field == "cell_line" and "frequent_cell_lines" in user:
|
||
|
for cl in user["frequent_cell_lines"]:
|
||
|
personalized.append({"source": "user_profile", "text": cl})
|
||
|
if field == "protocol" and "frequent_protocols" in user:
|
||
|
for p in user["frequent_protocols"]:
|
||
|
personalized.append({"source": "user_profile", "text": p})
|
||
|
# Query both protocol and experiment collections for similar content
|
||
|
results = self.protocol_collection.query(query_texts=[context], n_results=3)
|
||
|
exp_results = self.experiment_collection.query(query_texts=[context], n_results=3)
|
||
|
suggestions = personalized
|
||
|
for r in results["documents"][0]:
|
||
|
suggestions.append({"source": "protocol", "text": r})
|
||
|
for r in exp_results["documents"][0]:
|
||
|
suggestions.append({"source": "experiment", "text": r})
|
||
|
# If suggestions are weak or empty, query PubMed
|
||
|
if not suggestions or all(not s["text"].strip() for s in suggestions):
|
||
|
pubmed_suggestions = self.query_pubmed(context)
|
||
|
for s in pubmed_suggestions:
|
||
|
suggestions.append({"source": "pubmed", "text": s})
|
||
|
self.logger.info(f"Smart-Fill suggestions for '{field}' (user: {username}): {suggestions}")
|
||
|
return suggestions
|
||
|
|
||
|
def query_pubmed(self, query: str, max_results: int = 3):
|
||
|
"""
|
||
|
Query PubMed for relevant articles using Biopython Entrez.
|
||
|
Returns a list of article titles and snippets.
|
||
|
"""
|
||
|
Entrez.email = os.getenv("ENTREZ_EMAIL", "labagent@example.com")
|
||
|
try:
|
||
|
handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
|
||
|
record = Entrez.read(handle)
|
||
|
ids = record["IdList"]
|
||
|
if not ids:
|
||
|
return []
|
||
|
handle = Entrez.efetch(db="pubmed", id=",".join(ids), rettype="medline", retmode="text")
|
||
|
abstracts = handle.read().split("\n\n")
|
||
|
results = []
|
||
|
for abs_text in abstracts:
|
||
|
lines = abs_text.split("\n")
|
||
|
title = next((l[6:] for l in lines if l.startswith("TI - ")), "")
|
||
|
snippet = next((l[6:] for l in lines if l.startswith("AB - ")), "")
|
||
|
if title:
|
||
|
results.append(f"{title} {'- ' + snippet if snippet else ''}")
|
||
|
return results[:max_results]
|
||
|
except Exception as e:
|
||
|
self.logger.error(f"PubMed query failed: {e}")
|
||
|
return []
|
||
|
|
||
|
def handle_commit_and_push(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle committing and pushing changes to GitHub.
|
||
|
Args should include: message (str)
|
||
|
"""
|
||
|
repo = git.Repo(os.getcwd())
|
||
|
message = args.get("message", "Lab agent commit")
|
||
|
# Stage all changes
|
||
|
repo.git.add(A=True)
|
||
|
# Commit
|
||
|
repo.index.commit(message)
|
||
|
self.logger.info(f"Committed changes: {message}")
|
||
|
# Push
|
||
|
origin = repo.remote(name='origin')
|
||
|
origin.push()
|
||
|
self.logger.info("Pushed changes to origin.")
|
||
|
|
||
|
def handle_open_issue(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle opening a GitHub issue via CLI (gh).
|
||
|
Args should include: title (str), body (str)
|
||
|
"""
|
||
|
title = args.get("title", "Lab Agent Issue")
|
||
|
body = args.get("body", "Created by lab agent.")
|
||
|
cmd = f'gh issue create --title "{title}" --body "{body}"'
|
||
|
result = os.system(cmd)
|
||
|
if result == 0:
|
||
|
self.logger.info(f"Opened GitHub issue: {title}")
|
||
|
else:
|
||
|
self.logger.error(f"Failed to open GitHub issue: {title}")
|
||
|
|
||
|
def handle_open_pr(self, args: Dict[str, Any]):
|
||
|
"""
|
||
|
Handle opening a draft pull request via CLI (gh).
|
||
|
Args should include: title (str), body (str), base (str, optional)
|
||
|
"""
|
||
|
title = args.get("title", "Lab Agent Draft PR")
|
||
|
body = args.get("body", "Draft PR created by lab agent.")
|
||
|
base = args.get("base", "main")
|
||
|
cmd = f'gh pr create --title "{title}" --body "{body}" --base {base} --draft'
|
||
|
result = os.system(cmd)
|
||
|
if result == 0:
|
||
|
self.logger.info(f"Opened draft PR: {title}")
|
||
|
else:
|
||
|
self.logger.error(f"Failed to open draft PR: {title}")
|
||
|
|
||
|
def _make_commit_message(self, action_type: str, details: Dict[str, Any]) -> str:
|
||
|
"""
|
||
|
Helper to generate a structured commit message.
|
||
|
"""
|
||
|
if action_type == "create_protocol":
|
||
|
return f"Protocol: Created {details.get('name', 'unnamed protocol')} v{details.get('version', '')}"
|
||
|
elif action_type == "create_experiment":
|
||
|
return f"Experiment: Created {details.get('title', 'untitled experiment')} ({details.get('experiment_id', '')})"
|
||
|
elif action_type == "update_experiment":
|
||
|
return f"Experiment: Updated {details.get('experiment_id', '')}"
|
||
|
elif action_type == "open_issue":
|
||
|
return f"Issue: {details.get('title', 'No title')}"
|
||
|
elif action_type == "open_pr":
|
||
|
return f"PR: {details.get('title', 'No title')}"
|
||
|
return "Lab agent commit"
|
||
|
|
||
|
def index_repo_content(self):
|
||
|
"""
|
||
|
Index all Protocols/ and Experiments/ YAMLs for Smart-Fill suggestions.
|
||
|
"""
|
||
|
# Index protocols
|
||
|
protocol_dir = "Protocols"
|
||
|
for fname in os.listdir(protocol_dir):
|
||
|
if fname.endswith(".yaml"):
|
||
|
with open(os.path.join(protocol_dir, fname), "r") as f:
|
||
|
doc = yaml.safe_load(f)
|
||
|
doc_id = f"protocol_{fname}"
|
||
|
content = f"{doc.get('name', '')} {doc.get('description', '')} {doc.get('materials', '')} {doc.get('steps', '')}"
|
||
|
self.protocol_collection.upsert(
|
||
|
ids=[doc_id],
|
||
|
documents=[content],
|
||
|
metadatas=[{"file": fname}]
|
||
|
)
|
||
|
# Index experiments
|
||
|
exp_dir = "Experiments"
|
||
|
for fname in os.listdir(exp_dir):
|
||
|
if fname.endswith(".yaml"):
|
||
|
with open(os.path.join(exp_dir, fname), "r") as f:
|
||
|
doc = yaml.safe_load(f)
|
||
|
doc_id = f"experiment_{fname}"
|
||
|
content = f"{doc.get('title', '')} {doc.get('protocol', '')} {doc.get('materials', '')} {doc.get('parameters', '')} {doc.get('results', '')}"
|
||
|
self.experiment_collection.upsert(
|
||
|
ids=[doc_id],
|
||
|
documents=[content],
|
||
|
metadatas=[{"file": fname}]
|
||
|
)
|
||
|
self.logger.info("Indexed protocols and experiments for Smart-Fill.")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
runner = AgentTaskRunner()
|
||
|
runner.run()
|