openai-cookbook/examples/Leveraging_model_distillation_to_fine-tune_a_model.ipynb

832 lines
34 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "c6806af9-68ae-4714-851c-9a967aee0e23",
"metadata": {},
"source": [
"# Leveraging model distillation to fine-tune a model\n",
"\n",
"OpenAI recently released **Distillation** which allows to leverage the outputs of a (large) model to fine-tune another (smaller) model. This can significantly reduce the price and the latency for specific tasks as you move to a smaller model. In this cookbook we'll look at a dataset, distill the output of gpt-4o to gpt-4o-mini and show how we can get significantly better results than on a generic, non-distilled, 4o-mini.\n",
"\n",
"We'll also leverage **Structured Outputs** for a classification problem using a list of enum. We'll see how fine-tuned model can benefit from structured output and how it will impact the performance. We'll show that **Structured Ouputs** work with all of those models, including the distilled one.\n",
"\n",
"We'll first analyze the dataset, get the output of both 4o and 4o mini, highlighting the difference in performance of both models, then proceed to the distillation and analyze the performance of this distilled model."
]
},
{
"cell_type": "markdown",
"id": "5dd8fd2f-dfdf-47c2-9627-02acbe3fb7a2",
"metadata": {},
"source": [
"## Prerequisites\n",
"\n",
"Let's install and load dependencies.\n",
"Make sure your OpenAI API key is defined in your environment as \"OPENAI_API_KEY\" and it'll be loaded by the client directly."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e16ed9ef-0220-4f23-a8eb-40813eacf210",
"metadata": {},
"outputs": [],
"source": [
"! pip install openai tiktoken numpy pandas tqdm --quiet"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7b643798-3b2b-43e4-bfb5-ebcf74066253",
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import json\n",
"import tiktoken\n",
"from tqdm import tqdm\n",
"from openai import OpenAI\n",
"import numpy as np\n",
"import concurrent.futures\n",
"import pandas as pd\n",
"\n",
"client = OpenAI()"
]
},
{
"cell_type": "markdown",
"id": "246364b6-2fed-4b54-b540-09569a197a6b",
"metadata": {},
"source": [
"## Loading and understanding the dataset\n",
"\n",
"For this cookbook, we'll load the data from the following Kaggle challenge: [https://www.kaggle.com/datasets/zynicide/wine-reviews](https://www.kaggle.com/datasets/zynicide/wine-reviews).\n",
"\n",
"This dataset has a large number of rows and you're free to run this cookbook on the whole data, but as a biaised french wine-lover, I'll narrow down the dataset to only French wine to focus on less rows and grape varieties.\n",
"\n",
"We're looking at a classification problem where we'd like to guess the grape variety based on all other criterias available, including description, subregion and province that we'll include in the prompt. It gives a lot of information to the model, you're free to also remove some information that can help significantly the model such as the region in which it was produced to see if it does a good job at finding the grape.\n",
"\n",
"Let's filter the grape varieties that have less than 5 occurences in reviews.\n",
"\n",
"Let's proceed with a subset of 500 random rows from this dataset."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "759d1705-2213-443a-9fc3-050bc00177e6",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Unnamed: 0</th>\n",
" <th>country</th>\n",
" <th>description</th>\n",
" <th>designation</th>\n",
" <th>points</th>\n",
" <th>price</th>\n",
" <th>province</th>\n",
" <th>region_1</th>\n",
" <th>region_2</th>\n",
" <th>taster_name</th>\n",
" <th>taster_twitter_handle</th>\n",
" <th>title</th>\n",
" <th>variety</th>\n",
" <th>winery</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>95206</th>\n",
" <td>95206</td>\n",
" <td>France</td>\n",
" <td>Full, fat, ripe, perfumed wine that is full of...</td>\n",
" <td>Château de Mercey Premier Cru</td>\n",
" <td>91</td>\n",
" <td>35.0</td>\n",
" <td>Burgundy</td>\n",
" <td>Mercurey</td>\n",
" <td>NaN</td>\n",
" <td>Roger Voss</td>\n",
" <td>@vossroger</td>\n",
" <td>Antonin Rodet 2010 Château de Mercey Premier C...</td>\n",
" <td>Pinot Noir</td>\n",
" <td>Antonin Rodet</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66403</th>\n",
" <td>66403</td>\n",
" <td>France</td>\n",
" <td>For simple Chablis, this is impressive, rich, ...</td>\n",
" <td>Domaine</td>\n",
" <td>89</td>\n",
" <td>26.0</td>\n",
" <td>Burgundy</td>\n",
" <td>Chablis</td>\n",
" <td>NaN</td>\n",
" <td>Roger Voss</td>\n",
" <td>@vossroger</td>\n",
" <td>William Fèvre 2005 Domaine (Chablis)</td>\n",
" <td>Chardonnay</td>\n",
" <td>William Fèvre</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71277</th>\n",
" <td>71277</td>\n",
" <td>France</td>\n",
" <td>This 50-50 blend of Marselan and Merlot opens ...</td>\n",
" <td>La Remise</td>\n",
" <td>84</td>\n",
" <td>13.0</td>\n",
" <td>France Other</td>\n",
" <td>Vin de France</td>\n",
" <td>NaN</td>\n",
" <td>Lauren Buzzeo</td>\n",
" <td>@laurbuzz</td>\n",
" <td>Domaine de la Mordorée 2014 La Remise Red (Vin...</td>\n",
" <td>Red Blend</td>\n",
" <td>Domaine de la Mordorée</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27484</th>\n",
" <td>27484</td>\n",
" <td>France</td>\n",
" <td>The medium-intense nose of this solid and easy...</td>\n",
" <td>Authentic &amp; Chic</td>\n",
" <td>86</td>\n",
" <td>10.0</td>\n",
" <td>France Other</td>\n",
" <td>Vin de France</td>\n",
" <td>NaN</td>\n",
" <td>Lauren Buzzeo</td>\n",
" <td>@laurbuzz</td>\n",
" <td>Romantic 2014 Authentic &amp; Chic Cabernet Sauvig...</td>\n",
" <td>Cabernet Sauvignon</td>\n",
" <td>Romantic</td>\n",
" </tr>\n",
" <tr>\n",
" <th>124917</th>\n",
" <td>124917</td>\n",
" <td>France</td>\n",
" <td>Fresh, pure notes of Conference pear peel enti...</td>\n",
" <td>NaN</td>\n",
" <td>89</td>\n",
" <td>30.0</td>\n",
" <td>Alsace</td>\n",
" <td>Alsace</td>\n",
" <td>NaN</td>\n",
" <td>Anne Krebiehl MW</td>\n",
" <td>@AnneInVino</td>\n",
" <td>Domaine Vincent Stoeffler 2015 Pinot Gris (Als...</td>\n",
" <td>Pinot Gris</td>\n",
" <td>Domaine Vincent Stoeffler</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Unnamed: 0 country description \\\n",
"95206 95206 France Full, fat, ripe, perfumed wine that is full of... \n",
"66403 66403 France For simple Chablis, this is impressive, rich, ... \n",
"71277 71277 France This 50-50 blend of Marselan and Merlot opens ... \n",
"27484 27484 France The medium-intense nose of this solid and easy... \n",
"124917 124917 France Fresh, pure notes of Conference pear peel enti... \n",
"\n",
" designation points price province \\\n",
"95206 Château de Mercey Premier Cru 91 35.0 Burgundy \n",
"66403 Domaine 89 26.0 Burgundy \n",
"71277 La Remise 84 13.0 France Other \n",
"27484 Authentic & Chic 86 10.0 France Other \n",
"124917 NaN 89 30.0 Alsace \n",
"\n",
" region_1 region_2 taster_name taster_twitter_handle \\\n",
"95206 Mercurey NaN Roger Voss @vossroger \n",
"66403 Chablis NaN Roger Voss @vossroger \n",
"71277 Vin de France NaN Lauren Buzzeo @laurbuzz \n",
"27484 Vin de France NaN Lauren Buzzeo @laurbuzz \n",
"124917 Alsace NaN Anne Krebiehl MW @AnneInVino \n",
"\n",
" title variety \\\n",
"95206 Antonin Rodet 2010 Château de Mercey Premier C... Pinot Noir \n",
"66403 William Fèvre 2005 Domaine (Chablis) Chardonnay \n",
"71277 Domaine de la Mordorée 2014 La Remise Red (Vin... Red Blend \n",
"27484 Romantic 2014 Authentic & Chic Cabernet Sauvig... Cabernet Sauvignon \n",
"124917 Domaine Vincent Stoeffler 2015 Pinot Gris (Als... Pinot Gris \n",
"\n",
" winery \n",
"95206 Antonin Rodet \n",
"66403 William Fèvre \n",
"71277 Domaine de la Mordorée \n",
"27484 Romantic \n",
"124917 Domaine Vincent Stoeffler "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('data/winemag/winemag-data-130k-v2.csv')\n",
"df_france = df[df['country'] == 'France']\n",
"\n",
"# Let's also filter out wines that have less than 5 references with their grape variety even though we'd like to find those\n",
"# they're outliers that we don't want to optimize for that would make our enum list be too long\n",
"# and they could also add noise for the rest of the dataset on which we'd like to guess, eventually reducing our accuracy.\n",
"\n",
"varieties_less_than_five_list = df_france['variety'].value_counts()[df_france['variety'].value_counts() < 5].index.tolist()\n",
"df_france = df_france[~df_france['variety'].isin(varieties_less_than_five_list)]\n",
"\n",
"df_france_subset = df_france.sample(n=500)\n",
"df_france_subset.head()"
]
},
{
"cell_type": "markdown",
"id": "b96cd12f-cbdf-46af-958f-3d553598be1d",
"metadata": {},
"source": [
"Let's retrieve all grape varieties to include them in the prompt and in our structured outputs enum list."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "06f5dbea-549a-455d-9b6e-051de9d38723",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['Gewürztraminer', 'Pinot Gris', 'Gamay',\n",
" 'Bordeaux-style White Blend', 'Champagne Blend', 'Chardonnay',\n",
" 'Petit Manseng', 'Riesling', 'White Blend', 'Pinot Blanc',\n",
" 'Alsace white blend', 'Bordeaux-style Red Blend', 'Malbec',\n",
" 'Tannat-Cabernet', 'Rhône-style Red Blend', 'Ugni Blanc-Colombard',\n",
" 'Savagnin', 'Pinot Noir', 'Rosé', 'Melon',\n",
" 'Rhône-style White Blend', 'Pinot Noir-Gamay', 'Colombard',\n",
" 'Chenin Blanc', 'Sylvaner', 'Sauvignon Blanc', 'Red Blend',\n",
" 'Chenin Blanc-Chardonnay', 'Cabernet Sauvignon', 'Cabernet Franc',\n",
" 'Syrah', 'Sparkling Blend', 'Duras', 'Provence red blend',\n",
" 'Tannat', 'Merlot', 'Malbec-Merlot', 'Chardonnay-Viognier',\n",
" 'Cabernet Franc-Cabernet Sauvignon', 'Muscat', 'Viognier',\n",
" 'Picpoul', 'Altesse', 'Provence white blend', 'Mondeuse',\n",
" 'Grenache-Syrah', 'G-S-M', 'Pinot Meunier', 'Cabernet-Syrah',\n",
" 'Vermentino', 'Marsanne', 'Colombard-Sauvignon Blanc',\n",
" 'Gros and Petit Manseng', 'Jacquère', 'Negrette', 'Mauzac',\n",
" 'Pinot Auxerrois', 'Grenache', 'Roussanne', 'Gros Manseng',\n",
" 'Tannat-Merlot', 'Aligoté', 'Chasselas', \"Loin de l'Oeil\",\n",
" 'Malbec-Tannat', 'Carignan', 'Colombard-Ugni Blanc', 'Sémillon',\n",
" 'Syrah-Grenache', 'Sciaccerellu', 'Auxerrois', 'Mourvèdre',\n",
" 'Tannat-Cabernet Franc', 'Braucol', 'Trousseau',\n",
" 'Merlot-Cabernet Sauvignon'], dtype='<U33')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"varieties = np.array(df_france['variety'].unique()).astype('str')\n",
"varieties"
]
},
{
"cell_type": "markdown",
"id": "6612e612-698d-4f1a-8008-70fb5ff263a0",
"metadata": {},
"source": [
"## Generating the prompt\n",
"\n",
"Let's build out a function to generate our prompt and try it for the first wine of our list."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c3ec2fba-9c99-4cb7-bf56-f13e3816e559",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n Based on this wine review, guess the grape variety:\\n This wine is produced by Trimbach in the Alsace region of France.\\n It was grown in Alsace. It is described as: \"This dry and restrained wine offers spice in profusion. Balanced with acidity and a firm texture, it\\'s very much for food.\".\\n The wine has been reviewed by Roger Voss and received 87 points.\\n The price is 24.0.\\n\\n Here is a list of possible grape varieties to choose from: Gewürztraminer, Pinot Gris, Gamay, Bordeaux-style White Blend, Champagne Blend, Chardonnay, Petit Manseng, Riesling, White Blend, Pinot Blanc, Alsace white blend, Bordeaux-style Red Blend, Malbec, Tannat-Cabernet, Rhône-style Red Blend, Ugni Blanc-Colombard, Savagnin, Pinot Noir, Rosé, Melon, Rhône-style White Blend, Pinot Noir-Gamay, Colombard, Chenin Blanc, Sylvaner, Sauvignon Blanc, Red Blend, Chenin Blanc-Chardonnay, Cabernet Sauvignon, Cabernet Franc, Syrah, Sparkling Blend, Duras, Provence red blend, Tannat, Merlot, Malbec-Merlot, Chardonnay-Viognier, Cabernet Franc-Cabernet Sauvignon, Muscat, Viognier, Picpoul, Altesse, Provence white blend, Mondeuse, Grenache-Syrah, G-S-M, Pinot Meunier, Cabernet-Syrah, Vermentino, Marsanne, Colombard-Sauvignon Blanc, Gros and Petit Manseng, Jacquère, Negrette, Mauzac, Pinot Auxerrois, Grenache, Roussanne, Gros Manseng, Tannat-Merlot, Aligoté, Chasselas, Loin de l\\'Oeil, Malbec-Tannat, Carignan, Colombard-Ugni Blanc, Sémillon, Syrah-Grenache, Sciaccerellu, Auxerrois, Mourvèdre, Tannat-Cabernet Franc, Braucol, Trousseau, Merlot-Cabernet Sauvignon.\\n \\n What is the likely grape variety? Answer only with the grape variety name or blend from the list.\\n '"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def generate_prompt(row, varieties):\n",
" # Format the varieties list as a comma-separated string\n",
" variety_list = ', '.join(varieties)\n",
" \n",
" prompt = f\"\"\"\n",
" Based on this wine review, guess the grape variety:\n",
" This wine is produced by {row['winery']} in the {row['province']} region of {row['country']}.\n",
" It was grown in {row['region_1']}. It is described as: \"{row['description']}\".\n",
" The wine has been reviewed by {row['taster_name']} and received {row['points']} points.\n",
" The price is {row['price']}.\n",
"\n",
" Here is a list of possible grape varieties to choose from: {variety_list}.\n",
" \n",
" What is the likely grape variety? Answer only with the grape variety name or blend from the list.\n",
" \"\"\"\n",
" return prompt\n",
"\n",
"# Example usage with a specific row\n",
"prompt = generate_prompt(df_france.iloc[0], varieties)\n",
"prompt"
]
},
{
"cell_type": "markdown",
"id": "fdf208f5-0a70-41cb-b2c0-c340ee55d6d1",
"metadata": {},
"source": [
"To get a understanding of the cost before running the queries, you can leverage tiktoken to understand the number of tokens we'll send and the cost associated to run this. This will only give you an estimate for to run the completions, not the fine-tuning process (used later in this cookbook when running the distillation), which depends on other factors such as the number of epochs, training set etc."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "cf43b5f4-dbfb-4c89-a7f8-2bd7b03c49c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of tokens in the dataset: 245439\n",
"Total number of prompts: 500\n"
]
}
],
"source": [
"# Load encoding for the GPT-4 model\n",
"enc = tiktoken.encoding_for_model(\"gpt-4o\")\n",
"\n",
"# Initialize a variable to store the total number of tokens\n",
"total_tokens = 0\n",
"\n",
"for index, row in df_france_subset.iterrows():\n",
" prompt = generate_prompt(row, varieties)\n",
" \n",
" # Tokenize the input text and count tokens\n",
" tokens = enc.encode(prompt)\n",
" token_count = len(tokens)\n",
" \n",
" # Add the token count to the total\n",
" total_tokens += token_count\n",
"\n",
"print(f\"Total number of tokens in the dataset: {total_tokens}\")\n",
"print(f\"Total number of prompts: {len(df_france_subset)}\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "362e1ddd-2a87-433e-bcba-534ae946c2ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.6135975\n",
"0.03681585\n"
]
}
],
"source": [
"# outputing cost in $ as of 2024/10/16\n",
"\n",
"gpt4o_token_price = 2.50 / 1_000_000 # $2.50 per 1M tokens\n",
"gpt4o_mini_token_price = 0.150 / 1_000_000 # $0.15 per 1M tokens\n",
"\n",
"total_gpt4o_cost = gpt4o_token_price*total_tokens\n",
"total_gpt4o_mini_cost = gpt4o_mini_token_price*total_tokens\n",
"\n",
"print(total_gpt4o_cost)\n",
"print(total_gpt4o_mini_cost)"
]
},
{
"cell_type": "markdown",
"id": "6ff1181a-6b06-4ad3-8d68-7afb64518010",
"metadata": {},
"source": [
"## Preparing functions to Store Completions\n",
"\n",
"As we're looking at a limited list of response (enumerate list of grape varieties), let's leverage structured outputs so we make sure the model will answer from this list. This also allows us to compare the model's answer directly with the grape variety and have a deterministic answer (compared to a model that could answer \"I think the grape is Pinot Noir\" instead of just \"Pinot noir\"), on top of improving the performance to avoid grape varieties not in our dataset.\n",
"\n",
"If you want to know more on Structured Outputs you can read this [cookbook](https://cookbook.openai.com/examples/structured_outputs_intro) and this [documentation guide](https://platform.openai.com/docs/guides/structured-outputs/introduction)."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4153ca1a-c195-4102-a6c3-f570dd6ae372",
"metadata": {},
"outputs": [],
"source": [
"response_format = {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": {\n",
" \"name\": \"grape-variety\",\n",
" \"schema\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"variety\": {\n",
" \"type\": \"string\",\n",
" \"enum\": varieties.tolist()\n",
" }\n",
" },\n",
" \"additionalProperties\": False,\n",
" \"required\": [\"variety\"],\n",
" },\n",
" \"strict\": True\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "0fedb40e-72e5-4aba-8bc2-b244111b4b43",
"metadata": {},
"source": [
"To distill a model, you need to store all completions from a model, allowing you to give it as a reference to the smaller model to fine-tune it. We're therefore adding a `store=True` parameter to our `client.chat.completions.create` method so we can store those completions from gpt-4o.\n",
"\n",
"We're going to store all completions (even 4o-mini and our future fine-tuned model) so we are able to run [Evals](https://platform.openai.com/docs/guides/evals) from OpenAI platform directly.\n",
"\n",
"When storing those completions, it's useful to store them with a metadata tag, that will allow filtering from the OpenAI platform to run distillation & evals on the specific set of completions you'd like to run those."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "85cb4cc7-077a-4afe-aefc-85b768adfc1f",
"metadata": {},
"outputs": [],
"source": [
"# Initialize the progress index\n",
"metadata_value = \"wine-distillation\" # that's a funny metadata tag :-)\n",
"\n",
"# Function to call the API and process the result for a single model (blocking call in this case)\n",
"def call_model(model, prompt):\n",
" response = client.chat.completions.create(\n",
" model=model,\n",
" store=True,\n",
" metadata={\n",
" \"distillation\": metadata_value,\n",
" },\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You're a sommelier expert and you know everything about wine. You answer precisely with the name of the variety/blend.\"\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": prompt\n",
" }\n",
" ],\n",
" response_format=response_format\n",
" )\n",
" return json.loads(response.choices[0].message.content.strip())['variety']"
]
},
{
"cell_type": "markdown",
"id": "cd4ae8e1-20cb-43b4-a7b1-9998e09f4394",
"metadata": {},
"source": [
"## Parallel processing\n",
"\n",
"As we'll run this on a large number of rows, let's make sure we run those completions in parallel and use concurrent futures for this. We'll iterate on our dataframe and output progress every 20 rows. We'll store the completion from the model we run the completion for in the same dataframe using the column name `{model}-variety`."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "81ddec71-c2a6-411a-aecd-6556cddc36c8",
"metadata": {},
"outputs": [],
"source": [
"def process_example(index, row, model, df, progress_bar):\n",
" global progress_index\n",
"\n",
" try:\n",
" # Generate the prompt using the row\n",
" prompt = generate_prompt(row, varieties)\n",
"\n",
" df.at[index, model + \"-variety\"] = call_model(model, prompt)\n",
" \n",
" # Update the progress bar\n",
" progress_bar.update(1)\n",
" \n",
" progress_index += 1\n",
" except Exception as e:\n",
" print(f\"Error processing model {model}: {str(e)}\")\n",
"\n",
"def process_dataframe(df, model):\n",
" global progress_index\n",
" progress_index = 1 # Reset progress index\n",
"\n",
" # Create a tqdm progress bar\n",
" with tqdm(total=len(df), desc=\"Processing rows\") as progress_bar:\n",
" # Process each example concurrently using ThreadPoolExecutor\n",
" with concurrent.futures.ThreadPoolExecutor() as executor:\n",
" futures = {executor.submit(process_example, index, row, model, df, progress_bar): index for index, row in df.iterrows()}\n",
" \n",
" for future in concurrent.futures.as_completed(futures):\n",
" try:\n",
" future.result() # Wait for each example to be processed\n",
" except Exception as e:\n",
" print(f\"Error processing example: {str(e)}\")\n",
"\n",
" return df"
]
},
{
"cell_type": "markdown",
"id": "680587d1-bcb8-45bd-bfe8-1f771ae1a5bd",
"metadata": {},
"source": [
"Let's try out our call model function before processing the whole dataframe and check the output."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "95915fd2-24f6-4908-a54f-9ce59073233d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Pinot Noir'"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer = call_model('gpt-4o', generate_prompt(df_france_subset.iloc[0], varieties))\n",
"answer"
]
},
{
"cell_type": "markdown",
"id": "941b9027-eb22-4eed-84de-8b962fe4fee2",
"metadata": {},
"source": [
"Great! We confirmed we can get a grape variety as an output, let's now process the dataset with both `gpt-4o` and `gpt-4o-mini` and compare the results. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1cedf978-e60b-4f56-9c8c-da6b0a722320",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing rows: 100%|███████████████████████████████████████████████| 500/500 [00:41<00:00, 12.09it/s]\n"
]
}
],
"source": [
"df_france_subset = process_dataframe(df_france_subset, \"gpt-4o\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "93a96539-e25d-4d15-adc0-ef8cbacd70e9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing rows: 100%|███████████████████████████████████████████████| 500/500 [01:31<00:00, 5.45it/s]\n"
]
}
],
"source": [
"df_france_subset = process_dataframe(df_france_subset, \"gpt-4o-mini\")"
]
},
{
"cell_type": "markdown",
"id": "f74869ee-3361-4d08-bd46-126239d6008c",
"metadata": {},
"source": [
"## Comparing gpt-4o and gpt-4o-mini\n",
"\n",
"Now that we've got all chat completions for those two models ; let's compare them against the expected grape variety and assess their accuracy at finding it. We'll do this directly in python here as we've got a simple string check to run, but if your task involves more complex evals you can leverage OpenAI Evals or our open-source eval framework."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "cce147d2-32e4-4005-9f79-11394c76c2e4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpt-4o accuracy: 81.80%\n",
"gpt-4o-mini accuracy: 69.00%\n"
]
}
],
"source": [
"models = ['gpt-4o', 'gpt-4o-mini']\n",
"\n",
"def get_accuracy(model, df):\n",
" return np.mean(df['variety'] == df[model + '-variety'])\n",
"\n",
"for model in models:\n",
" print(f\"{model} accuracy: {get_accuracy(model, df_france_subset) * 100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "0a6ae1c1-22e5-4a61-a055-7b98360de523",
"metadata": {},
"source": [
"We can see that gpt-4o is better a finding grape variety than 4o-mini (12.80% higher or almost 20% relatively to 4o-mini!). Now I'm wondering if we're making gpt-4o drink wine during training!\n",
"\n",
"## Distilling gpt-4o outputs to gpt-4o-mini\n",
"\n",
"Let's assume we'd like to run this prediction often, we want completions to be faster and cheaper, but keep that level of accuracy. That'd be great to be able to distill 4o accuracy to 4o-mini, wouldn't it? Let's do it!\n",
"\n",
"We'll now go to OpenAI Stored completions page: [https://platform.openai.com/chat-completions](https://platform.openai.com/chat-completions).\n",
"\n",
"Let's select the model gpt-4o (make sure to do this, you don't want to distill the outputs of 4o-mini that we ran). Let's also select the metadata `distillation: wine-distillation` to get only stored completions ran from this cookbook.\n",
"\n",
"![Filtering out completions](../images/filtering-out-completions.png)\n",
"\n",
"Once you've selected completions, you can click on \"Distill\" on the top right corner to fine-tune a model based on those completions. Once we've done that, a file to run the fine-tuning process will automatically be created. Let's then select `gpt-4o-mini` as the base model, keep the default parameters (but you're free to change them or iterate with it to improve performance).\n",
"\n",
"![Distilling modal](../images/distilling.png)\n",
"\n",
"Once the fine-tuning job is starting, you can retrieve the fine tuning job ID from the fine-tuning page, we'll use it to monitor status of the fine-tuned job as well as retrieving the fine-tuned model id once done.\n",
"\n",
"![Fine tuning job](../images/fine-tuning-job.png)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "78a1b24e-3956-48f0-ae56-399f729a6c24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"finetuned model: ft:gpt-4o-mini-2024-07-18:distillation-test:wine-distillation:AIZntSyE\n"
]
}
],
"source": [
"# copy paste your fine-tune job ID below\n",
"finetune_job = client.fine_tuning.jobs.retrieve(\"ftjob-pRyNWzUItmHpxmJ1TX7FOaWe\")\n",
"\n",
"if finetune_job.status == 'succeeded':\n",
" fine_tuned_model = finetune_job.fine_tuned_model\n",
" print('finetuned model: ' + fine_tuned_model)\n",
"else:\n",
" print('finetuned job status: ' + finetune_job.status)"
]
},
{
"cell_type": "markdown",
"id": "54eda7a4-3817-40df-84bc-db4555366078",
"metadata": {},
"source": [
"## Running completions for the distilled model\n",
"\n",
"Now that we've got our model fine-tuned, we can use this model to run completions and compare accuracy with both gpt4o and gpt4o-mini.\n",
"Let's grab a different subset of french wines (as we restricted the outputs to french grape varieties, without outliers, we'll need to focus our validation dataset to this too). Let's run this on 300 entries for each models."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "603dc1e9-58dc-4b01-b25a-4433ebd88546",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing rows: 100%|███████████████████████████████████████████████| 300/300 [00:20<00:00, 14.69it/s]\n",
"Processing rows: 100%|███████████████████████████████████████████████| 300/300 [00:27<00:00, 10.99it/s]\n",
"Processing rows: 100%|███████████████████████████████████████████████| 300/300 [00:37<00:00, 8.08it/s]\n"
]
}
],
"source": [
"validation_dataset = df_france.sample(n=300)\n",
"\n",
"models.append(fine_tuned_model)\n",
"\n",
"for model in models:\n",
" another_subset = process_dataframe(validation_dataset, model)"
]
},
{
"cell_type": "markdown",
"id": "6eba9c5d-d805-4610-86fe-e18cc20b8cd9",
"metadata": {},
"source": [
"Let's compare accuracy of models"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "cad42199-a4cb-4589-859b-32b39ad02eb2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpt-4o accuracy: 79.67%\n",
"gpt-4o-mini accuracy: 64.67%\n",
"ft:gpt-4o-mini-2024-07-18:distillation-test:wine-distillation:AIZntSyE accuracy: 79.33%\n"
]
}
],
"source": [
"for model in models:\n",
" print(f\"{model} accuracy: {get_accuracy(model, another_subset) * 100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "2b5bdb3c-c7ee-493a-9343-5950d7c36638",
"metadata": {},
"source": [
"That's almost a **22% relative improvement over the non-distilled gpt-4o-mini! 🎉**\n",
"\n",
"Our fine-tuned model performs way better than gpt-4o-mini, while having the same base model. We'll be able to use this model to run inferences at a lower cost and lower latency for future grape variety prediction."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}