mirror of
https://github.com/james-m-jordan/openai-cookbook.git
synced 2025-05-09 19:32:38 +00:00
785 lines
72 KiB
Plaintext
785 lines
72 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "xs8_Q1nPwxlK"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Structured Outputs for Multi-Agent Systems\n",
|
||
|
"\n",
|
||
|
"In this cookbook, we will explore how to use Structured Outputs to build multi-agent systems.\n",
|
||
|
"\n",
|
||
|
"Structured Outputs is a new capability that builds upon JSON mode and function calling to enforce a strict schema in a model output.\n",
|
||
|
"\n",
|
||
|
"By using the new parameter `strict: true`, we are able to guarantee the response abides by a provided schema.\n",
|
||
|
"\n",
|
||
|
"To demonstrate the power of this feature, we will use it to build a multi-agent system.\n",
|
||
|
"\n",
|
||
|
"### Why build a Multi-Agent System?\n",
|
||
|
"\n",
|
||
|
"When using function calling, if the number of functions (or tools) increases, the performance may suffer.\n",
|
||
|
"\n",
|
||
|
"To mitigate this, we can logically group the tools together and have specialized \"agents\" that are able to solve specific tasks or sub-tasks, which will increase the overall system performance."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "7SLTnfLRKVnP"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Environment set up"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 52,
|
||
|
"metadata": {
|
||
|
"id": "UCySx7jT6T7Y"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from openai import OpenAI\n",
|
||
|
"from IPython.display import Image\n",
|
||
|
"import json\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from io import StringIO\n",
|
||
|
"import numpy as np\n",
|
||
|
"client = OpenAI()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 53,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"MODEL = \"gpt-4o-2024-08-06\""
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "X34-4ZYyK6-S"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Agents set up\n",
|
||
|
"\n",
|
||
|
"The use case we will tackle is a data analysis task.\n",
|
||
|
"\n",
|
||
|
"Let's first set up our 4-agents system:\n",
|
||
|
"\n",
|
||
|
"1. **Triaging agent:** Decides which agent(s) to call\n",
|
||
|
"2. **Data pre-processing Agent:** Prepares data for analysis - for example by cleaning it up\n",
|
||
|
"3. **Data Analysis Agent:** Performs analysis on the data\n",
|
||
|
"4. **Data Visualization Agent:** Visualizes the output of the analysis to extract insights\n",
|
||
|
"\n",
|
||
|
"We will start by defining the system prompts for each of these agents."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 54,
|
||
|
"metadata": {
|
||
|
"id": "CewlAQuhKUIe"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"triaging_system_prompt = \"\"\"You are a Triaging Agent. Your role is to assess the user's query and route it to the relevant agents. The agents available are:\n",
|
||
|
"- Data Processing Agent: Cleans, transforms, and aggregates data.\n",
|
||
|
"- Analysis Agent: Performs statistical, correlation, and regression analysis.\n",
|
||
|
"- Visualization Agent: Creates bar charts, line charts, and pie charts.\n",
|
||
|
"\n",
|
||
|
"Use the send_query_to_agents tool to forward the user's query to the relevant agents. Also, use the speak_to_user tool to get more information from the user if needed.\"\"\"\n",
|
||
|
"\n",
|
||
|
"processing_system_prompt = \"\"\"You are a Data Processing Agent. Your role is to clean, transform, and aggregate data using the following tools:\n",
|
||
|
"- clean_data\n",
|
||
|
"- transform_data\n",
|
||
|
"- aggregate_data\"\"\"\n",
|
||
|
"\n",
|
||
|
"analysis_system_prompt = \"\"\"You are an Analysis Agent. Your role is to perform statistical, correlation, and regression analysis using the following tools:\n",
|
||
|
"- stat_analysis\n",
|
||
|
"- correlation_analysis\n",
|
||
|
"- regression_analysis\"\"\"\n",
|
||
|
"\n",
|
||
|
"visualization_system_prompt = \"\"\"You are a Visualization Agent. Your role is to create bar charts, line charts, and pie charts using the following tools:\n",
|
||
|
"- create_bar_chart\n",
|
||
|
"- create_line_chart\n",
|
||
|
"- create_pie_chart\"\"\""
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "vkpZ409POhiS"
|
||
|
},
|
||
|
"source": [
|
||
|
"We will then define the tools for each agent.\n",
|
||
|
"\n",
|
||
|
"Apart from the triaging agent, each agent will be equipped with tools specific to their role:\n",
|
||
|
"\n",
|
||
|
"#### Data pre-processing agent\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"1. Clean data\n",
|
||
|
"2. Transform data\n",
|
||
|
"3. Aggregate data\n",
|
||
|
"\n",
|
||
|
"#### Data analysis agent\n",
|
||
|
"\n",
|
||
|
"1. Statistical analysis\n",
|
||
|
"2. Correlation analysis\n",
|
||
|
"3. Regression Analysis\n",
|
||
|
"\n",
|
||
|
"#### Data visualization agent\n",
|
||
|
"\n",
|
||
|
"1. Create bar chart\n",
|
||
|
"2. Create line chart\n",
|
||
|
"3. Create pie chart"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 55,
|
||
|
"metadata": {
|
||
|
"id": "MzBvgBliOc9Y"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"triage_tools = [\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"send_query_to_agents\",\n",
|
||
|
" \"description\": \"Sends the user query to relevant agents based on their capabilities.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"agents\": {\n",
|
||
|
" \"type\": \"array\",\n",
|
||
|
" \"items\": {\"type\": \"string\"},\n",
|
||
|
" \"description\": \"An array of agent names to send the query to.\"\n",
|
||
|
" },\n",
|
||
|
" \"query\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The user query to send.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"agents\", \"query\"]\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" }\n",
|
||
|
"]\n",
|
||
|
"\n",
|
||
|
"preprocess_tools = [\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"clean_data\",\n",
|
||
|
" \"description\": \"Cleans the provided data by removing duplicates and handling missing values.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The dataset to clean. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" },\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"transform_data\",\n",
|
||
|
" \"description\": \"Transforms data based on specified rules.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The data to transform. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"rules\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Transformation rules to apply, specified in a structured format.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"rules\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
"\n",
|
||
|
" },\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"aggregate_data\",\n",
|
||
|
" \"description\": \"Aggregates data by specified columns and operations.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The data to aggregate. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"group_by\": {\n",
|
||
|
" \"type\": \"array\",\n",
|
||
|
" \"items\": {\"type\": \"string\"},\n",
|
||
|
" \"description\": \"Columns to group by.\"\n",
|
||
|
" },\n",
|
||
|
" \"operations\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Aggregation operations to perform, specified in a structured format.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"group_by\", \"operations\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" }\n",
|
||
|
"]\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"analysis_tools = [\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"stat_analysis\",\n",
|
||
|
" \"description\": \"Performs statistical analysis on the given dataset.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The dataset to analyze. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" },\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"correlation_analysis\",\n",
|
||
|
" \"description\": \"Calculates correlation coefficients between variables in the dataset.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The dataset to analyze. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"variables\": {\n",
|
||
|
" \"type\": \"array\",\n",
|
||
|
" \"items\": {\"type\": \"string\"},\n",
|
||
|
" \"description\": \"List of variables to calculate correlations for.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"variables\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" },\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"regression_analysis\",\n",
|
||
|
" \"description\": \"Performs regression analysis on the dataset.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The dataset to analyze. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"dependent_var\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The dependent variable for regression.\"\n",
|
||
|
" },\n",
|
||
|
" \"independent_vars\": {\n",
|
||
|
" \"type\": \"array\",\n",
|
||
|
" \"items\": {\"type\": \"string\"},\n",
|
||
|
" \"description\": \"List of independent variables.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"dependent_var\", \"independent_vars\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" }\n",
|
||
|
"]\n",
|
||
|
"\n",
|
||
|
"visualization_tools = [\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"create_bar_chart\",\n",
|
||
|
" \"description\": \"Creates a bar chart from the provided data.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The data for the bar chart. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"x\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Column for the x-axis.\"\n",
|
||
|
" },\n",
|
||
|
" \"y\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Column for the y-axis.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"x\", \"y\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" },\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"create_line_chart\",\n",
|
||
|
" \"description\": \"Creates a line chart from the provided data.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The data for the line chart. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"x\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Column for the x-axis.\"\n",
|
||
|
" },\n",
|
||
|
" \"y\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Column for the y-axis.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"x\", \"y\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" },\n",
|
||
|
" {\n",
|
||
|
" \"type\": \"function\",\n",
|
||
|
" \"function\": {\n",
|
||
|
" \"name\": \"create_pie_chart\",\n",
|
||
|
" \"description\": \"Creates a pie chart from the provided data.\",\n",
|
||
|
" \"parameters\": {\n",
|
||
|
" \"type\": \"object\",\n",
|
||
|
" \"properties\": {\n",
|
||
|
" \"data\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"The data for the pie chart. Should be in a suitable format such as JSON or CSV.\"\n",
|
||
|
" },\n",
|
||
|
" \"labels\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Column for the labels.\"\n",
|
||
|
" },\n",
|
||
|
" \"values\": {\n",
|
||
|
" \"type\": \"string\",\n",
|
||
|
" \"description\": \"Column for the values.\"\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"required\": [\"data\", \"labels\", \"values\"],\n",
|
||
|
" \"additionalProperties\": False\n",
|
||
|
" }\n",
|
||
|
" },\n",
|
||
|
" \"strict\": True\n",
|
||
|
" }\n",
|
||
|
"]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "yh8tRZHkQJVv"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Tool execution\n",
|
||
|
"\n",
|
||
|
"We need to write the code logic to:\n",
|
||
|
"- handle passing the user query to the multi-agent system\n",
|
||
|
"- handle the internal workings of the multi-agent system\n",
|
||
|
"- execute the tool calls\n",
|
||
|
"\n",
|
||
|
"For the sake of brevity, we will only define the logic for tools that are relevant to the user query."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 56,
|
||
|
"metadata": {
|
||
|
"id": "dwM_0mHZ5pXx"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Example query\n",
|
||
|
"\n",
|
||
|
"user_query = \"\"\"\n",
|
||
|
"Below is some data. I want you to first remove the duplicates then analyze the statistics of the data as well as plot a line chart.\n",
|
||
|
"\n",
|
||
|
"house_size (m3), house_price ($)\n",
|
||
|
"90, 100\n",
|
||
|
"80, 90\n",
|
||
|
"100, 120\n",
|
||
|
"90, 100\n",
|
||
|
"\"\"\"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"From the user query, we can infer that the tools we would need to call are `clean_data`, `start_analysis` and `use_line_chart`.\n",
|
||
|
"\n",
|
||
|
"We will first define the execution function which runs tool calls.\n",
|
||
|
"\n",
|
||
|
"This maps a tool call to the corresponding function. It then appends the output of the function to the conversation history."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 57,
|
||
|
"metadata": {
|
||
|
"id": "XH6wgrATUA_l"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def clean_data(data):\n",
|
||
|
" data_io = StringIO(data)\n",
|
||
|
" df = pd.read_csv(data_io, sep=\",\")\n",
|
||
|
" df_deduplicated = df.drop_duplicates()\n",
|
||
|
" return df_deduplicated\n",
|
||
|
"\n",
|
||
|
"def stat_analysis(data):\n",
|
||
|
" data_io = StringIO(data)\n",
|
||
|
" df = pd.read_csv(data_io, sep=\",\")\n",
|
||
|
" return df.describe()\n",
|
||
|
"\n",
|
||
|
"def plot_line_chart(data):\n",
|
||
|
" data_io = StringIO(data)\n",
|
||
|
" df = pd.read_csv(data_io, sep=\",\")\n",
|
||
|
" \n",
|
||
|
" x = df.iloc[:, 0]\n",
|
||
|
" y = df.iloc[:, 1]\n",
|
||
|
" \n",
|
||
|
" coefficients = np.polyfit(x, y, 1)\n",
|
||
|
" polynomial = np.poly1d(coefficients)\n",
|
||
|
" y_fit = polynomial(x)\n",
|
||
|
" \n",
|
||
|
" plt.figure(figsize=(10, 6))\n",
|
||
|
" plt.plot(x, y, 'o', label='Data Points')\n",
|
||
|
" plt.plot(x, y_fit, '-', label='Best Fit Line')\n",
|
||
|
" plt.title('Line Chart with Best Fit Line')\n",
|
||
|
" plt.xlabel(df.columns[0])\n",
|
||
|
" plt.ylabel(df.columns[1])\n",
|
||
|
" plt.legend()\n",
|
||
|
" plt.grid(True)\n",
|
||
|
" plt.show()\n",
|
||
|
"\n",
|
||
|
"# Define the function to execute the tools\n",
|
||
|
"def execute_tool(tool_calls, messages):\n",
|
||
|
" for tool_call in tool_calls:\n",
|
||
|
" tool_name = tool_call.function.name\n",
|
||
|
" tool_arguments = json.loads(tool_call.function.arguments)\n",
|
||
|
"\n",
|
||
|
" if tool_name == 'clean_data':\n",
|
||
|
" # Simulate data cleaning\n",
|
||
|
" cleaned_df = clean_data(tool_arguments['data'])\n",
|
||
|
" cleaned_data = {\"cleaned_data\": cleaned_df.to_dict()}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(cleaned_data)})\n",
|
||
|
" print('Cleaned data: ', cleaned_df)\n",
|
||
|
" elif tool_name == 'transform_data':\n",
|
||
|
" # Simulate data transformation\n",
|
||
|
" transformed_data = {\"transformed_data\": \"sample_transformed_data\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(transformed_data)})\n",
|
||
|
" elif tool_name == 'aggregate_data':\n",
|
||
|
" # Simulate data aggregation\n",
|
||
|
" aggregated_data = {\"aggregated_data\": \"sample_aggregated_data\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(aggregated_data)})\n",
|
||
|
" elif tool_name == 'stat_analysis':\n",
|
||
|
" # Simulate statistical analysis\n",
|
||
|
" stats_df = stat_analysis(tool_arguments['data'])\n",
|
||
|
" stats = {\"stats\": stats_df.to_dict()}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(stats)})\n",
|
||
|
" print('Statistical Analysis: ', stats_df)\n",
|
||
|
" elif tool_name == 'correlation_analysis':\n",
|
||
|
" # Simulate correlation analysis\n",
|
||
|
" correlations = {\"correlations\": \"sample_correlations\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(correlations)})\n",
|
||
|
" elif tool_name == 'regression_analysis':\n",
|
||
|
" # Simulate regression analysis\n",
|
||
|
" regression_results = {\"regression_results\": \"sample_regression_results\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(regression_results)})\n",
|
||
|
" elif tool_name == 'create_bar_chart':\n",
|
||
|
" # Simulate bar chart creation\n",
|
||
|
" bar_chart = {\"bar_chart\": \"sample_bar_chart\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(bar_chart)})\n",
|
||
|
" elif tool_name == 'create_line_chart':\n",
|
||
|
" # Simulate line chart creation\n",
|
||
|
" line_chart = {\"line_chart\": \"sample_line_chart\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(line_chart)})\n",
|
||
|
" plot_line_chart(tool_arguments['data'])\n",
|
||
|
" elif tool_name == 'create_pie_chart':\n",
|
||
|
" # Simulate pie chart creation\n",
|
||
|
" pie_chart = {\"pie_chart\": \"sample_pie_chart\"}\n",
|
||
|
" messages.append({\"role\": \"tool\", \"name\": tool_name, \"content\": json.dumps(pie_chart)})\n",
|
||
|
" return messages"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Next, we will create the tool handlers for each of the sub-agents.\n",
|
||
|
"\n",
|
||
|
"These have a unique prompt and tool set passed to the model. \n",
|
||
|
"\n",
|
||
|
"The output is then passed to an execution function which runs the tool calls.\n",
|
||
|
"\n",
|
||
|
"We will also append the messages to the conversation history."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 58,
|
||
|
"metadata": {
|
||
|
"id": "EcOGJ0AZTmkp"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the functions to handle each agent's processing\n",
|
||
|
"def handle_data_processing_agent(query, conversation_messages):\n",
|
||
|
" messages = [{\"role\": \"system\", \"content\": processing_system_prompt}]\n",
|
||
|
" messages.append({\"role\": \"user\", \"content\": query})\n",
|
||
|
"\n",
|
||
|
" response = client.chat.completions.create(\n",
|
||
|
" model=MODEL,\n",
|
||
|
" messages=messages,\n",
|
||
|
" temperature=0,\n",
|
||
|
" tools=preprocess_tools,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
||
|
" execute_tool(response.choices[0].message.tool_calls, conversation_messages)\n",
|
||
|
"\n",
|
||
|
"def handle_analysis_agent(query, conversation_messages):\n",
|
||
|
" messages = [{\"role\": \"system\", \"content\": analysis_system_prompt}]\n",
|
||
|
" messages.append({\"role\": \"user\", \"content\": query})\n",
|
||
|
"\n",
|
||
|
" response = client.chat.completions.create(\n",
|
||
|
" model=MODEL,\n",
|
||
|
" messages=messages,\n",
|
||
|
" temperature=0,\n",
|
||
|
" tools=analysis_tools,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
||
|
" execute_tool(response.choices[0].message.tool_calls, conversation_messages)\n",
|
||
|
"\n",
|
||
|
"def handle_visualization_agent(query, conversation_messages):\n",
|
||
|
" messages = [{\"role\": \"system\", \"content\": visualization_system_prompt}]\n",
|
||
|
" messages.append({\"role\": \"user\", \"content\": query})\n",
|
||
|
"\n",
|
||
|
" response = client.chat.completions.create(\n",
|
||
|
" model=MODEL,\n",
|
||
|
" messages=messages,\n",
|
||
|
" temperature=0,\n",
|
||
|
" tools=visualization_tools,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
||
|
" execute_tool(response.choices[0].message.tool_calls, conversation_messages)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Finally, we create the overarching tool to handle processing the user query.\n",
|
||
|
"\n",
|
||
|
"This function takes the user query, gets a response from the model and handles passing it to the other agents to execute. In addition to this, we will keep the state of the ongoing conversation."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 59,
|
||
|
"metadata": {
|
||
|
"id": "4skE5-KYI9Tw"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Function to handle user input and triaging\n",
|
||
|
"def handle_user_message(user_query, conversation_messages=[]):\n",
|
||
|
" user_message = {\"role\": \"user\", \"content\": user_query}\n",
|
||
|
" conversation_messages.append(user_message)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" messages = [{\"role\": \"system\", \"content\": triaging_system_prompt}]\n",
|
||
|
" messages.extend(conversation_messages)\n",
|
||
|
"\n",
|
||
|
" response = client.chat.completions.create(\n",
|
||
|
" model=MODEL,\n",
|
||
|
" messages=messages,\n",
|
||
|
" temperature=0,\n",
|
||
|
" tools=triage_tools,\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" conversation_messages.append([tool_call.function for tool_call in response.choices[0].message.tool_calls])\n",
|
||
|
"\n",
|
||
|
" for tool_call in response.choices[0].message.tool_calls:\n",
|
||
|
" if tool_call.function.name == 'send_query_to_agents':\n",
|
||
|
" agents = json.loads(tool_call.function.arguments)['agents']\n",
|
||
|
" query = json.loads(tool_call.function.arguments)['query']\n",
|
||
|
" for agent in agents:\n",
|
||
|
" if agent == \"Data Processing Agent\":\n",
|
||
|
" handle_data_processing_agent(query, conversation_messages)\n",
|
||
|
" elif agent == \"Analysis Agent\":\n",
|
||
|
" handle_analysis_agent(query, conversation_messages)\n",
|
||
|
" elif agent == \"Visualization Agent\":\n",
|
||
|
" handle_visualization_agent(query, conversation_messages)\n",
|
||
|
"\n",
|
||
|
" return conversation_messages"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "jzQAwIW_WL3k"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Multi-agent system execution\n",
|
||
|
"\n",
|
||
|
"Finally, we run the overarching `handle_user_message` function on the user query and view the output."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 60,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "a0h10s_W49ct",
|
||
|
"outputId": "7e340af9-dc3d-44ba-aa0c-e613fbdcc153"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Cleaned data: house_size (m3) house_price ($)\n",
|
||
|
"0 90 100\n",
|
||
|
"1 80 90\n",
|
||
|
"2 100 120\n",
|
||
|
"Statistical Analysis: house_size house_price\n",
|
||
|
"count 4.000000 4.000000\n",
|
||
|
"mean 90.000000 102.500000\n",
|
||
|
"std 8.164966 12.583057\n",
|
||
|
"min 80.000000 90.000000\n",
|
||
|
"25% 87.500000 97.500000\n",
|
||
|
"50% 90.000000 100.000000\n",
|
||
|
"75% 92.500000 105.000000\n",
|
||
|
"max 100.000000 120.000000\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB5oUlEQVR4nO3deVhUBfvG8XvYQQXEVMBQcSmX3E1z31DRcknL3MqlNEsz9S0ry8Qt3zYzzWzPJbU9rV4zcTczt9Q0zVwwrXDJDRHBAc7vD39MjgOyyMwZ4Pu5Li6dZ86c88zDOHJzlrEYhmEIAAAAAJBjHmY3AAAAAAAFDUEKAAAAAHKJIAUAAAAAuUSQAgAAAIBcIkgBAAAAQC4RpAAAAAAglwhSAAAAAJBLBCkAAAAAyCWCFAAAAADkEkEKAFzgyJEjslgsmjt3rtmtXFdGn6+88orZreSb3My+MD7/vKpYsaIGDhzo1G2sXbtWFotFa9eudep2AMAZCFIAcIPmzp0ri8Wibdu2md3Kde3cuVP9+/dXRESEfH19FRISoqioKH344YdKS0szpadly5YpJiam0Gw3Ixhc/RUSEqI77rhDCxcuzPftXe2FF17QkiVLcrRsRmDM7OuOO+7I9DF79+5VTEyMjhw5kqNtxMTEyGKx6J9//snhMwCAgsXL7AYAoCioUKGCLl26JG9vb1O2/95772nYsGEqW7as7r//flWtWlUXLlzQqlWr9OCDDyo+Pl7jxo1zeV/Lli3T7NmznRqmMpu9s7c7cuRI3X777ZKk06dP65NPPlH//v117tw5DR8+3CnbfOGFF3TPPfeoe/fuOX5Mnz591LlzZ7ta6dKlJUn79++Xh8e/v2/du3evJk6cqNatW6tixYr50bJatmypS5cuycfHJ1/WBwCuRJACABewWCzy8/MzZds//fSThg0bpiZNmmjZsmUqUaKE7b5Ro0Zp27Zt2rNnj0t7unjxoooVK+aSbZkx+xYtWuiee+6x3X7kkUdUqVIlLVq0yGlBKi/q16+v/v37Z3qfr6+v07fv4eFh2r8LALhRHNoHAC6Q2Xk6AwcOVPHixfXXX3+pe/fuKl68uEqXLq0nnnjC4VC79PR0zZgxQzVr1pSfn5/Kli2rhx9+WGfPns122xMnTpTFYtHChQvtQlSGhg0bZnouzDvvvKPKlSvL19dXt99+u7Zu3Wp3/y+//KKBAweqUqVK8vPzU2hoqAYPHqzTp0/bLZdxiNfevXvVt29flSxZUs2bN9fAgQM1e/ZsSbI7tCwrY8aMUalSpWQYhq322GOPyWKxaObMmbbaiRMnZLFYNGfOHEmOs8/pdrN7/rnh4+OjkiVLysvL8feXH330kRo0aCB/f3+FhISod+/eOnbsmN0yBw4cUM+ePRUaGio/Pz/dfPPN6t27t86fP297HhcvXtS8efNsz+dGz2+6+hypuXPn6t5775UktWnTxraNGz23KbNzpFq3bq3bbrtNe/fuVZs2bRQQEKBy5crppZdecnh8SkqKJkyYoCpVqsjX11cREREaO3asUlJSbqgvAMgJ9kgBgInS0tLUsWNHNW7cWK+88opWrlypV199VZUrV9YjjzxiW+7hhx/W3LlzNWjQII0cOVJxcXF64403tGPHDm3cuDHLQwaTkpK0atUqtWzZUuXLl89xX4sWLdKFCxf08MMPy2Kx6KWXXlKPHj10+PBh27ZiY2N1+PBhDRo0SKGhofr111/1zjvv6Ndff9VPP/3kEE7uvfdeVa1aVS+88IIMw1C9evX0999/KzY2VgsWLMi2pxYtWui1117Tr7/+qttuu02StGHDBnl4eGjDhg0aOXKkrSZdOWwsMw8//HC2283J87+eCxcu2M4NOnPmjBYtWqQ9e/bo/ffft1tu6tSpGj9+vHr16qWHHnpIp06d0qxZs9SyZUvt2LFDwcHBunz5sjp27KiUlBQ99thjCg0N1V9//aVvv/1W586dU1BQkBYsWKCHHnpIjRo10tChQyVJlStXzrbPpKQkh3OYgoKCHJ5jy5YtNXLkSM2cOVPjxo1T9erVJcn2Z347e/asoqOj1aNHD/Xq1Uuff/65nnrqKdWqVUudOnWSdOWXC127dtUPP/ygoUOHqnr16tq9e7dee+01/f777zk+XwwA8swAANyQDz/80JBkbN26Nctl4uLiDEnGhx9+aKsNGDDAkGRMmjTJbtl69eoZDRo0sN3esGGDIclYuHCh3XLLly/PtH61Xbt2GZKMxx9/PEfPJaPPUqVKGWfOnLHVly5dakgyvvnmG1stKSnJ4fGLFy82JBnr16+31SZMmGBIMvr06eOw/PDhw42c/ld08uRJQ5Lx5ptvGoZhGOfOnTM8PDyMe++91yhbtqxtuZEjRxohISFGenq63XO6evZZbTc3zz8za9asMSQ5fHl4eBhTp061W/bIkSOGp6enQ3337t2Gl5eXrb5jxw5DkvHZZ59dd9vFihUzBgwYcN1lrn2emX2tWbPGMAzDqFChgt36PvvsM7v7s5PxfT916lSWy2TM6+p1tmrVypBkzJ8/31ZLSUkxQkNDjZ49e9pqCxYsMDw8PIwNGzbYrfOtt94yJBkbN27MUZ8AkFcc2gcAJhs2bJjd7RYtWujw4cO225999pmCgoLUvn17/fPPP7avBg0aqHjx4lqzZk2W605ISJCkTA/pu5777rtPJUuWtOtJkl1f/v7+tr8nJyfrn3/+sV3x7eeff3ZY57XPM7dKly6tatWqaf369ZKkjRs3ytPTU08++aROnDihAwcOSLqyR6p58+bXPUwwOzl5/tfz/PPPKzY2VrGxsfrkk0/Up08fPfvss3r99ddty3z55ZdKT09Xr1697L6voaGhqlq1qu37GhQUJEn6/vvvlZSUlOfnlJmhQ4fa+sz4qlOnTr5uIy+KFy9ud+6Wj4+PGjVq5PDvonr16qpWrZrd/Nq2bStJ1/13AQD5gUP7AMBEfn5+tqukZShZsqTduU8HDhzQ+fPnVaZMmUzXcfLkySzXHxgYKOnKoWa5ce1hgBmh4uq+zpw5o4kTJ+rjjz926CHj3J2rRUZG5qqHzLRo0ULLli2TdCUwNWzYUA0bNlRISIg2bNigsmXLateuXerbt+8NbScnz/96atWqpaioKNvtXr166fz583r66afVt29flS5dWgcOHJBhGKpatWqm68g4vC4yMlJjxozR9OnTtXDhQrVo0UJdu3ZV//79bSErr6pWrWrXp7u4+eabHYJwyZIl9csvv9huHzhwQPv27XP495Phev8uACA/EKQAwESenp7ZLpOenq4yZcpk+TlEWf0gKUlVqlSRl5eXdu/enS99GVdd6KFXr1768ccf9eSTT6pu3boqXry40tPTFR0drfT0dIfHXr0HK6+aN2+ud999V4cPH9aGDRvUokULWSwWNW/eXBs2bFB4eLjS09Nte5DyKifPP7fatWunb7/9Vlu2bNGdd96p9PR0WSwWfffdd5lur3jx4ra/v/rqqxo4cKCWLl2qFStWaOTIkZo2bZp++ukn3XzzzXnuyV3lZP7p6emqVauWpk+fnumyERERTukNADIQpADAzVWuXFkrV65Us2bNch1GAgIC1LZtW61evVrHjh3Ltx8uz549q1WrVmnixIl6/vnnbfWMw+tyKreH32UEpNjYWG3dulVPP/20pCsXQ5gzZ47Cw8NVrFgxNWjQIF+3mx9SU1MlSYmJiZKufF8Nw1BkZKRuueWWbB9fq1Yt1apVS88995x+/PFHNWvWTG+99ZamTJkiyfnPyYyZXU/lypW1a9cutWvXzu16A1A0cI4UALi5Xr16KS0tTZMnT3a4LzU1VefOnbvu4ydMmCDDMHT//ffbfoi/2vbt2zVv3rxc9ZSxx+DaPTQzZszI1XoyPksqu+eQITIyUuXKldNrr70mq9WqZs2aSboSsA4dOqTPP/9cd9xxR6aXGb+R7eaHb7/9VpJs5yD16NFDnp6emjhxosMcDcOwXUY+ISHBFsIy1KpVSx4eHnaX+S5WrJhTn48ZM7ueXr166a+//tK7777rcN+lS5d08eJFE7oCUJS
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x600 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[{'role': 'user',\n",
|
||
|
" 'content': '\\nBelow is some data. I want you to first remove the duplicates then analyze the statistics of the data as well as plot a line chart.\\n\\nhouse_size (m3), house_price ($)\\n90, 100\\n80, 90\\n100, 120\\n90, 100\\n'},\n",
|
||
|
" [Function(arguments='{\"agents\": [\"Data Processing Agent\"], \"query\": \"Remove duplicates from the data: house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='send_query_to_agents'),\n",
|
||
|
" Function(arguments='{\"agents\": [\"Analysis Agent\"], \"query\": \"Analyze the statistics of the data: house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='send_query_to_agents'),\n",
|
||
|
" Function(arguments='{\"agents\": [\"Visualization Agent\"], \"query\": \"Plot a line chart for the data: house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='send_query_to_agents')],\n",
|
||
|
" [Function(arguments='{\"data\":\"house_size (m3), house_price ($)\\\\n90, 100\\\\n80, 90\\\\n100, 120\\\\n90, 100\"}', name='clean_data')],\n",
|
||
|
" {'role': 'tool',\n",
|
||
|
" 'name': 'clean_data',\n",
|
||
|
" 'content': '{\"cleaned_data\": {\"house_size (m3)\": {\"0\": 90, \"1\": 80, \"2\": 100}, \" house_price ($)\": {\"0\": 100, \"1\": 90, \"2\": 120}}}'},\n",
|
||
|
" [Function(arguments='{\"data\":\"house_size,house_price\\\\n90,100\\\\n80,90\\\\n100,120\\\\n90,100\"}', name='stat_analysis')],\n",
|
||
|
" {'role': 'tool',\n",
|
||
|
" 'name': 'stat_analysis',\n",
|
||
|
" 'content': '{\"stats\": {\"house_size\": {\"count\": 4.0, \"mean\": 90.0, \"std\": 8.16496580927726, \"min\": 80.0, \"25%\": 87.5, \"50%\": 90.0, \"75%\": 92.5, \"max\": 100.0}, \"house_price\": {\"count\": 4.0, \"mean\": 102.5, \"std\": 12.583057392117917, \"min\": 90.0, \"25%\": 97.5, \"50%\": 100.0, \"75%\": 105.0, \"max\": 120.0}}}'},\n",
|
||
|
" [Function(arguments='{\"data\":\"house_size,house_price\\\\n90,100\\\\n80,90\\\\n100,120\\\\n90,100\",\"x\":\"house_size\",\"y\":\"house_price\"}', name='create_line_chart')],\n",
|
||
|
" {'role': 'tool',\n",
|
||
|
" 'name': 'create_line_chart',\n",
|
||
|
" 'content': '{\"line_chart\": \"sample_line_chart\"}'}]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 60,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"handle_user_message(user_query)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Conclusion\n",
|
||
|
"\n",
|
||
|
"In this cookbook, we've explored how to leverage Structured Outputs to build more robust multi-agent systems.\n",
|
||
|
"\n",
|
||
|
"Using this new feature allows to make sure that tool calls follow the specified schema and avoids having to handle edge cases or validate arguments on your side.\n",
|
||
|
"\n",
|
||
|
"This can be applied to many more use cases, and we hope you can take inspiration from this to build your own use case!"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"provenance": []
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 1
|
||
|
}
|