mirror of
https://github.com/james-m-jordan/openai-cookbook.git
synced 2025-05-09 19:32:38 +00:00
797 lines
35 KiB
Plaintext
797 lines
35 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "3e67f200",
|
|
"metadata": {},
|
|
"source": [
|
|
"# How to use functions with a knowledge base\n",
|
|
"\n",
|
|
"This notebook builds on the concepts in the [argument generation](How_to_call_functions_with_chat_models.ipynb) notebook, by creating an agent with access to a knowledge base and two functions that it can call based on the user requirement.\n",
|
|
"\n",
|
|
"We'll create an agent that uses data from arXiv to answer questions about academic subjects. It has two functions at its disposal:\n",
|
|
"- **get_articles**: A function that gets arXiv articles on a subject and summarizes them for the user with links.\n",
|
|
"- **read_article_and_summarize**: This function takes one of the previously searched articles, reads it in its entirety and summarizes the core argument, evidence and conclusions.\n",
|
|
"\n",
|
|
"This will get you comfortable with a multi-function workflow that can choose from multiple services, and where some of the data from the first function is persisted to be used by the second.\n",
|
|
"\n",
|
|
"## Walkthrough\n",
|
|
"\n",
|
|
"This cookbook takes you through the following workflow:\n",
|
|
"\n",
|
|
"- **Search utilities:** Creating the two functions that access arXiv for answers.\n",
|
|
"- **Configure Agent:** Building up the Agent behaviour that will assess the need for a function and, if one is required, call that function and present results back to the agent.\n",
|
|
"- **arXiv conversation:** Put all of this together in live conversation.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "80e71f33",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"is_executing": true
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install scipy --quiet\n",
|
|
"!pip install tenacity --quiet\n",
|
|
"!pip install tiktoken==0.3.3 --quiet\n",
|
|
"!pip install termcolor --quiet\n",
|
|
"!pip install openai --quiet\n",
|
|
"!pip install arxiv --quiet\n",
|
|
"!pip install pandas --quiet\n",
|
|
"!pip install PyPDF2 --quiet\n",
|
|
"!pip install tqdm --quiet"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "dab872c5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import arxiv\n",
|
|
"import ast\n",
|
|
"import concurrent\n",
|
|
"import json\n",
|
|
"import os\n",
|
|
"import pandas as pd\n",
|
|
"import tiktoken\n",
|
|
"from csv import writer\n",
|
|
"from IPython.display import display, Markdown, Latex\n",
|
|
"from openai import OpenAI\n",
|
|
"from PyPDF2 import PdfReader\n",
|
|
"from scipy import spatial\n",
|
|
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
|
|
"from tqdm import tqdm\n",
|
|
"from termcolor import colored\n",
|
|
"\n",
|
|
"GPT_MODEL = \"gpt-4o-mini\"\n",
|
|
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
|
|
"client = OpenAI()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "f2e47962",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Search utilities\n",
|
|
"\n",
|
|
"We'll first set up some utilities that will underpin our two functions.\n",
|
|
"\n",
|
|
"Downloaded papers will be stored in a directory (we use ```./data/papers``` here). We create a file ```arxiv_library.csv``` to store the embeddings and details for downloaded papers to retrieve against using ```summarize_text```."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "2de5d32d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Directory './data/papers' already exists.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"directory = './data/papers'\n",
|
|
"\n",
|
|
"# Check if the directory already exists\n",
|
|
"if not os.path.exists(directory):\n",
|
|
" # If the directory doesn't exist, create it and any necessary intermediate directories\n",
|
|
" os.makedirs(directory)\n",
|
|
" print(f\"Directory '{directory}' created successfully.\")\n",
|
|
"else:\n",
|
|
" # If the directory already exists, print a message indicating it\n",
|
|
" print(f\"Directory '{directory}' already exists.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "ae5cb7a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set a directory to store downloaded papers\n",
|
|
"data_dir = os.path.join(os.curdir, \"data\", \"papers\")\n",
|
|
"paper_dir_filepath = \"./data/papers/arxiv_library.csv\"\n",
|
|
"\n",
|
|
"# Generate a blank dataframe where we can store downloaded files\n",
|
|
"df = pd.DataFrame(list())\n",
|
|
"df.to_csv(paper_dir_filepath)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "57217b9d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
|
|
"def embedding_request(text):\n",
|
|
" response = client.embeddings.create(input=text, model=EMBEDDING_MODEL)\n",
|
|
" return response\n",
|
|
"\n",
|
|
"\n",
|
|
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
|
|
"def get_articles(query, library=paper_dir_filepath, top_k=10):\n",
|
|
" \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n",
|
|
" It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n",
|
|
" \"\"\"\n",
|
|
" client = arxiv.Client()\n",
|
|
" search = arxiv.Search(\n",
|
|
" query = query,\n",
|
|
" max_results = top_k\n",
|
|
" )\n",
|
|
" result_list = []\n",
|
|
" for result in client.results(search):\n",
|
|
" result_dict = {}\n",
|
|
" result_dict.update({\"title\": result.title})\n",
|
|
" result_dict.update({\"summary\": result.summary})\n",
|
|
"\n",
|
|
" # Taking the first url provided\n",
|
|
" result_dict.update({\"article_url\": [x.href for x in result.links][0]})\n",
|
|
" result_dict.update({\"pdf_url\": [x.href for x in result.links][1]})\n",
|
|
" result_list.append(result_dict)\n",
|
|
"\n",
|
|
" # Store references in library file\n",
|
|
" response = embedding_request(text=result.title)\n",
|
|
" file_reference = [\n",
|
|
" result.title,\n",
|
|
" result.download_pdf(data_dir),\n",
|
|
" response.data[0].embedding,\n",
|
|
" ]\n",
|
|
"\n",
|
|
" # Write to file\n",
|
|
" with open(library, \"a\") as f_object:\n",
|
|
" writer_object = writer(f_object)\n",
|
|
" writer_object.writerow(file_reference)\n",
|
|
" f_object.close()\n",
|
|
" return result_list\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "dda02bdb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n",
|
|
" 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n",
|
|
" 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n",
|
|
" 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test that the search is working\n",
|
|
"result_output = get_articles(\"ppo reinforcement learning\")\n",
|
|
"result_output[0]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "11675627",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def strings_ranked_by_relatedness(\n",
|
|
" query: str,\n",
|
|
" df: pd.DataFrame,\n",
|
|
" relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),\n",
|
|
" top_n: int = 100,\n",
|
|
") -> list[str]:\n",
|
|
" \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n",
|
|
" query_embedding_response = embedding_request(query)\n",
|
|
" query_embedding = query_embedding_response.data[0].embedding\n",
|
|
" strings_and_relatednesses = [\n",
|
|
" (row[\"filepath\"], relatedness_fn(query_embedding, row[\"embedding\"]))\n",
|
|
" for i, row in df.iterrows()\n",
|
|
" ]\n",
|
|
" strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)\n",
|
|
" strings, relatednesses = zip(*strings_and_relatednesses)\n",
|
|
" return strings[:top_n]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "7211df2c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def read_pdf(filepath):\n",
|
|
" \"\"\"Takes a filepath to a PDF and returns a string of the PDF's contents\"\"\"\n",
|
|
" # creating a pdf reader object\n",
|
|
" reader = PdfReader(filepath)\n",
|
|
" pdf_text = \"\"\n",
|
|
" page_number = 0\n",
|
|
" for page in reader.pages:\n",
|
|
" page_number += 1\n",
|
|
" pdf_text += page.extract_text() + f\"\\nPage Number: {page_number}\"\n",
|
|
" return pdf_text\n",
|
|
"\n",
|
|
"\n",
|
|
"# Split a text into smaller chunks of size n, preferably ending at the end of a sentence\n",
|
|
"def create_chunks(text, n, tokenizer):\n",
|
|
" \"\"\"Returns successive n-sized chunks from provided text.\"\"\"\n",
|
|
" tokens = tokenizer.encode(text)\n",
|
|
" i = 0\n",
|
|
" while i < len(tokens):\n",
|
|
" # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens\n",
|
|
" j = min(i + int(1.5 * n), len(tokens))\n",
|
|
" while j > i + int(0.5 * n):\n",
|
|
" # Decode the tokens and check for full stop or newline\n",
|
|
" chunk = tokenizer.decode(tokens[i:j])\n",
|
|
" if chunk.endswith(\".\") or chunk.endswith(\"\\n\"):\n",
|
|
" break\n",
|
|
" j -= 1\n",
|
|
" # If no end of sentence found, use n tokens as the chunk size\n",
|
|
" if j == i + int(0.5 * n):\n",
|
|
" j = min(i + n, len(tokens))\n",
|
|
" yield tokens[i:j]\n",
|
|
" i = j\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_chunk(content, template_prompt):\n",
|
|
" \"\"\"This function applies a prompt to some input content. In this case it returns a summarized chunk of text\"\"\"\n",
|
|
" prompt = template_prompt + content\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n",
|
|
" )\n",
|
|
" return response.choices[0].message.content\n",
|
|
"\n",
|
|
"\n",
|
|
"def summarize_text(query):\n",
|
|
" \"\"\"This function does the following:\n",
|
|
" - Reads in the arxiv_library.csv file in including the embeddings\n",
|
|
" - Finds the closest file to the user's query\n",
|
|
" - Scrapes the text out of the file and chunks it\n",
|
|
" - Summarizes each chunk in parallel\n",
|
|
" - Does one final summary and returns this to the user\"\"\"\n",
|
|
"\n",
|
|
" # A prompt to dictate how the recursive summarizations should approach the input paper\n",
|
|
" summary_prompt = \"\"\"Summarize this text from an academic paper. Extract any key points with reasoning.\\n\\nContent:\"\"\"\n",
|
|
"\n",
|
|
" # If the library is empty (no searches have been performed yet), we perform one and download the results\n",
|
|
" library_df = pd.read_csv(paper_dir_filepath).reset_index()\n",
|
|
" if len(library_df) == 0:\n",
|
|
" print(\"No papers searched yet, downloading first.\")\n",
|
|
" get_articles(query)\n",
|
|
" print(\"Papers downloaded, continuing\")\n",
|
|
" library_df = pd.read_csv(paper_dir_filepath).reset_index()\n",
|
|
" else:\n",
|
|
" print(\"Existing papers found... Articles:\", len(library_df))\n",
|
|
" library_df.columns = [\"title\", \"filepath\", \"embedding\"]\n",
|
|
" library_df[\"embedding\"] = library_df[\"embedding\"].apply(ast.literal_eval)\n",
|
|
" strings = strings_ranked_by_relatedness(query, library_df, top_n=1)\n",
|
|
" print(\"Chunking text from paper\")\n",
|
|
" pdf_text = read_pdf(strings[0])\n",
|
|
"\n",
|
|
" # Initialise tokenizer\n",
|
|
" tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n",
|
|
" results = \"\"\n",
|
|
"\n",
|
|
" # Chunk up the document into 1500 token chunks\n",
|
|
" chunks = create_chunks(pdf_text, 1500, tokenizer)\n",
|
|
" text_chunks = [tokenizer.decode(chunk) for chunk in chunks]\n",
|
|
" print(\"Summarizing each chunk of text\")\n",
|
|
"\n",
|
|
" # Parallel process the summaries\n",
|
|
" with concurrent.futures.ThreadPoolExecutor(\n",
|
|
" max_workers=len(text_chunks)\n",
|
|
" ) as executor:\n",
|
|
" futures = [\n",
|
|
" executor.submit(extract_chunk, chunk, summary_prompt)\n",
|
|
" for chunk in text_chunks\n",
|
|
" ]\n",
|
|
" with tqdm(total=len(text_chunks)) as pbar:\n",
|
|
" for _ in concurrent.futures.as_completed(futures):\n",
|
|
" pbar.update(1)\n",
|
|
" for future in futures:\n",
|
|
" data = future.result()\n",
|
|
" results += data\n",
|
|
"\n",
|
|
" # Final summary\n",
|
|
" print(\"Summarizing into overall summary\")\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=GPT_MODEL,\n",
|
|
" messages=[\n",
|
|
" {\n",
|
|
" \"role\": \"user\",\n",
|
|
" \"content\": f\"\"\"Write a summary collated from this collection of key points extracted from an academic paper.\n",
|
|
" The summary should highlight the core argument, conclusions and evidence, and answer the user's query.\n",
|
|
" User query: {query}\n",
|
|
" The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.\n",
|
|
" Key points:\\n{results}\\nSummary:\\n\"\"\",\n",
|
|
" }\n",
|
|
" ],\n",
|
|
" temperature=0,\n",
|
|
" )\n",
|
|
" return response\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "898b94d4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Existing papers found... Articles: 10\n",
|
|
"Chunking text from paper\n",
|
|
"Summarizing each chunk of text\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.40s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Summarizing into overall summary\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Test the summarize_text function works\n",
|
|
"chat_test_response = summarize_text(\"PPO reinforcement learning sequence generation\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "c715f60d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/markdown": [
|
|
"### Core Argument\n",
|
|
"- The paper argues that Proximal Policy Optimization (PPO) and its dynamic variant (PPO-dynamic) significantly improve sequence generation tasks, particularly for chit-chat chatbots, by addressing the instability and suboptimal performance associated with traditional policy gradient methods.\n",
|
|
"\n",
|
|
"### Evidence\n",
|
|
"- **Challenges with Traditional Methods**: Traditional policy gradient methods, like REINFORCE, suffer from unstable training and poor performance due to large updates and similar action tendencies, especially in non-differentiable evaluation contexts (e.g., BLEU scores).\n",
|
|
"- **PPO Advantages**: PPO regularizes policy updates, enhancing training stability and enabling the generation of coherent and diverse chatbot responses.\n",
|
|
"- **Dynamic PPO Approach**: PPO-dynamic introduces adaptive constraints on KL-divergence, allowing for dynamic adjustments based on action probabilities, which leads to improved training performance.\n",
|
|
"- **Experimental Validation**: The authors conducted experiments on synthetic counting tasks and real-world chit-chat scenarios, demonstrating that PPO and PPO-dynamic outperform traditional methods like REINFORCE and SeqGAN in terms of stability and performance metrics (e.g., BLEU-2 scores).\n",
|
|
"- **Results**: PPO-dynamic showed faster convergence and higher precision in the counting task, and it achieved the best performance in the chit-chat task, indicating its effectiveness in generating diverse and contextually appropriate responses.\n",
|
|
"\n",
|
|
"### Conclusions\n",
|
|
"- The introduction of PPO and PPO-dynamic enhances the training stability and output diversity in sequence generation tasks, making them more suitable for applications like chatbots.\n",
|
|
"- The dynamic variant of PPO not only improves performance but also accelerates convergence, addressing the limitations of traditional policy gradient methods and providing a robust framework for reinforcement learning in sequence generation."
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.Markdown object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"display(Markdown(chat_test_response.choices[0].message.content))\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "dab07e98",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Configure Agent\n",
|
|
"\n",
|
|
"We'll create our agent in this step, including a ```Conversation``` class to support multiple turns with the API, and some Python functions to enable interaction between the ```ChatCompletion``` API and our knowledge base functions."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "77a6fb4f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
|
|
"def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n",
|
|
" try:\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=model,\n",
|
|
" messages=messages,\n",
|
|
" functions=functions,\n",
|
|
" )\n",
|
|
" return response\n",
|
|
" except Exception as e:\n",
|
|
" print(\"Unable to generate ChatCompletion response\")\n",
|
|
" print(f\"Exception: {e}\")\n",
|
|
" return e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "73f7672d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Conversation:\n",
|
|
" def __init__(self):\n",
|
|
" self.conversation_history = []\n",
|
|
"\n",
|
|
" def add_message(self, role, content):\n",
|
|
" message = {\"role\": role, \"content\": content}\n",
|
|
" self.conversation_history.append(message)\n",
|
|
"\n",
|
|
" def display_conversation(self, detailed=False):\n",
|
|
" role_to_color = {\n",
|
|
" \"system\": \"red\",\n",
|
|
" \"user\": \"green\",\n",
|
|
" \"assistant\": \"blue\",\n",
|
|
" \"function\": \"magenta\",\n",
|
|
" }\n",
|
|
" for message in self.conversation_history:\n",
|
|
" print(\n",
|
|
" colored(\n",
|
|
" f\"{message['role']}: {message['content']}\\n\\n\",\n",
|
|
" role_to_color[message[\"role\"]],\n",
|
|
" )\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "978b7877",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initiate our get_articles and read_article_and_summarize functions\n",
|
|
"arxiv_functions = [\n",
|
|
" {\n",
|
|
" \"name\": \"get_articles\",\n",
|
|
" \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"query\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": f\"\"\"\n",
|
|
" User query in JSON. Responses should be summarized and should include the article URL reference\n",
|
|
" \"\"\",\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"query\"],\n",
|
|
" },\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"name\": \"read_article_and_summarize\",\n",
|
|
" \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n",
|
|
" You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n",
|
|
" \"parameters\": {\n",
|
|
" \"type\": \"object\",\n",
|
|
" \"properties\": {\n",
|
|
" \"query\": {\n",
|
|
" \"type\": \"string\",\n",
|
|
" \"description\": f\"\"\"\n",
|
|
" Description of the article in plain text based on the user's query\n",
|
|
" \"\"\",\n",
|
|
" }\n",
|
|
" },\n",
|
|
" \"required\": [\"query\"],\n",
|
|
" },\n",
|
|
" }\n",
|
|
"]\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "0c88ae15",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def chat_completion_with_function_execution(messages, functions=[None]):\n",
|
|
" \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n",
|
|
" response = chat_completion_request(messages, functions)\n",
|
|
" full_message = response.choices[0]\n",
|
|
" if full_message.finish_reason == \"function_call\":\n",
|
|
" print(f\"Function generation requested, calling function\")\n",
|
|
" return call_arxiv_function(messages, full_message)\n",
|
|
" else:\n",
|
|
" print(f\"Function not required, responding to user\")\n",
|
|
" return response\n",
|
|
"\n",
|
|
"\n",
|
|
"def call_arxiv_function(messages, full_message):\n",
|
|
" \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n",
|
|
" Currently extended by adding clauses to this if statement.\"\"\"\n",
|
|
"\n",
|
|
" if full_message.message.function_call.name == \"get_articles\":\n",
|
|
" try:\n",
|
|
" parsed_output = json.loads(\n",
|
|
" full_message.message.function_call.arguments\n",
|
|
" )\n",
|
|
" print(\"Getting search results\")\n",
|
|
" results = get_articles(parsed_output[\"query\"])\n",
|
|
" except Exception as e:\n",
|
|
" print(parsed_output)\n",
|
|
" print(f\"Function execution failed\")\n",
|
|
" print(f\"Error message: {e}\")\n",
|
|
" messages.append(\n",
|
|
" {\n",
|
|
" \"role\": \"function\",\n",
|
|
" \"name\": full_message.message.function_call.name,\n",
|
|
" \"content\": str(results),\n",
|
|
" }\n",
|
|
" )\n",
|
|
" try:\n",
|
|
" print(\"Got search results, summarizing content\")\n",
|
|
" response = chat_completion_request(messages)\n",
|
|
" return response\n",
|
|
" except Exception as e:\n",
|
|
" print(type(e))\n",
|
|
" raise Exception(\"Function chat request failed\")\n",
|
|
"\n",
|
|
" elif (\n",
|
|
" full_message.message.function_call.name == \"read_article_and_summarize\"\n",
|
|
" ):\n",
|
|
" parsed_output = json.loads(\n",
|
|
" full_message.message.function_call.arguments\n",
|
|
" )\n",
|
|
" print(\"Finding and reading paper\")\n",
|
|
" summary = summarize_text(parsed_output[\"query\"])\n",
|
|
" return summary\n",
|
|
"\n",
|
|
" else:\n",
|
|
" raise Exception(\"Function does not exist and cannot be called\")\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "dd3e7868",
|
|
"metadata": {},
|
|
"source": [
|
|
"## arXiv conversation\n",
|
|
"\n",
|
|
"Let's put this all together by testing our functions out in conversation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "c39a1d80",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start with a system message\n",
|
|
"paper_system_message = \"\"\"You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.\n",
|
|
"You summarize the papers clearly so the customer can decide which to read to answer their question.\n",
|
|
"You always provide the article_url and title so the user can understand the name of the paper and click through to access it.\n",
|
|
"Begin!\"\"\"\n",
|
|
"paper_conversation = Conversation()\n",
|
|
"paper_conversation.add_message(\"system\", paper_system_message)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "253fd0f7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Function generation requested, calling function\n",
|
|
"Getting search results\n",
|
|
"Got search results, summarizing content\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/markdown": [
|
|
"Here are some recent papers that discuss Proximal Policy Optimization (PPO) in reinforcement learning, explaining its mechanics and various enhancements:\n",
|
|
"\n",
|
|
"1. **[Proximal Policy Optimization and its Dynamic Version for Sequence Generation](http://arxiv.org/abs/1808.07982v1)** \n",
|
|
" - *Summary:* This paper applies PPO to sequence generation tasks, demonstrating that it outperforms traditional policy gradient methods in terms of stability and performance. It introduces a dynamic version of PPO for these tasks.\n",
|
|
" - [PDF](http://arxiv.org/pdf/1808.07982v1)\n",
|
|
"\n",
|
|
"2. **[CIM-PPO: Proximal Policy Optimization with Liu-Correntropy Induced Metric](http://arxiv.org/abs/2110.10522v3)** \n",
|
|
" - *Summary:* This work investigates the asymmetry in KL divergence in PPO-KL and proposes PPO-CIM as an enhanced version with lower computation costs and improved policy updates, validated through experiments on continuous-action tasks.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2110.10522v3)\n",
|
|
"\n",
|
|
"3. **[A2C is a special case of PPO](http://arxiv.org/abs/2205.09123v1)** \n",
|
|
" - *Summary:* This paper shows that A2C can be viewed as a special case of PPO, providing theoretical justifications and empirical evidence demonstrating their equivalence under controlled conditions.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2205.09123v1)\n",
|
|
"\n",
|
|
"4. **[Proximal Policy Optimization via Enhanced Exploration Efficiency](http://arxiv.org/abs/2011.05525v1)** \n",
|
|
" - *Summary:* This paper enhances the PPO algorithm by improving exploration strategies, proposing IEM-PPO, which shows better sample efficiency and rewards than standard methods in complex environments.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2011.05525v1)\n",
|
|
"\n",
|
|
"5. **[ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models](http://arxiv.org/abs/2310.10505v4)** \n",
|
|
" - *Summary:* The ReMax method is proposed as an alternative to PPO for training large language models, reducing hyper-parameter tuning complexities and enhancing training efficiency.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2310.10505v4)\n",
|
|
"\n",
|
|
"6. **[Reward Scale Robustness for Proximal Policy Optimization via DreamerV3 Tricks](http://arxiv.org/abs/2310.17805v1)** \n",
|
|
" - *Summary:* This work examines the applicability of DreamerV3's tricks to PPO, revealing mixed outcomes and providing insights into the clipping mechanism in PPO's performance.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2310.17805v1)\n",
|
|
"\n",
|
|
"7. **[Neural PPO-Clip Attains Global Optimality: A Hinge Loss Perspective](http://arxiv.org/abs/2110.13799v4)** \n",
|
|
" - *Summary:* This paper establishes a theoretical grounding for PPO-Clip and introduces new interpretive frameworks for its mechanics, showing improved convergence properties.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2110.13799v4)\n",
|
|
"\n",
|
|
"8. **[Colored Noise in PPO: Improved Exploration and Performance through Correlated Action Sampling](http://dx.doi.org/10.1609/aaai.v38i11.29139)** \n",
|
|
" - *Summary:* This study proposes a variant of PPO using correlated noise for improved exploration, demonstrating enhanced performance over traditional approaches.\n",
|
|
" - [PDF](http://arxiv.org/abs/2312.11091v2)\n",
|
|
"\n",
|
|
"9. **[A dynamical clipping approach with task feedback for Proximal Policy Optimization](http://arxiv.org/abs/2312.07624v3)** \n",
|
|
" - *Summary:* The paper presents Pb-PPO, which dynamically adjusts the clipping bounds in PPO to enhance returns, showing improved performance across various tasks.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2312.07624v3)\n",
|
|
"\n",
|
|
"10. **[PPO-UE: Proximal Policy Optimization via Uncertainty-Aware Exploration](http://arxiv.org/abs/2212.06343v1)** \n",
|
|
" - *Summary:* Introducing PPO-UE, which incorporates uncertainty-aware exploration, this paper shows improvements in convergence speed and performance compared to standard PPO.\n",
|
|
" - [PDF](http://arxiv.org/pdf/2212.06343v1)\n",
|
|
"\n",
|
|
"These papers provide a comprehensive view of the developments and enhancements in PPO and how it operates within the reinforcement learning framework. You can click on the titles to access the full articles."
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.Markdown object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Add a user message\n",
|
|
"paper_conversation.add_message(\"user\", \"Hi, how does PPO reinforcement learning work?\")\n",
|
|
"chat_response = chat_completion_with_function_execution(\n",
|
|
" paper_conversation.conversation_history, functions=arxiv_functions\n",
|
|
")\n",
|
|
"assistant_message = chat_response.choices[0].message.content\n",
|
|
"paper_conversation.add_message(\"assistant\", assistant_message)\n",
|
|
"display(Markdown(assistant_message))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "3ca3e18a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Function generation requested, calling function\n",
|
|
"Finding and reading paper\n",
|
|
"Existing papers found... Articles: 20\n",
|
|
"Chunking text from paper\n",
|
|
"Summarizing each chunk of text\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.21s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Summarizing into overall summary\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/markdown": [
|
|
"### Core Argument\n",
|
|
"- The paper argues for the adoption of Proximal Policy Optimization (PPO) and its dynamic variant (PPO-dynamic) as superior methods for sequence generation tasks, particularly in the context of chit-chat chatbots, compared to traditional policy gradient methods.\n",
|
|
"- It highlights the instability and suboptimal performance of traditional policy gradient methods, such as REINFORCE, and presents PPO as a more stable and efficient alternative.\n",
|
|
"\n",
|
|
"### Evidence\n",
|
|
"- **Challenges with Policy Gradient**: Traditional methods lead to unstable training and poor performance due to large updates and similar action tendencies, especially in non-differentiable evaluation metrics like BLEU scores.\n",
|
|
"- **PPO Advantages**: PPO regularizes policy updates, enhancing stability and coherence in chatbot responses.\n",
|
|
"- **Dynamic PPO Approach**: PPO-dynamic introduces dynamic adjustments to the KL-divergence bounds, allowing for more flexible and effective training.\n",
|
|
"- **Experimental Validation**: Experiments on synthetic tasks and real-world chit-chat scenarios demonstrate that PPO and PPO-dynamic outperform REINFORCE and other algorithms (like MIXER and SeqGAN) in terms of stability and performance metrics, including BLEU-2 scores.\n",
|
|
"- **Results**: PPO-dynamic showed significant improvements in precision on counting tasks and achieved the highest BLEU-2 score for chatbot responses, indicating better performance in generating diverse and accurate outputs.\n",
|
|
"\n",
|
|
"### Conclusions\n",
|
|
"- The paper concludes that replacing traditional policy gradient methods with PPO, particularly the dynamic version, leads to more stable training and faster convergence in sequence generation tasks.\n",
|
|
"- The proposed PPO-dynamic method enhances the training process by dynamically adjusting constraints, resulting in improved performance and efficiency in generating human-like conversational agents.\n",
|
|
"- Future research directions are suggested to further explore the potential of PPO and its adaptations in natural language processing applications."
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.Markdown object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Add another user message to induce our system to use the second tool\n",
|
|
"paper_conversation.add_message(\n",
|
|
" \"user\",\n",
|
|
" \"Can you read the PPO sequence generation paper for me and give me a summary\",\n",
|
|
")\n",
|
|
"updated_response = chat_completion_with_function_execution(\n",
|
|
" paper_conversation.conversation_history, functions=arxiv_functions\n",
|
|
")\n",
|
|
"display(Markdown(updated_response.choices[0].message.content))\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|