Update FT'ing for function calling notebook to match new python SDK (#866)

This commit is contained in:
jhills20 2023-11-27 16:46:21 -05:00 committed by GitHub
parent bcbb505a4a
commit 16e5a1c2f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -82,7 +82,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -90,12 +90,12 @@
"import numpy as np\n", "import numpy as np\n",
"import json\n", "import json\n",
"import os\n", "import os\n",
"import openai\n", "from openai import OpenAI\n",
"import itertools\n", "import itertools\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n", "from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from typing import Any, Dict, List, Generator\n", "from typing import Any, Dict, List, Generator\n",
"import ast\n", "import ast\n",
"openai.api_key = os.getenv('OPENAI_API_KEY')\n" "client = OpenAI()\n"
] ]
}, },
{ {
@ -114,7 +114,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -137,7 +137,7 @@
" if functions:\n", " if functions:\n",
" params['functions'] = functions\n", " params['functions'] = functions\n",
"\n", "\n",
" completion = openai.ChatCompletion.create(**params)\n", " completion = client.chat.completions.create(**params)\n",
" return completion.choices[0].message\n" " return completion.choices[0].message\n"
] ]
}, },
@ -159,7 +159,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -177,7 +177,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -430,7 +430,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -442,9 +442,28 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Land the drone at the home base\n",
"FunctionCall(arguments='{\\n \"location\": \"home_base\"\\n}', name='land_drone') \n",
"\n",
"Take off the drone to 50 meters\n",
"FunctionCall(arguments='{\\n \"altitude\": 50\\n}', name='takeoff_drone') \n",
"\n",
"change speed to 15 kilometers per hour\n",
"FunctionCall(arguments='{\\n \"speed\": 15\\n}', name='set_drone_speed') \n",
"\n",
"turn into an elephant!\n",
"FunctionCall(arguments='{}', name='reject_request') \n",
"\n"
]
}
],
"source": [ "source": [
"for prompt in straightforward_prompts:\n", "for prompt in straightforward_prompts:\n",
" messages = []\n", " messages = []\n",
@ -464,7 +483,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -477,9 +496,36 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Play pre-recorded audio message\n",
"FunctionCall(arguments='{}', name='reject_request')\n",
"\n",
"\n",
"Initiate live-streaming on social media\n",
"FunctionCall(arguments='{\\n\"mode\": \"video\",\\n\"duration\": 0\\n}', name='control_camera')\n",
"\n",
"\n",
"Scan environment for heat signatures\n",
"None\n",
"\n",
"\n",
"Enable stealth mode\n",
"FunctionCall(arguments='{\\n \"mode\": \"off\"\\n}', name='set_drone_lighting')\n",
"\n",
"\n",
"Change drone's paint job color\n",
"FunctionCall(arguments='{\\n \"pattern\": \"solid\",\\n \"color\": \"blue\"\\n}', name='configure_led_display')\n",
"\n",
"\n"
]
}
],
"source": [ "source": [
"for prompt in challenging_prompts:\n", "for prompt in challenging_prompts:\n",
" messages = []\n", " messages = []\n",
@ -537,7 +583,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -557,7 +603,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -664,7 +710,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -734,7 +780,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -784,7 +830,7 @@
" request_prompt = COMMAND_GENERATION_PROMPT.format(invocation=invocation)\n", " request_prompt = COMMAND_GENERATION_PROMPT.format(invocation=invocation)\n",
"\n", "\n",
" messages = [{\"role\": \"user\", \"content\": f\"{request_prompt}\"}]\n", " messages = [{\"role\": \"user\", \"content\": f\"{request_prompt}\"}]\n",
" completion = get_chat_completion(messages,temperature=0.8)\n", " completion = get_chat_completion(messages,temperature=0.8).content\n",
" command_dict = {\n", " command_dict = {\n",
" \"Input\": invocation,\n", " \"Input\": invocation,\n",
" \"Prompt\": completion\n", " \"Prompt\": completion\n",
@ -926,13 +972,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"if __name__ == \"__main__\":\n", "if __name__ == \"__main__\":\n",
" file = openai.File.create(\n", " file = client.files.create(\n",
" file=open(training_file, \"rb\"),\n", " file=open(training_file, \"rb\"),\n",
" purpose=\"fine-tune\",\n", " purpose=\"fine-tune\",\n",
" )\n", " )\n",
" file_id = file.id\n", " file_id = file.id\n",
" print(file_id)\n", " print(file_id)\n",
" ft = openai.FineTuningJob.create(\n", " ft = client.fine_tuning.jobs.create(\n",
" # model=\"gpt-4-0613\",\n", " # model=\"gpt-4-0613\",\n",
" model=\"gpt-3.5-turbo\",\n", " model=\"gpt-3.5-turbo\",\n",
" training_file=file_id,\n", " training_file=file_id,\n",
@ -980,7 +1026,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Conclustion" "### Conclusion"
] ]
}, },
{ {