openai-cookbook/examples/How_to_call_functions_for_knowledge_retrieval.ipynb
2025-02-03 13:43:26 -08:00

797 lines
35 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "3e67f200",
"metadata": {},
"source": [
"# How to use functions with a knowledge base\n",
"\n",
"This notebook builds on the concepts in the [argument generation](How_to_call_functions_with_chat_models.ipynb) notebook, by creating an agent with access to a knowledge base and two functions that it can call based on the user requirement.\n",
"\n",
"We'll create an agent that uses data from arXiv to answer questions about academic subjects. It has two functions at its disposal:\n",
"- **get_articles**: A function that gets arXiv articles on a subject and summarizes them for the user with links.\n",
"- **read_article_and_summarize**: This function takes one of the previously searched articles, reads it in its entirety and summarizes the core argument, evidence and conclusions.\n",
"\n",
"This will get you comfortable with a multi-function workflow that can choose from multiple services, and where some of the data from the first function is persisted to be used by the second.\n",
"\n",
"## Walkthrough\n",
"\n",
"This cookbook takes you through the following workflow:\n",
"\n",
"- **Search utilities:** Creating the two functions that access arXiv for answers.\n",
"- **Configure Agent:** Building up the Agent behaviour that will assess the need for a function and, if one is required, call that function and present results back to the agent.\n",
"- **arXiv conversation:** Put all of this together in live conversation.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "80e71f33",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"!pip install scipy --quiet\n",
"!pip install tenacity --quiet\n",
"!pip install tiktoken==0.3.3 --quiet\n",
"!pip install termcolor --quiet\n",
"!pip install openai --quiet\n",
"!pip install arxiv --quiet\n",
"!pip install pandas --quiet\n",
"!pip install PyPDF2 --quiet\n",
"!pip install tqdm --quiet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "dab872c5",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import arxiv\n",
"import ast\n",
"import concurrent\n",
"import json\n",
"import os\n",
"import pandas as pd\n",
"import tiktoken\n",
"from csv import writer\n",
"from IPython.display import display, Markdown, Latex\n",
"from openai import OpenAI\n",
"from PyPDF2 import PdfReader\n",
"from scipy import spatial\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from tqdm import tqdm\n",
"from termcolor import colored\n",
"\n",
"GPT_MODEL = \"gpt-4o-mini\"\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"client = OpenAI()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f2e47962",
"metadata": {},
"source": [
"## Search utilities\n",
"\n",
"We'll first set up some utilities that will underpin our two functions.\n",
"\n",
"Downloaded papers will be stored in a directory (we use ```./data/papers``` here). We create a file ```arxiv_library.csv``` to store the embeddings and details for downloaded papers to retrieve against using ```summarize_text```."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2de5d32d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Directory './data/papers' already exists.\n"
]
}
],
"source": [
"directory = './data/papers'\n",
"\n",
"# Check if the directory already exists\n",
"if not os.path.exists(directory):\n",
" # If the directory doesn't exist, create it and any necessary intermediate directories\n",
" os.makedirs(directory)\n",
" print(f\"Directory '{directory}' created successfully.\")\n",
"else:\n",
" # If the directory already exists, print a message indicating it\n",
" print(f\"Directory '{directory}' already exists.\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ae5cb7a1",
"metadata": {},
"outputs": [],
"source": [
"# Set a directory to store downloaded papers\n",
"data_dir = os.path.join(os.curdir, \"data\", \"papers\")\n",
"paper_dir_filepath = \"./data/papers/arxiv_library.csv\"\n",
"\n",
"# Generate a blank dataframe where we can store downloaded files\n",
"df = pd.DataFrame(list())\n",
"df.to_csv(paper_dir_filepath)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "57217b9d",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def embedding_request(text):\n",
" response = client.embeddings.create(input=text, model=EMBEDDING_MODEL)\n",
" return response\n",
"\n",
"\n",
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def get_articles(query, library=paper_dir_filepath, top_k=10):\n",
" \"\"\"This function gets the top_k articles based on a user's query, sorted by relevance.\n",
" It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.\n",
" \"\"\"\n",
" client = arxiv.Client()\n",
" search = arxiv.Search(\n",
" query = query,\n",
" max_results = top_k\n",
" )\n",
" result_list = []\n",
" for result in client.results(search):\n",
" result_dict = {}\n",
" result_dict.update({\"title\": result.title})\n",
" result_dict.update({\"summary\": result.summary})\n",
"\n",
" # Taking the first url provided\n",
" result_dict.update({\"article_url\": [x.href for x in result.links][0]})\n",
" result_dict.update({\"pdf_url\": [x.href for x in result.links][1]})\n",
" result_list.append(result_dict)\n",
"\n",
" # Store references in library file\n",
" response = embedding_request(text=result.title)\n",
" file_reference = [\n",
" result.title,\n",
" result.download_pdf(data_dir),\n",
" response.data[0].embedding,\n",
" ]\n",
"\n",
" # Write to file\n",
" with open(library, \"a\") as f_object:\n",
" writer_object = writer(f_object)\n",
" writer_object.writerow(file_reference)\n",
" f_object.close()\n",
" return result_list\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dda02bdb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'Proximal Policy Optimization and its Dynamic Version for Sequence Generation',\n",
" 'summary': 'In sequence generation task, many works use policy gradient for model\\noptimization to tackle the intractable backpropagation issue when maximizing\\nthe non-differentiable evaluation metrics or fooling the discriminator in\\nadversarial learning. In this paper, we replace policy gradient with proximal\\npolicy optimization (PPO), which is a proved more efficient reinforcement\\nlearning algorithm, and propose a dynamic approach for PPO (PPO-dynamic). We\\ndemonstrate the efficacy of PPO and PPO-dynamic on conditional sequence\\ngeneration tasks including synthetic experiment and chit-chat chatbot. The\\nresults show that PPO and PPO-dynamic can beat policy gradient by stability and\\nperformance.',\n",
" 'article_url': 'http://arxiv.org/abs/1808.07982v1',\n",
" 'pdf_url': 'http://arxiv.org/pdf/1808.07982v1'}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test that the search is working\n",
"result_output = get_articles(\"ppo reinforcement learning\")\n",
"result_output[0]\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "11675627",
"metadata": {},
"outputs": [],
"source": [
"def strings_ranked_by_relatedness(\n",
" query: str,\n",
" df: pd.DataFrame,\n",
" relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),\n",
" top_n: int = 100,\n",
") -> list[str]:\n",
" \"\"\"Returns a list of strings and relatednesses, sorted from most related to least.\"\"\"\n",
" query_embedding_response = embedding_request(query)\n",
" query_embedding = query_embedding_response.data[0].embedding\n",
" strings_and_relatednesses = [\n",
" (row[\"filepath\"], relatedness_fn(query_embedding, row[\"embedding\"]))\n",
" for i, row in df.iterrows()\n",
" ]\n",
" strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)\n",
" strings, relatednesses = zip(*strings_and_relatednesses)\n",
" return strings[:top_n]\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7211df2c",
"metadata": {},
"outputs": [],
"source": [
"def read_pdf(filepath):\n",
" \"\"\"Takes a filepath to a PDF and returns a string of the PDF's contents\"\"\"\n",
" # creating a pdf reader object\n",
" reader = PdfReader(filepath)\n",
" pdf_text = \"\"\n",
" page_number = 0\n",
" for page in reader.pages:\n",
" page_number += 1\n",
" pdf_text += page.extract_text() + f\"\\nPage Number: {page_number}\"\n",
" return pdf_text\n",
"\n",
"\n",
"# Split a text into smaller chunks of size n, preferably ending at the end of a sentence\n",
"def create_chunks(text, n, tokenizer):\n",
" \"\"\"Returns successive n-sized chunks from provided text.\"\"\"\n",
" tokens = tokenizer.encode(text)\n",
" i = 0\n",
" while i < len(tokens):\n",
" # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens\n",
" j = min(i + int(1.5 * n), len(tokens))\n",
" while j > i + int(0.5 * n):\n",
" # Decode the tokens and check for full stop or newline\n",
" chunk = tokenizer.decode(tokens[i:j])\n",
" if chunk.endswith(\".\") or chunk.endswith(\"\\n\"):\n",
" break\n",
" j -= 1\n",
" # If no end of sentence found, use n tokens as the chunk size\n",
" if j == i + int(0.5 * n):\n",
" j = min(i + n, len(tokens))\n",
" yield tokens[i:j]\n",
" i = j\n",
"\n",
"\n",
"def extract_chunk(content, template_prompt):\n",
" \"\"\"This function applies a prompt to some input content. In this case it returns a summarized chunk of text\"\"\"\n",
" prompt = template_prompt + content\n",
" response = client.chat.completions.create(\n",
" model=GPT_MODEL, messages=[{\"role\": \"user\", \"content\": prompt}], temperature=0\n",
" )\n",
" return response.choices[0].message.content\n",
"\n",
"\n",
"def summarize_text(query):\n",
" \"\"\"This function does the following:\n",
" - Reads in the arxiv_library.csv file in including the embeddings\n",
" - Finds the closest file to the user's query\n",
" - Scrapes the text out of the file and chunks it\n",
" - Summarizes each chunk in parallel\n",
" - Does one final summary and returns this to the user\"\"\"\n",
"\n",
" # A prompt to dictate how the recursive summarizations should approach the input paper\n",
" summary_prompt = \"\"\"Summarize this text from an academic paper. Extract any key points with reasoning.\\n\\nContent:\"\"\"\n",
"\n",
" # If the library is empty (no searches have been performed yet), we perform one and download the results\n",
" library_df = pd.read_csv(paper_dir_filepath).reset_index()\n",
" if len(library_df) == 0:\n",
" print(\"No papers searched yet, downloading first.\")\n",
" get_articles(query)\n",
" print(\"Papers downloaded, continuing\")\n",
" library_df = pd.read_csv(paper_dir_filepath).reset_index()\n",
" else:\n",
" print(\"Existing papers found... Articles:\", len(library_df))\n",
" library_df.columns = [\"title\", \"filepath\", \"embedding\"]\n",
" library_df[\"embedding\"] = library_df[\"embedding\"].apply(ast.literal_eval)\n",
" strings = strings_ranked_by_relatedness(query, library_df, top_n=1)\n",
" print(\"Chunking text from paper\")\n",
" pdf_text = read_pdf(strings[0])\n",
"\n",
" # Initialise tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n",
" results = \"\"\n",
"\n",
" # Chunk up the document into 1500 token chunks\n",
" chunks = create_chunks(pdf_text, 1500, tokenizer)\n",
" text_chunks = [tokenizer.decode(chunk) for chunk in chunks]\n",
" print(\"Summarizing each chunk of text\")\n",
"\n",
" # Parallel process the summaries\n",
" with concurrent.futures.ThreadPoolExecutor(\n",
" max_workers=len(text_chunks)\n",
" ) as executor:\n",
" futures = [\n",
" executor.submit(extract_chunk, chunk, summary_prompt)\n",
" for chunk in text_chunks\n",
" ]\n",
" with tqdm(total=len(text_chunks)) as pbar:\n",
" for _ in concurrent.futures.as_completed(futures):\n",
" pbar.update(1)\n",
" for future in futures:\n",
" data = future.result()\n",
" results += data\n",
"\n",
" # Final summary\n",
" print(\"Summarizing into overall summary\")\n",
" response = client.chat.completions.create(\n",
" model=GPT_MODEL,\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"\"\"Write a summary collated from this collection of key points extracted from an academic paper.\n",
" The summary should highlight the core argument, conclusions and evidence, and answer the user's query.\n",
" User query: {query}\n",
" The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.\n",
" Key points:\\n{results}\\nSummary:\\n\"\"\",\n",
" }\n",
" ],\n",
" temperature=0,\n",
" )\n",
" return response\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "898b94d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Existing papers found... Articles: 10\n",
"Chunking text from paper\n",
"Summarizing each chunk of text\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.40s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summarizing into overall summary\n"
]
}
],
"source": [
"# Test the summarize_text function works\n",
"chat_test_response = summarize_text(\"PPO reinforcement learning sequence generation\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c715f60d",
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"### Core Argument\n",
"- The paper argues that Proximal Policy Optimization (PPO) and its dynamic variant (PPO-dynamic) significantly improve sequence generation tasks, particularly for chit-chat chatbots, by addressing the instability and suboptimal performance associated with traditional policy gradient methods.\n",
"\n",
"### Evidence\n",
"- **Challenges with Traditional Methods**: Traditional policy gradient methods, like REINFORCE, suffer from unstable training and poor performance due to large updates and similar action tendencies, especially in non-differentiable evaluation contexts (e.g., BLEU scores).\n",
"- **PPO Advantages**: PPO regularizes policy updates, enhancing training stability and enabling the generation of coherent and diverse chatbot responses.\n",
"- **Dynamic PPO Approach**: PPO-dynamic introduces adaptive constraints on KL-divergence, allowing for dynamic adjustments based on action probabilities, which leads to improved training performance.\n",
"- **Experimental Validation**: The authors conducted experiments on synthetic counting tasks and real-world chit-chat scenarios, demonstrating that PPO and PPO-dynamic outperform traditional methods like REINFORCE and SeqGAN in terms of stability and performance metrics (e.g., BLEU-2 scores).\n",
"- **Results**: PPO-dynamic showed faster convergence and higher precision in the counting task, and it achieved the best performance in the chit-chat task, indicating its effectiveness in generating diverse and contextually appropriate responses.\n",
"\n",
"### Conclusions\n",
"- The introduction of PPO and PPO-dynamic enhances the training stability and output diversity in sequence generation tasks, making them more suitable for applications like chatbots.\n",
"- The dynamic variant of PPO not only improves performance but also accelerates convergence, addressing the limitations of traditional policy gradient methods and providing a robust framework for reinforcement learning in sequence generation."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(Markdown(chat_test_response.choices[0].message.content))\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "dab07e98",
"metadata": {},
"source": [
"## Configure Agent\n",
"\n",
"We'll create our agent in this step, including a ```Conversation``` class to support multiple turns with the API, and some Python functions to enable interaction between the ```ChatCompletion``` API and our knowledge base functions."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "77a6fb4f",
"metadata": {},
"outputs": [],
"source": [
"@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))\n",
"def chat_completion_request(messages, functions=None, model=GPT_MODEL):\n",
" try:\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" messages=messages,\n",
" functions=functions,\n",
" )\n",
" return response\n",
" except Exception as e:\n",
" print(\"Unable to generate ChatCompletion response\")\n",
" print(f\"Exception: {e}\")\n",
" return e\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "73f7672d",
"metadata": {},
"outputs": [],
"source": [
"class Conversation:\n",
" def __init__(self):\n",
" self.conversation_history = []\n",
"\n",
" def add_message(self, role, content):\n",
" message = {\"role\": role, \"content\": content}\n",
" self.conversation_history.append(message)\n",
"\n",
" def display_conversation(self, detailed=False):\n",
" role_to_color = {\n",
" \"system\": \"red\",\n",
" \"user\": \"green\",\n",
" \"assistant\": \"blue\",\n",
" \"function\": \"magenta\",\n",
" }\n",
" for message in self.conversation_history:\n",
" print(\n",
" colored(\n",
" f\"{message['role']}: {message['content']}\\n\\n\",\n",
" role_to_color[message[\"role\"]],\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "978b7877",
"metadata": {},
"outputs": [],
"source": [
"# Initiate our get_articles and read_article_and_summarize functions\n",
"arxiv_functions = [\n",
" {\n",
" \"name\": \"get_articles\",\n",
" \"description\": \"\"\"Use this function to get academic papers from arXiv to answer user questions.\"\"\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": f\"\"\"\n",
" User query in JSON. Responses should be summarized and should include the article URL reference\n",
" \"\"\",\n",
" }\n",
" },\n",
" \"required\": [\"query\"],\n",
" },\n",
" },\n",
" {\n",
" \"name\": \"read_article_and_summarize\",\n",
" \"description\": \"\"\"Use this function to read whole papers and provide a summary for users.\n",
" You should NEVER call this function before get_articles has been called in the conversation.\"\"\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": f\"\"\"\n",
" Description of the article in plain text based on the user's query\n",
" \"\"\",\n",
" }\n",
" },\n",
" \"required\": [\"query\"],\n",
" },\n",
" }\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0c88ae15",
"metadata": {},
"outputs": [],
"source": [
"def chat_completion_with_function_execution(messages, functions=[None]):\n",
" \"\"\"This function makes a ChatCompletion API call with the option of adding functions\"\"\"\n",
" response = chat_completion_request(messages, functions)\n",
" full_message = response.choices[0]\n",
" if full_message.finish_reason == \"function_call\":\n",
" print(f\"Function generation requested, calling function\")\n",
" return call_arxiv_function(messages, full_message)\n",
" else:\n",
" print(f\"Function not required, responding to user\")\n",
" return response\n",
"\n",
"\n",
"def call_arxiv_function(messages, full_message):\n",
" \"\"\"Function calling function which executes function calls when the model believes it is necessary.\n",
" Currently extended by adding clauses to this if statement.\"\"\"\n",
"\n",
" if full_message.message.function_call.name == \"get_articles\":\n",
" try:\n",
" parsed_output = json.loads(\n",
" full_message.message.function_call.arguments\n",
" )\n",
" print(\"Getting search results\")\n",
" results = get_articles(parsed_output[\"query\"])\n",
" except Exception as e:\n",
" print(parsed_output)\n",
" print(f\"Function execution failed\")\n",
" print(f\"Error message: {e}\")\n",
" messages.append(\n",
" {\n",
" \"role\": \"function\",\n",
" \"name\": full_message.message.function_call.name,\n",
" \"content\": str(results),\n",
" }\n",
" )\n",
" try:\n",
" print(\"Got search results, summarizing content\")\n",
" response = chat_completion_request(messages)\n",
" return response\n",
" except Exception as e:\n",
" print(type(e))\n",
" raise Exception(\"Function chat request failed\")\n",
"\n",
" elif (\n",
" full_message.message.function_call.name == \"read_article_and_summarize\"\n",
" ):\n",
" parsed_output = json.loads(\n",
" full_message.message.function_call.arguments\n",
" )\n",
" print(\"Finding and reading paper\")\n",
" summary = summarize_text(parsed_output[\"query\"])\n",
" return summary\n",
"\n",
" else:\n",
" raise Exception(\"Function does not exist and cannot be called\")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "dd3e7868",
"metadata": {},
"source": [
"## arXiv conversation\n",
"\n",
"Let's put this all together by testing our functions out in conversation."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c39a1d80",
"metadata": {},
"outputs": [],
"source": [
"# Start with a system message\n",
"paper_system_message = \"\"\"You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.\n",
"You summarize the papers clearly so the customer can decide which to read to answer their question.\n",
"You always provide the article_url and title so the user can understand the name of the paper and click through to access it.\n",
"Begin!\"\"\"\n",
"paper_conversation = Conversation()\n",
"paper_conversation.add_message(\"system\", paper_system_message)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "253fd0f7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Function generation requested, calling function\n",
"Getting search results\n",
"Got search results, summarizing content\n"
]
},
{
"data": {
"text/markdown": [
"Here are some recent papers that discuss Proximal Policy Optimization (PPO) in reinforcement learning, explaining its mechanics and various enhancements:\n",
"\n",
"1. **[Proximal Policy Optimization and its Dynamic Version for Sequence Generation](http://arxiv.org/abs/1808.07982v1)** \n",
" - *Summary:* This paper applies PPO to sequence generation tasks, demonstrating that it outperforms traditional policy gradient methods in terms of stability and performance. It introduces a dynamic version of PPO for these tasks.\n",
" - [PDF](http://arxiv.org/pdf/1808.07982v1)\n",
"\n",
"2. **[CIM-PPO: Proximal Policy Optimization with Liu-Correntropy Induced Metric](http://arxiv.org/abs/2110.10522v3)** \n",
" - *Summary:* This work investigates the asymmetry in KL divergence in PPO-KL and proposes PPO-CIM as an enhanced version with lower computation costs and improved policy updates, validated through experiments on continuous-action tasks.\n",
" - [PDF](http://arxiv.org/pdf/2110.10522v3)\n",
"\n",
"3. **[A2C is a special case of PPO](http://arxiv.org/abs/2205.09123v1)** \n",
" - *Summary:* This paper shows that A2C can be viewed as a special case of PPO, providing theoretical justifications and empirical evidence demonstrating their equivalence under controlled conditions.\n",
" - [PDF](http://arxiv.org/pdf/2205.09123v1)\n",
"\n",
"4. **[Proximal Policy Optimization via Enhanced Exploration Efficiency](http://arxiv.org/abs/2011.05525v1)** \n",
" - *Summary:* This paper enhances the PPO algorithm by improving exploration strategies, proposing IEM-PPO, which shows better sample efficiency and rewards than standard methods in complex environments.\n",
" - [PDF](http://arxiv.org/pdf/2011.05525v1)\n",
"\n",
"5. **[ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models](http://arxiv.org/abs/2310.10505v4)** \n",
" - *Summary:* The ReMax method is proposed as an alternative to PPO for training large language models, reducing hyper-parameter tuning complexities and enhancing training efficiency.\n",
" - [PDF](http://arxiv.org/pdf/2310.10505v4)\n",
"\n",
"6. **[Reward Scale Robustness for Proximal Policy Optimization via DreamerV3 Tricks](http://arxiv.org/abs/2310.17805v1)** \n",
" - *Summary:* This work examines the applicability of DreamerV3's tricks to PPO, revealing mixed outcomes and providing insights into the clipping mechanism in PPO's performance.\n",
" - [PDF](http://arxiv.org/pdf/2310.17805v1)\n",
"\n",
"7. **[Neural PPO-Clip Attains Global Optimality: A Hinge Loss Perspective](http://arxiv.org/abs/2110.13799v4)** \n",
" - *Summary:* This paper establishes a theoretical grounding for PPO-Clip and introduces new interpretive frameworks for its mechanics, showing improved convergence properties.\n",
" - [PDF](http://arxiv.org/pdf/2110.13799v4)\n",
"\n",
"8. **[Colored Noise in PPO: Improved Exploration and Performance through Correlated Action Sampling](http://dx.doi.org/10.1609/aaai.v38i11.29139)** \n",
" - *Summary:* This study proposes a variant of PPO using correlated noise for improved exploration, demonstrating enhanced performance over traditional approaches.\n",
" - [PDF](http://arxiv.org/abs/2312.11091v2)\n",
"\n",
"9. **[A dynamical clipping approach with task feedback for Proximal Policy Optimization](http://arxiv.org/abs/2312.07624v3)** \n",
" - *Summary:* The paper presents Pb-PPO, which dynamically adjusts the clipping bounds in PPO to enhance returns, showing improved performance across various tasks.\n",
" - [PDF](http://arxiv.org/pdf/2312.07624v3)\n",
"\n",
"10. **[PPO-UE: Proximal Policy Optimization via Uncertainty-Aware Exploration](http://arxiv.org/abs/2212.06343v1)** \n",
" - *Summary:* Introducing PPO-UE, which incorporates uncertainty-aware exploration, this paper shows improvements in convergence speed and performance compared to standard PPO.\n",
" - [PDF](http://arxiv.org/pdf/2212.06343v1)\n",
"\n",
"These papers provide a comprehensive view of the developments and enhancements in PPO and how it operates within the reinforcement learning framework. You can click on the titles to access the full articles."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Add a user message\n",
"paper_conversation.add_message(\"user\", \"Hi, how does PPO reinforcement learning work?\")\n",
"chat_response = chat_completion_with_function_execution(\n",
" paper_conversation.conversation_history, functions=arxiv_functions\n",
")\n",
"assistant_message = chat_response.choices[0].message.content\n",
"paper_conversation.add_message(\"assistant\", assistant_message)\n",
"display(Markdown(assistant_message))\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3ca3e18a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Function generation requested, calling function\n",
"Finding and reading paper\n",
"Existing papers found... Articles: 20\n",
"Chunking text from paper\n",
"Summarizing each chunk of text\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.21s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summarizing into overall summary\n"
]
},
{
"data": {
"text/markdown": [
"### Core Argument\n",
"- The paper argues for the adoption of Proximal Policy Optimization (PPO) and its dynamic variant (PPO-dynamic) as superior methods for sequence generation tasks, particularly in the context of chit-chat chatbots, compared to traditional policy gradient methods.\n",
"- It highlights the instability and suboptimal performance of traditional policy gradient methods, such as REINFORCE, and presents PPO as a more stable and efficient alternative.\n",
"\n",
"### Evidence\n",
"- **Challenges with Policy Gradient**: Traditional methods lead to unstable training and poor performance due to large updates and similar action tendencies, especially in non-differentiable evaluation metrics like BLEU scores.\n",
"- **PPO Advantages**: PPO regularizes policy updates, enhancing stability and coherence in chatbot responses.\n",
"- **Dynamic PPO Approach**: PPO-dynamic introduces dynamic adjustments to the KL-divergence bounds, allowing for more flexible and effective training.\n",
"- **Experimental Validation**: Experiments on synthetic tasks and real-world chit-chat scenarios demonstrate that PPO and PPO-dynamic outperform REINFORCE and other algorithms (like MIXER and SeqGAN) in terms of stability and performance metrics, including BLEU-2 scores.\n",
"- **Results**: PPO-dynamic showed significant improvements in precision on counting tasks and achieved the highest BLEU-2 score for chatbot responses, indicating better performance in generating diverse and accurate outputs.\n",
"\n",
"### Conclusions\n",
"- The paper concludes that replacing traditional policy gradient methods with PPO, particularly the dynamic version, leads to more stable training and faster convergence in sequence generation tasks.\n",
"- The proposed PPO-dynamic method enhances the training process by dynamically adjusting constraints, resulting in improved performance and efficiency in generating human-like conversational agents.\n",
"- Future research directions are suggested to further explore the potential of PPO and its adaptations in natural language processing applications."
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Add another user message to induce our system to use the second tool\n",
"paper_conversation.add_message(\n",
" \"user\",\n",
" \"Can you read the PPO sequence generation paper for me and give me a summary\",\n",
")\n",
"updated_response = chat_completion_with_function_execution(\n",
" paper_conversation.conversation_history, functions=arxiv_functions\n",
")\n",
"display(Markdown(updated_response.choices[0].message.content))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}