mirror of
https://github.com/james-m-jordan/openai-cookbook.git
synced 2025-05-09 19:32:38 +00:00
Minor bug fix (#1278)
This commit is contained in:
parent
f854f6f0ca
commit
a254b498a4
@ -51,13 +51,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "dab872c5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:23.563149Z",
|
||||
"start_time": "2024-05-15T17:45:22.925978Z"
|
||||
"end_time": "2024-07-12T22:41:58.148850Z",
|
||||
"start_time": "2024-07-12T22:41:58.133412Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from openai import OpenAI\n",
|
||||
@ -66,9 +68,7 @@
|
||||
"\n",
|
||||
"GPT_MODEL = \"gpt-4o\"\n",
|
||||
"client = OpenAI()"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
@ -83,13 +83,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "745ceec5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:28.816345Z",
|
||||
"start_time": "2024-05-15T17:45:28.814155Z"
|
||||
"end_time": "2024-07-12T22:41:59.531820Z",
|
||||
"start_time": "2024-07-12T22:41:59.529870Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n",
|
||||
"def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):\n",
|
||||
@ -105,19 +107,19 @@
|
||||
" print(\"Unable to generate ChatCompletion response\")\n",
|
||||
" print(f\"Exception: {e}\")\n",
|
||||
" return e\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 3
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "c4d1c99f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:30.003910Z",
|
||||
"start_time": "2024-05-15T17:45:30.001259Z"
|
||||
"end_time": "2024-07-12T22:42:00.463896Z",
|
||||
"start_time": "2024-07-12T22:42:00.461258Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def pretty_print_conversation(messages):\n",
|
||||
" role_to_color = {\n",
|
||||
@ -138,9 +140,7 @@
|
||||
" print(colored(f\"assistant: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
|
||||
" elif message[\"role\"] == \"function\":\n",
|
||||
" print(colored(f\"function ({message['name']}): {message['content']}\\n\", role_to_color[message[\"role\"]]))\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 4
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
@ -155,13 +155,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "d2e25069",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:31.794879Z",
|
||||
"start_time": "2024-05-15T17:45:31.792617Z"
|
||||
"end_time": "2024-07-12T22:42:01.676606Z",
|
||||
"start_time": "2024-07-12T22:42:01.674348Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
@ -213,9 +215,7 @@
|
||||
" }\n",
|
||||
" },\n",
|
||||
"]"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 5
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
@ -228,13 +228,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "518d6827",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:35.282310Z",
|
||||
"start_time": "2024-05-15T17:45:33.861496Z"
|
||||
"end_time": "2024-07-12T22:42:03.726604Z",
|
||||
"start_time": "2024-07-12T22:42:03.154689Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content='Sure, can you please provide me with the name of your city and state?', role='assistant', function_call=None, tool_calls=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = []\n",
|
||||
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
||||
@ -245,21 +258,8 @@
|
||||
"assistant_message = chat_response.choices[0].message\n",
|
||||
"messages.append(assistant_message)\n",
|
||||
"assistant_message\n"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=\"I need to know your location to provide you with the current weather. Could you please specify the city and state (or country) you're in?\", role='assistant', function_call=None, tool_calls=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 6
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -271,13 +271,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "23c42a6e",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:43.553403Z",
|
||||
"start_time": "2024-05-15T17:45:42.205590Z"
|
||||
"end_time": "2024-07-12T22:42:05.778263Z",
|
||||
"start_time": "2024-07-12T22:42:05.277346Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_xb7QwwNnx90LkmhtlW0YrgP2', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\"}', name='get_current_weather'), type='function')])"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages.append({\"role\": \"user\", \"content\": \"I'm in Glasgow, Scotland.\"})\n",
|
||||
"chat_response = chat_completion_request(\n",
|
||||
@ -286,21 +299,8 @@
|
||||
"assistant_message = chat_response.choices[0].message\n",
|
||||
"messages.append(assistant_message)\n",
|
||||
"assistant_message\n"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Dn2RJJSxzDm49vlVTehseJ0k', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\"}', name='get_current_weather'), type='function')])"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 7
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -312,13 +312,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "fa232e54",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:47.090638Z",
|
||||
"start_time": "2024-05-15T17:45:46.302475Z"
|
||||
"end_time": "2024-07-12T22:42:07.575820Z",
|
||||
"start_time": "2024-07-12T22:42:07.018764Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content='To provide you with the weather forecast for Glasgow, Scotland, could you please specify the number of days you would like the forecast for?', role='assistant', function_call=None, tool_calls=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = []\n",
|
||||
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
||||
@ -329,21 +342,8 @@
|
||||
"assistant_message = chat_response.choices[0].message\n",
|
||||
"messages.append(assistant_message)\n",
|
||||
"assistant_message\n"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content='Please specify the number of days (x) for which you want the weather forecast for Glasgow, Scotland.', role='assistant', function_call=None, tool_calls=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 8
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -355,34 +355,34 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "c7d8a543",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:49.790820Z",
|
||||
"start_time": "2024-05-15T17:45:48.847752Z"
|
||||
"end_time": "2024-07-12T22:42:09.587530Z",
|
||||
"start_time": "2024-07-12T22:42:08.666795Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_34PBraFdNN6KR95uD5rHF8Aw', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\",\"num_days\":5}', name='get_n_day_weather_forecast'), type='function')]))"
|
||||
]
|
||||
},
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages.append({\"role\": \"user\", \"content\": \"5 days\"})\n",
|
||||
"chat_response = chat_completion_request(\n",
|
||||
" messages, tools=tools\n",
|
||||
")\n",
|
||||
"chat_response.choices[0]\n"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Yg5ydH9lHhLjjYQyXbNvh004', function=Function(arguments='{\"location\":\"Glasgow, Scotland\",\"format\":\"celsius\",\"num_days\":5}', name='get_n_day_weather_forecast'), type='function')]))"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 9
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -403,13 +403,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "559371b7",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:54.194255Z",
|
||||
"start_time": "2024-05-15T17:45:52.975746Z"
|
||||
"end_time": "2024-07-12T22:42:12.216712Z",
|
||||
"start_time": "2024-07-12T22:42:11.714246Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_FImGxrLowOAOszCaaQqQWmEN', function=Function(arguments='{\"location\":\"Toronto, Canada\",\"format\":\"celsius\",\"num_days\":7}', name='get_n_day_weather_forecast'), type='function')])"
|
||||
]
|
||||
},
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# in this cell we force the model to use get_n_day_weather_forecast\n",
|
||||
"messages = []\n",
|
||||
@ -419,30 +432,30 @@
|
||||
" messages, tools=tools, tool_choice={\"type\": \"function\", \"function\": {\"name\": \"get_n_day_weather_forecast\"}}\n",
|
||||
")\n",
|
||||
"chat_response.choices[0].message"
|
||||
],
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "a7ab0f58",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-07-12T22:42:14.264601Z",
|
||||
"start_time": "2024-07-12T22:42:13.001306Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_aP8ZEtGcyseL0btTMYxTCKbk', function=Function(arguments='{\"location\":\"Toronto, Canada\",\"format\":\"celsius\",\"num_days\":1}', name='get_n_day_weather_forecast'), type='function')])"
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_n84kYFqjNFDPNGDEnjnrd2KC', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_AEs3AFhJc9pn42hWSbHTaIDh', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\", \"num_days\": 3}', name='get_n_day_weather_forecast'), type='function')])"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 65,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 10
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "a7ab0f58",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:56.841233Z",
|
||||
"start_time": "2024-05-15T17:45:55.433397Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# if we don't force the model to use get_n_day_weather_forecast it may not\n",
|
||||
"messages = []\n",
|
||||
@ -452,21 +465,8 @@
|
||||
" messages, tools=tools\n",
|
||||
")\n",
|
||||
"chat_response.choices[0].message"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_5HqCVRaAoBuU0uTlO3MUwaWX', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_C9kCha28xHEsxYl4PxZ1l5LI', function=Function(arguments='{\"location\": \"Toronto, Canada\", \"format\": \"celsius\", \"num_days\": 3}', name='get_n_day_weather_forecast'), type='function')])"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 11
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -478,13 +478,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "acfe54e6",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:45:59.800346Z",
|
||||
"start_time": "2024-05-15T17:45:59.289603Z"
|
||||
"end_time": "2024-07-12T22:42:16.928643Z",
|
||||
"start_time": "2024-07-12T22:42:16.295006Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=\"Sure, I'll get the current weather for Toronto, Canada in Celsius.\", role='assistant', function_call=None, tool_calls=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 66,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = []\n",
|
||||
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
||||
@ -493,21 +506,8 @@
|
||||
" messages, tools=tools, tool_choice=\"none\"\n",
|
||||
")\n",
|
||||
"chat_response.choices[0].message\n"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatCompletionMessage(content=\"I'll get the current weather for Toronto, Canada in Celsius.\", role='assistant', function_call=None, tool_calls=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 12
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b616353b",
|
||||
@ -520,13 +520,27 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "380eeb68",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:04.048553Z",
|
||||
"start_time": "2024-05-15T17:46:01.273501Z"
|
||||
"end_time": "2024-07-12T22:42:18.988762Z",
|
||||
"start_time": "2024-07-12T22:42:18.041914Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ChatCompletionMessageToolCall(id='call_ObhLiJwaHwc3U1KyB4Pdpx8y', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"fahrenheit\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
|
||||
" ChatCompletionMessageToolCall(id='call_5YRgeZ0MGBMFKE3hZiLouwg7', function=Function(arguments='{\"location\": \"Glasgow, SCT\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = []\n",
|
||||
"messages.append({\"role\": \"system\", \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"})\n",
|
||||
@ -537,22 +551,8 @@
|
||||
"\n",
|
||||
"assistant_message = chat_response.choices[0].message.tool_calls\n",
|
||||
"assistant_message"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ChatCompletionMessageToolCall(id='call_pFdKcCu5taDTtOOfX14vEDRp', function=Function(arguments='{\"location\": \"San Francisco, CA\", \"format\": \"fahrenheit\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function'),\n",
|
||||
" ChatCompletionMessageToolCall(id='call_Veeyp2hYJOKp0wT7ODxmTjaS', function=Function(arguments='{\"location\": \"Glasgow, UK\", \"format\": \"celsius\", \"num_days\": 4}', name='get_n_day_weather_forecast'), type='function')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 13
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@ -579,19 +579,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "30f6b60e",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:07.270851Z",
|
||||
"start_time": "2024-05-15T17:46:07.265545Z"
|
||||
"end_time": "2024-07-12T22:42:20.742187Z",
|
||||
"start_time": "2024-07-12T22:42:20.737751Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"import sqlite3\n",
|
||||
"\n",
|
||||
"conn = sqlite3.connect(\"data/Chinook.db\")\n",
|
||||
"print(\"Opened database successfully\")"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -601,17 +596,24 @@
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 14
|
||||
"source": [
|
||||
"import sqlite3\n",
|
||||
"\n",
|
||||
"conn = sqlite3.connect(\"data/Chinook.db\")\n",
|
||||
"print(\"Opened database successfully\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "abec0214",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:09.345308Z",
|
||||
"start_time": "2024-05-15T17:46:09.342998Z"
|
||||
"end_time": "2024-07-12T22:42:21.370623Z",
|
||||
"start_time": "2024-07-12T22:42:21.368246Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_table_names(conn):\n",
|
||||
" \"\"\"Return a list of table names.\"\"\"\n",
|
||||
@ -638,9 +640,7 @@
|
||||
" columns_names = get_column_names(conn, table_name)\n",
|
||||
" table_dicts.append({\"table_name\": table_name, \"column_names\": columns_names})\n",
|
||||
" return table_dicts\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 15
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
@ -653,13 +653,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "0c0104cd",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:11.303746Z",
|
||||
"start_time": "2024-05-15T17:46:11.301210Z"
|
||||
"end_time": "2024-07-12T22:42:22.668456Z",
|
||||
"start_time": "2024-07-12T22:42:22.665839Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"database_schema_dict = get_database_info(conn)\n",
|
||||
"database_schema_string = \"\\n\".join(\n",
|
||||
@ -668,9 +670,7 @@
|
||||
" for table in database_schema_dict\n",
|
||||
" ]\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 16
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
@ -683,13 +683,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "0258813a",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:16.569530Z",
|
||||
"start_time": "2024-05-15T17:46:16.567801Z"
|
||||
"end_time": "2024-07-12T22:42:24.156291Z",
|
||||
"start_time": "2024-07-12T22:42:24.154372Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
@ -715,9 +717,7 @@
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 17
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
@ -732,13 +732,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "65585e74",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:19.198723Z",
|
||||
"start_time": "2024-05-15T17:46:19.197043Z"
|
||||
"end_time": "2024-07-12T22:42:25.444734Z",
|
||||
"start_time": "2024-07-12T22:42:25.442757Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def ask_database(conn, query):\n",
|
||||
" \"\"\"Function to query SQLite database with a provided SQL query.\"\"\"\n",
|
||||
@ -747,9 +749,7 @@
|
||||
" except Exception as e:\n",
|
||||
" results = f\"query failed with error: {e}\"\n",
|
||||
" return results"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 18
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -767,13 +767,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "e8b7cb9cdc7a7616",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:25.725379Z",
|
||||
"start_time": "2024-05-15T17:46:24.255505Z"
|
||||
"end_time": "2024-07-12T22:42:28.395683Z",
|
||||
"start_time": "2024-07-12T22:42:27.415626Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_wDN8uLjq2ofuU6rVx1k8Gw0e', function=Function(arguments='{\"query\":\"SELECT Album.Title, COUNT(Track.TrackId) AS TrackCount FROM Album INNER JOIN Track ON Album.AlbumId = Track.AlbumId GROUP BY Album.Title ORDER BY TrackCount DESC LIMIT 1;\"}', name='ask_database'), type='function')])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Step #1: Prompt with content that may result in function call. In this case the model can identify the information requested by the user is potentially available in the database schema passed to the model in Tools description. \n",
|
||||
"messages = [{\n",
|
||||
@ -793,27 +803,27 @@
|
||||
"messages.append(response_message)\n",
|
||||
"\n",
|
||||
"print(response_message)"
|
||||
],
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "351c39def3417776",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-07-12T22:42:30.439519Z",
|
||||
"start_time": "2024-07-12T22:42:29.799492Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_bXMf903yO78sdsMZble4yu90', function=Function(arguments='{\"query\":\"SELECT A.Title, COUNT(T.TrackId) AS TrackCount FROM Album A JOIN Track T ON A.AlbumId = T.AlbumId GROUP BY A.Title ORDER BY TrackCount DESC LIMIT 1;\"}', name='ask_database'), type='function')])\n"
|
||||
"The album with the most tracks is titled \"Greatest Hits,\" which contains 57 tracks.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 19
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "351c39def3417776",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-15T17:46:30.346444Z",
|
||||
"start_time": "2024-05-15T17:46:29.699046Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Step 2: determine if the response from the model includes a tool call. \n",
|
||||
"tool_calls = response_message.tool_calls\n",
|
||||
@ -821,8 +831,8 @@
|
||||
" # If true the model will return the name of the tool / function to call and the argument(s) \n",
|
||||
" tool_call_id = tool_calls[0].id\n",
|
||||
" tool_function_name = tool_calls[0].function.name\n",
|
||||
" tool_query_string = eval(tool_calls[0].function.arguments)['query']\n",
|
||||
" \n",
|
||||
" tool_query_string = json.loads(tool_calls[0].function.arguments)['query']\n",
|
||||
"\n",
|
||||
" # Step 3: Call the function and retrieve results. Append the results to the messages list. \n",
|
||||
" if tool_function_name == 'ask_database':\n",
|
||||
" results = ask_database(conn, tool_query_string)\n",
|
||||
@ -846,17 +856,7 @@
|
||||
"else: \n",
|
||||
" # Model did not identify a function to call, result can be returned to the user \n",
|
||||
" print(response_message.content) "
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The album with the most tracks is titled \"Greatest Hits,\" and it contains 57 tracks.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 20
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
|
Loading…
x
Reference in New Issue
Block a user