From d4604f1006f768f595aa67e6bdc688c0400b57cf Mon Sep 17 00:00:00 2001 From: vishnu-oai Date: Fri, 11 Apr 2025 09:19:39 -0700 Subject: [PATCH] =?UTF-8?q?Update=20the=20multiclass=20classification=20co?= =?UTF-8?q?okbook=20to=20use=20the=20new=20API=20and=20=E2=80=A6=20(#1764)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...lass_classification_for_transactions.ipynb | 1277 +++++------------ 1 file changed, 357 insertions(+), 920 deletions(-) diff --git a/examples/Multiclass_classification_for_transactions.ipynb b/examples/Multiclass_classification_for_transactions.ipynb index d9c4bef..c5ee388 100644 --- a/examples/Multiclass_classification_for_transactions.ipynb +++ b/examples/Multiclass_classification_for_transactions.ipynb @@ -25,18 +25,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload\n", - "%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers\n" + "%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers scikit-learn matplotlib plotly pandas scipy\n" ] }, { "cell_type": "code", - "execution_count": 311, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -47,8 +47,8 @@ "import os\n", "\n", "COMPLETIONS_MODEL = \"gpt-4\"\n", - "\n", - "client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"\"))" + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "client = openai.OpenAI()" ] }, { @@ -70,184 +70,34 @@ }, { "cell_type": "code", - "execution_count": 312, + "execution_count": 152, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "359" - ] - }, - "execution_count": 312, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of transactions: 359\n", + " Date Supplier Description \\\n", + "0 21/04/2016 M & J Ballantyne Ltd George IV Bridge Work \n", + "1 26/04/2016 Private Sale Literary & Archival Items \n", + "2 30/04/2016 City Of Edinburgh Council Non Domestic Rates \n", + "3 09/05/2016 Computacenter Uk Kelvin Hall \n", + "4 09/05/2016 John Graham Construction Ltd Causewayside Refurbishment \n", + "\n", + " Transaction value (£) \n", + "0 35098.0 \n", + "1 30000.0 \n", + "2 40800.0 \n", + "3 72835.0 \n", + "4 64361.0 \n" + ] } ], "source": [ "transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')\n", - "len(transactions)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 313, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
DateSupplierDescriptionTransaction value (£)
021/04/2016M & J Ballantyne LtdGeorge IV Bridge Work35098.0
126/04/2016Private SaleLiterary & Archival Items30000.0
230/04/2016City Of Edinburgh CouncilNon Domestic Rates40800.0
309/05/2016Computacenter UkKelvin Hall72835.0
409/05/2016John Graham Construction LtdCausewayside Refurbishment64361.0
\n", - "
" - ], - "text/plain": [ - " Date Supplier Description \\\n", - "0 21/04/2016 M & J Ballantyne Ltd George IV Bridge Work \n", - "1 26/04/2016 Private Sale Literary & Archival Items \n", - "2 30/04/2016 City Of Edinburgh Council Non Domestic Rates \n", - "3 09/05/2016 Computacenter Uk Kelvin Hall \n", - "4 09/05/2016 John Graham Construction Ltd Causewayside Refurbishment \n", - "\n", - " Transaction value (£) \n", - "0 35098.0 \n", - "1 30000.0 \n", - "2 40800.0 \n", - "3 72835.0 \n", - "4 64361.0 " - ] - }, - "execution_count": 313, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "transactions.head()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 335, - "metadata": {}, - "outputs": [], - "source": [ - "def request_completion(prompt):\n", - "\n", - " completion_response = openai.chat.completions.create(\n", - " prompt=prompt,\n", - " temperature=0,\n", - " max_tokens=5,\n", - " top_p=1,\n", - " frequency_penalty=0,\n", - " presence_penalty=0,\n", - " model=COMPLETIONS_MODEL)\n", - "\n", - " return completion_response\n", - "\n", - "def classify_transaction(transaction,prompt):\n", - "\n", - " prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n", - " prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n", - " prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n", - "\n", - " classification = request_completion(prompt).choices[0].message.content.replace('\\n','')\n", - "\n", - " return classification\n", - "\n", - "# This function takes your training and validation outputs from the prepare_data function of the Finetuning API, and\n", - "# confirms that each have the same number of classes.\n", - "# If they do not have the same number of classes the fine-tune will fail and return an error\n", - "\n", - "def check_finetune_classes(train_file,valid_file):\n", - "\n", - " train_classes = set()\n", - " valid_classes = set()\n", - " with open(train_file, 'r') as json_file:\n", - " json_list = list(json_file)\n", - " print(len(json_list))\n", - "\n", - " for json_str in json_list:\n", - " result = json.loads(json_str)\n", - " train_classes.add(result['completion'])\n", - " #print(f\"result: {result['completion']}\")\n", - " #print(isinstance(result, dict))\n", - "\n", - " with open(valid_file, 'r') as json_file:\n", - " json_list = list(json_file)\n", - " print(len(json_list))\n", - "\n", - " for json_str in json_list:\n", - " result = json.loads(json_str)\n", - " valid_classes.add(result['completion'])\n", - " #print(f\"result: {result['completion']}\")\n", - " #print(isinstance(result, dict))\n", - "\n", - " if len(train_classes) == len(valid_classes):\n", - " print('All good')\n", - "\n", - " else:\n", - " print('Classes do not match, please prepare data again')\n" + "print(f\"Number of transactions: {len(transactions)}\")\n", + "print(transactions.head())\n" ] }, { @@ -262,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 277, + "execution_count": 154, "metadata": {}, "outputs": [], "source": [ @@ -273,38 +123,54 @@ "\n", "Transaction:\n", "\n", - "Supplier: SUPPLIER_NAME\n", - "Description: DESCRIPTION_TEXT\n", - "Value: TRANSACTION_VALUE\n", + "Supplier: {}\n", + "Description: {}\n", + "Value: {}\n", "\n", - "The classification is:'''\n" + "The classification is:'''\n", + "\n", + "def format_prompt(transaction):\n", + " return zero_shot_prompt.format(transaction['Supplier'], transaction['Description'], transaction['Transaction value (£)'])\n", + "\n", + "def classify_transaction(transaction):\n", + "\n", + " \n", + " prompt = format_prompt(transaction)\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": prompt},\n", + " ]\n", + " completion_response = openai.chat.completions.create(\n", + " messages=messages,\n", + " temperature=0,\n", + " max_tokens=5,\n", + " top_p=1,\n", + " frequency_penalty=0,\n", + " presence_penalty=0,\n", + " model=COMPLETIONS_MODEL)\n", + " label = completion_response.choices[0].message.content.replace('\\n','')\n", + " return label\n" ] }, { "cell_type": "code", - "execution_count": 315, + "execution_count": 155, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Building Improvement\n" + "Transaction: M & J Ballantyne Ltd George IV Bridge Work 35098.0\n", + "Classification: Building Improvement\n" ] } ], "source": [ "# Get a test transaction\n", "transaction = transactions.iloc[0]\n", - "\n", - "# Interpolate the values into the prompt\n", - "prompt = zero_shot_prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n", - "prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n", - "prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n", - "\n", "# Use our completion function to return a prediction\n", - "completion_response = request_completion(prompt)\n", - "print(completion_response.choices[0].text)\n" + "print(f\"Transaction: {transaction['Supplier']} {transaction['Description']} {transaction['Transaction value (£)']}\")\n", + "print(f\"Classification: {classify_transaction(transaction)}\")\n" ] }, { @@ -319,44 +185,45 @@ }, { "cell_type": "code", - "execution_count": 291, + "execution_count": 156, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:2: SettingWithCopyWarning: \n", + "/var/folders/3n/79rgh27s6l7_l91b9shw0_nr0000gp/T/ipykernel_81921/2775604370.py:2: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " \n" + " test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)\n" ] } ], "source": [ "test_transactions = transactions.iloc[:25]\n", - "test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)\n" + "test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)\n" ] }, { "cell_type": "code", - "execution_count": 292, + "execution_count": 157, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - " Building Improvement 14\n", - " Could not classify 5\n", - " Literature & Archive 3\n", - " Software/IT 2\n", - " Utility Bills 1\n", - "Name: Classification, dtype: int64" + "Classification\n", + "Building Improvement 17\n", + "Literature & Archive 3\n", + "Software/IT 2\n", + "Could not classify 2\n", + "Utility Bills 1\n", + "Name: count, dtype: int64" ] }, - "execution_count": 292, + "execution_count": 157, "metadata": {}, "output_type": "execute_result" } @@ -367,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 293, + "execution_count": 158, "metadata": {}, "outputs": [ { @@ -493,7 +360,7 @@ " Wavetek Ltd\n", " Kelvin Hall\n", " 87589.0\n", - " Could not classify\n", + " Building Improvement\n", " \n", " \n", " 12\n", @@ -525,7 +392,7 @@ " Wavetek Ltd\n", " Kelvin Hall\n", " 65692.0\n", - " Could not classify\n", + " Building Improvement\n", " \n", " \n", " 16\n", @@ -581,7 +448,7 @@ " Creative Video Productions Ltd\n", " Kelvin Hall\n", " 26866.0\n", - " Could not classify\n", + " Building Improvement\n", " \n", " \n", " 23\n", @@ -631,35 +498,35 @@ "23 15/08/2016 John Graham Construction Ltd Causewayside Refurbishment \n", "24 24/08/2016 ECG Facilities Service Facilities Management Charge \n", "\n", - " Transaction value (£) Classification \n", - "0 35098.0 Building Improvement \n", - "1 30000.0 Literature & Archive \n", - "2 40800.0 Utility Bills \n", - "3 72835.0 Software/IT \n", - "4 64361.0 Building Improvement \n", - "5 53690.0 Building Improvement \n", - "6 365344.0 Building Improvement \n", - "7 26506.0 Software/IT \n", - "8 32777.0 Building Improvement \n", - "9 32777.0 Building Improvement \n", - "10 32317.0 Could not classify \n", - "11 87589.0 Could not classify \n", - "12 381803.0 Building Improvement \n", - "13 32832.0 Building Improvement \n", - "14 1700000.0 Building Improvement \n", - "15 65692.0 Could not classify \n", - "16 139845.0 Building Improvement \n", - "17 28500.0 Literature & Archive \n", - "18 33800.0 Literature & Archive \n", - "19 30113.0 Building Improvement \n", - "20 32317.0 Could not classify \n", - "21 32795.0 Building Improvement \n", - "22 26866.0 Could not classify \n", - "23 196807.0 Building Improvement \n", - "24 32795.0 Building Improvement " + " Transaction value (£) Classification \n", + "0 35098.0 Building Improvement \n", + "1 30000.0 Literature & Archive \n", + "2 40800.0 Utility Bills \n", + "3 72835.0 Software/IT \n", + "4 64361.0 Building Improvement \n", + "5 53690.0 Building Improvement \n", + "6 365344.0 Building Improvement \n", + "7 26506.0 Software/IT \n", + "8 32777.0 Building Improvement \n", + "9 32777.0 Building Improvement \n", + "10 32317.0 Could not classify \n", + "11 87589.0 Building Improvement \n", + "12 381803.0 Building Improvement \n", + "13 32832.0 Building Improvement \n", + "14 1700000.0 Building Improvement \n", + "15 65692.0 Building Improvement \n", + "16 139845.0 Building Improvement \n", + "17 28500.0 Literature & Archive \n", + "18 33800.0 Literature & Archive \n", + "19 30113.0 Building Improvement \n", + "20 32317.0 Could not classify \n", + "21 32795.0 Building Improvement \n", + "22 26866.0 Building Improvement \n", + "23 196807.0 Building Improvement \n", + "24 32795.0 Building Improvement " ] }, - "execution_count": 293, + "execution_count": 158, "metadata": {}, "output_type": "execute_result" } @@ -692,7 +559,7 @@ }, { "cell_type": "code", - "execution_count": 317, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -784,19 +651,19 @@ "4 27926 Building Improvement " ] }, - "execution_count": 317, + "execution_count": 159, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('./data/labelled_transactions.csv')\n", - "df.head()\n" + "df.head()" ] }, { "cell_type": "code", - "execution_count": 318, + "execution_count": 160, "metadata": {}, "outputs": [ { @@ -865,19 +732,19 @@ "1 Supplier: John Graham Construction Ltd; Descri... " ] }, - "execution_count": 318, + "execution_count": 160, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['combined'] = \"Supplier: \" + df['Supplier'].str.strip() + \"; Description: \" + df['Description'].str.strip() + \"; Value: \" + str(df['Transaction value (£)']).strip()\n", - "df.head(2)\n" + "df.head(2)" ] }, { "cell_type": "code", - "execution_count": 319, + "execution_count": 161, "metadata": {}, "outputs": [ { @@ -886,7 +753,7 @@ "101" ] }, - "execution_count": 319, + "execution_count": 161, "metadata": {}, "output_type": "execute_result" } @@ -896,28 +763,27 @@ "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n", "\n", "df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))\n", - "len(df)\n" + "len(df)" ] }, { "cell_type": "code", - "execution_count": 320, + "execution_count": 162, "metadata": {}, "outputs": [], "source": [ - "embedding_path = './data/transactions_with_embeddings_100.csv'\n" + "embedding_path = './data/transactions_with_embeddings_100.csv'" ] }, { "cell_type": "code", - "execution_count": 321, + "execution_count": 163, "metadata": {}, "outputs": [], "source": [ "from utils.embeddings_utils import get_embedding\n", - "\n", - "df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))\n", - "df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))\n", + "df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x))\n", + "df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x))\n", "df.to_csv(embedding_path)\n" ] }, @@ -935,7 +801,7 @@ }, { "cell_type": "code", - "execution_count": 309, + "execution_count": 164, "metadata": {}, "outputs": [ { @@ -982,8 +848,8 @@ " Other\n", " Supplier: Creative Video Productions Ltd; Desc...\n", " 136\n", - " [-0.009802100248634815, 0.022551486268639565, ...\n", - " [-0.00232666521333158, 0.019198870286345482, 0...\n", + " [-0.02898375503718853, -0.02881557121872902, 0...\n", + " [-0.02879939414560795, -0.02867320366203785, 0...\n", " \n", " \n", " 1\n", @@ -995,8 +861,8 @@ " Building Improvement\n", " Supplier: John Graham Construction Ltd; Descri...\n", " 140\n", - " [-0.009065819904208183, 0.012094118632376194, ...\n", - " [0.005169447045773268, 0.00473341578617692, -0...\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", " \n", " \n", " 2\n", @@ -1008,8 +874,8 @@ " Building Improvement\n", " Supplier: Morris & Spottiswood Ltd; Descriptio...\n", " 141\n", - " [-0.009000026620924473, 0.02405017428100109, -...\n", - " [0.0028343256562948227, 0.021166473627090454, ...\n", + " [0.013581369072198868, -0.003978211898356676, ...\n", + " [0.013593776151537895, -0.0037341134157031775,...\n", " \n", " \n", " 3\n", @@ -1021,8 +887,8 @@ " Building Improvement\n", " Supplier: John Graham Construction Ltd; Descri...\n", " 140\n", - " [-0.009065819904208183, 0.012094118632376194, ...\n", - " [0.005169447045773268, 0.00473341578617692, -0...\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", " \n", " \n", " 4\n", @@ -1034,8 +900,8 @@ " Building Improvement\n", " Supplier: John Graham Construction Ltd; Descri...\n", " 140\n", - " [-0.009065819904208183, 0.012094118632376194, ...\n", - " [0.005169447045773268, 0.00473341578617692, -0...\n", + " [-0.02408558875322342, -0.02881370671093464, 0...\n", + " [-0.024109570309519768, -0.02880912832915783, ...\n", " \n", " \n", "\n", @@ -1064,21 +930,21 @@ "4 Supplier: John Graham Construction Ltd; Descri... 140 \n", "\n", " babbage_similarity \\\n", - "0 [-0.009802100248634815, 0.022551486268639565, ... \n", - "1 [-0.009065819904208183, 0.012094118632376194, ... \n", - "2 [-0.009000026620924473, 0.02405017428100109, -... \n", - "3 [-0.009065819904208183, 0.012094118632376194, ... \n", - "4 [-0.009065819904208183, 0.012094118632376194, ... \n", + "0 [-0.02898375503718853, -0.02881557121872902, 0... \n", + "1 [-0.024112487211823463, -0.02881261520087719, ... \n", + "2 [0.013581369072198868, -0.003978211898356676, ... \n", + "3 [-0.024112487211823463, -0.02881261520087719, ... \n", + "4 [-0.02408558875322342, -0.02881370671093464, 0... \n", "\n", " babbage_search \n", - "0 [-0.00232666521333158, 0.019198870286345482, 0... \n", - "1 [0.005169447045773268, 0.00473341578617692, -0... \n", - "2 [0.0028343256562948227, 0.021166473627090454, ... \n", - "3 [0.005169447045773268, 0.00473341578617692, -0... \n", - "4 [0.005169447045773268, 0.00473341578617692, -0... " + "0 [-0.02879939414560795, -0.02867320366203785, 0... \n", + "1 [-0.024112487211823463, -0.02881261520087719, ... \n", + "2 [0.013593776151537895, -0.0037341134157031775,... \n", + "3 [-0.024112487211823463, -0.02881261520087719, ... \n", + "4 [-0.024109570309519768, -0.02880912832915783, ... " ] }, - "execution_count": 309, + "execution_count": 164, "metadata": {}, "output_type": "execute_result" } @@ -1086,7 +952,7 @@ "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import classification_report, accuracy_score\n", + "from sklearn.metrics import classification_report\n", "from ast import literal_eval\n", "\n", "fs_df = pd.read_csv(embedding_path)\n", @@ -1096,7 +962,7 @@ }, { "cell_type": "code", - "execution_count": 310, + "execution_count": 165, "metadata": {}, "outputs": [ { @@ -1121,12 +987,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", - " _warn_prf(average, modifier, msg_start, len(result))\n", - "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", - " _warn_prf(average, modifier, msg_start, len(result))\n", - "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", - " _warn_prf(average, modifier, msg_start, len(result))\n" + "/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] } ], @@ -1172,14 +1038,13 @@ "### Building Fine-tuned Classifier\n", "\n", "We'll need to do some data prep first to get our data ready. This will take the following steps:\n", - "- First we'll list out our classes and replace them with numeric identifiers. Making the model predict a single token rather than multiple consecutive ones like 'Building Improvement' should give us better results\n", - "- We also need to add a common prefix and suffix to each example to aid the model in making predictions - in our case our text is already started with 'Supplier' and we'll add a suffix of '\\n\\n###\\n\\n'\n", - "- Lastly we'll aid a leading whitespace onto each of our target classes for classification, again to aid the model" + "- To prepare our training and validation sets, we'll create a set of message sequences. The first message for each will be the user prompt formatted with the details of the transaction, and the final message will be the expected classification response from the model\n", + "- Our test set will contain the initial user prompt for each transaction, along with the corresponding expected class label. We will then use the fine-tuned model to generate the actual classification for each transaction." ] }, { "cell_type": "code", - "execution_count": 210, + "execution_count": 64, "metadata": {}, "outputs": [ { @@ -1188,7 +1053,7 @@ "101" ] }, - "execution_count": 210, + "execution_count": 64, "metadata": {}, "output_type": "execute_result" } @@ -1200,7 +1065,7 @@ }, { "cell_type": "code", - "execution_count": 211, + "execution_count": 65, "metadata": {}, "outputs": [ { @@ -1246,9 +1111,9 @@ " 26866\n", " Other\n", " Supplier: Creative Video Productions Ltd; Desc...\n", - " 12\n", - " [-0.009630300104618073, 0.009887108579277992, ...\n", - " [-0.008217384107410908, 0.025170527398586273, ...\n", + " 136\n", + " [-0.028885245323181152, -0.028660893440246582,...\n", + " [-0.02879939414560795, -0.02867320366203785, 0...\n", " \n", " \n", " 1\n", @@ -1259,9 +1124,9 @@ " 74806\n", " Building Improvement\n", " Supplier: John Graham Construction Ltd; Descri...\n", - " 16\n", - " [-0.006144719664007425, -0.0018709596479311585...\n", - " [-0.007424891460686922, 0.008475713431835175, ...\n", + " 140\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", + " [-0.02414606139063835, -0.02883070334792137, 0...\n", " \n", " \n", " 2\n", @@ -1272,9 +1137,9 @@ " 56448\n", " Building Improvement\n", " Supplier: Morris & Spottiswood Ltd; Descriptio...\n", - " 17\n", - " [-0.005225738976150751, 0.015156379900872707, ...\n", - " [-0.007611643522977829, 0.030322374776005745, ...\n", + " 141\n", + " [0.013593776151537895, -0.0037341134157031775,...\n", + " [0.013561442494392395, -0.004199974238872528, ...\n", " \n", " \n", " 3\n", @@ -1285,9 +1150,9 @@ " 164691\n", " Building Improvement\n", " Supplier: John Graham Construction Ltd; Descri...\n", - " 16\n", - " [-0.006144719664007425, -0.0018709596479311585...\n", - " [-0.007424891460686922, 0.008475713431835175, ...\n", + " 140\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", " \n", " \n", " 4\n", @@ -1298,9 +1163,9 @@ " 27926\n", " Building Improvement\n", " Supplier: John Graham Construction Ltd; Descri...\n", - " 16\n", - " [-0.006144719664007425, -0.0018709596479311585...\n", - " [-0.007424891460686922, 0.008475713431835175, ...\n", + " 140\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", + " [-0.024112487211823463, -0.02881261520087719, ...\n", " \n", " \n", "\n", @@ -1322,28 +1187,28 @@ "4 Causewayside Refurbishment 27926 Building Improvement \n", "\n", " combined n_tokens \\\n", - "0 Supplier: Creative Video Productions Ltd; Desc... 12 \n", - "1 Supplier: John Graham Construction Ltd; Descri... 16 \n", - "2 Supplier: Morris & Spottiswood Ltd; Descriptio... 17 \n", - "3 Supplier: John Graham Construction Ltd; Descri... 16 \n", - "4 Supplier: John Graham Construction Ltd; Descri... 16 \n", + "0 Supplier: Creative Video Productions Ltd; Desc... 136 \n", + "1 Supplier: John Graham Construction Ltd; Descri... 140 \n", + "2 Supplier: Morris & Spottiswood Ltd; Descriptio... 141 \n", + "3 Supplier: John Graham Construction Ltd; Descri... 140 \n", + "4 Supplier: John Graham Construction Ltd; Descri... 140 \n", "\n", " babbage_similarity \\\n", - "0 [-0.009630300104618073, 0.009887108579277992, ... \n", - "1 [-0.006144719664007425, -0.0018709596479311585... \n", - "2 [-0.005225738976150751, 0.015156379900872707, ... \n", - "3 [-0.006144719664007425, -0.0018709596479311585... \n", - "4 [-0.006144719664007425, -0.0018709596479311585... \n", + "0 [-0.028885245323181152, -0.028660893440246582,... \n", + "1 [-0.024112487211823463, -0.02881261520087719, ... \n", + "2 [0.013593776151537895, -0.0037341134157031775,... \n", + "3 [-0.024112487211823463, -0.02881261520087719, ... \n", + "4 [-0.024112487211823463, -0.02881261520087719, ... \n", "\n", " babbage_search \n", - "0 [-0.008217384107410908, 0.025170527398586273, ... \n", - "1 [-0.007424891460686922, 0.008475713431835175, ... \n", - "2 [-0.007611643522977829, 0.030322374776005745, ... \n", - "3 [-0.007424891460686922, 0.008475713431835175, ... \n", - "4 [-0.007424891460686922, 0.008475713431835175, ... " + "0 [-0.02879939414560795, -0.02867320366203785, 0... \n", + "1 [-0.02414606139063835, -0.02883070334792137, 0... \n", + "2 [0.013561442494392395, -0.004199974238872528, ... \n", + "3 [-0.024112487211823463, -0.02881261520087719, ... \n", + "4 [-0.024112487211823463, -0.02881261520087719, ... " ] }, - "execution_count": 211, + "execution_count": 65, "metadata": {}, "output_type": "execute_result" } @@ -1354,22 +1219,22 @@ }, { "cell_type": "code", - "execution_count": 212, + "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "( class_id class\n", - " 0 0 Literature & Archive\n", - " 1 1 Utility Bills\n", - " 2 2 Building Improvement\n", - " 3 3 Software/IT\n", - " 4 4 Other,\n", + " 0 0 Other\n", + " 1 1 Literature & Archive\n", + " 2 2 Software/IT\n", + " 3 3 Utility Bills\n", + " 4 4 Building Improvement,\n", " 5)" ] }, - "execution_count": 212, + "execution_count": 96, "metadata": {}, "output_type": "execute_result" } @@ -1383,7 +1248,7 @@ }, { "cell_type": "code", - "execution_count": 215, + "execution_count": 181, "metadata": {}, "outputs": [ { @@ -1407,145 +1272,50 @@ " \n", " \n", " \n", - " Unnamed: 0\n", - " Date\n", - " Supplier\n", - " Description\n", - " Transaction value (£)\n", - " Classification\n", - " combined\n", - " n_tokens\n", - " babbage_similarity\n", - " babbage_search\n", - " class_id\n", - " prompt\n", + " messages\n", + " class\n", " \n", " \n", " \n", " \n", " 0\n", - " 0\n", - " 15/08/2016\n", - " Creative Video Productions Ltd\n", - " Kelvin Hall\n", - " 26866\n", + " [{'role': 'user', 'content': 'You are a data e...\n", " Other\n", - " Supplier: Creative Video Productions Ltd; Desc...\n", - " 12\n", - " [-0.009630300104618073, 0.009887108579277992, ...\n", - " [-0.008217384107410908, 0.025170527398586273, ...\n", - " 4\n", - " Supplier: Creative Video Productions Ltd; Desc...\n", " \n", " \n", " 1\n", - " 51\n", - " 31/03/2017\n", - " NLS Foundation\n", - " Grant Payment\n", - " 177500\n", - " Other\n", - " Supplier: NLS Foundation; Description: Grant P...\n", - " 11\n", - " [-0.022305507212877274, 0.008543581701815128, ...\n", - " [-0.020519884303212166, 0.01993306167423725, -...\n", - " 4\n", - " Supplier: NLS Foundation; Description: Grant P...\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Building Improvement\n", " \n", " \n", " 2\n", - " 70\n", - " 26/06/2017\n", - " British Library\n", - " Legal Deposit Services\n", - " 50056\n", - " Other\n", - " Supplier: British Library; Description: Legal ...\n", - " 11\n", - " [-0.01019938476383686, 0.015277703292667866, -...\n", - " [-0.01843327097594738, 0.03343546763062477, -0...\n", - " 4\n", - " Supplier: British Library; Description: Legal ...\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Building Improvement\n", " \n", " \n", " 3\n", - " 71\n", - " 24/07/2017\n", - " ALDL\n", - " Legal Deposit Services\n", - " 27067\n", - " Other\n", - " Supplier: ALDL; Description: Legal Deposit Ser...\n", - " 11\n", - " [-0.008471488021314144, 0.004098685923963785, ...\n", - " [-0.012966590002179146, 0.01299362163990736, 0...\n", - " 4\n", - " Supplier: ALDL; Description: Legal Deposit Ser...\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Building Improvement\n", " \n", " \n", " 4\n", - " 100\n", - " 24/07/2017\n", - " AM Phillip\n", - " Vehicle Purchase\n", - " 26604\n", - " Other\n", - " Supplier: AM Phillip; Description: Vehicle Pur...\n", - " 10\n", - " [-0.003459023078903556, 0.004626389592885971, ...\n", - " [-0.0010945454705506563, 0.008626140654087067,...\n", - " 4\n", - " Supplier: AM Phillip; Description: Vehicle Pur...\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Building Improvement\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Unnamed: 0 Date Supplier \\\n", - "0 0 15/08/2016 Creative Video Productions Ltd \n", - "1 51 31/03/2017 NLS Foundation \n", - "2 70 26/06/2017 British Library \n", - "3 71 24/07/2017 ALDL \n", - "4 100 24/07/2017 AM Phillip \n", - "\n", - " Description Transaction value (£) Classification \\\n", - "0 Kelvin Hall 26866 Other \n", - "1 Grant Payment 177500 Other \n", - "2 Legal Deposit Services 50056 Other \n", - "3 Legal Deposit Services 27067 Other \n", - "4 Vehicle Purchase 26604 Other \n", - "\n", - " combined n_tokens \\\n", - "0 Supplier: Creative Video Productions Ltd; Desc... 12 \n", - "1 Supplier: NLS Foundation; Description: Grant P... 11 \n", - "2 Supplier: British Library; Description: Legal ... 11 \n", - "3 Supplier: ALDL; Description: Legal Deposit Ser... 11 \n", - "4 Supplier: AM Phillip; Description: Vehicle Pur... 10 \n", - "\n", - " babbage_similarity \\\n", - "0 [-0.009630300104618073, 0.009887108579277992, ... \n", - "1 [-0.022305507212877274, 0.008543581701815128, ... \n", - "2 [-0.01019938476383686, 0.015277703292667866, -... \n", - "3 [-0.008471488021314144, 0.004098685923963785, ... \n", - "4 [-0.003459023078903556, 0.004626389592885971, ... \n", - "\n", - " babbage_search class_id \\\n", - "0 [-0.008217384107410908, 0.025170527398586273, ... 4 \n", - "1 [-0.020519884303212166, 0.01993306167423725, -... 4 \n", - "2 [-0.01843327097594738, 0.03343546763062477, -0... 4 \n", - "3 [-0.012966590002179146, 0.01299362163990736, 0... 4 \n", - "4 [-0.0010945454705506563, 0.008626140654087067,... 4 \n", - "\n", - " prompt \n", - "0 Supplier: Creative Video Productions Ltd; Desc... \n", - "1 Supplier: NLS Foundation; Description: Grant P... \n", - "2 Supplier: British Library; Description: Legal ... \n", - "3 Supplier: ALDL; Description: Legal Deposit Ser... \n", - "4 Supplier: AM Phillip; Description: Vehicle Pur... " + " messages class\n", + "0 [{'role': 'user', 'content': 'You are a data e... Other\n", + "1 [{'role': 'user', 'content': 'You are a data e... Building Improvement\n", + "2 [{'role': 'user', 'content': 'You are a data e... Building Improvement\n", + "3 [{'role': 'user', 'content': 'You are a data e... Building Improvement\n", + "4 [{'role': 'user', 'content': 'You are a data e... Building Improvement" ] }, - "execution_count": 215, + "execution_count": 181, "metadata": {}, "output_type": "execute_result" } @@ -1553,110 +1323,40 @@ "source": [ "ft_df_with_class = ft_prep_df.merge(class_df,left_on='Classification',right_on='class',how='inner')\n", "\n", - "# Adding a leading whitespace onto each completion to help the model\n", - "ft_df_with_class['class_id'] = ft_df_with_class.apply(lambda x: ' ' + str(x['class_id']),axis=1)\n", - "ft_df_with_class = ft_df_with_class.drop('class', axis=1)\n", - "\n", - "# Adding a common separator onto the end of each prompt so the model knows when a prompt is terminating\n", - "ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\\n\\n###\\n\\n',axis=1)\n", - "ft_df_with_class.head()\n" + "# Creating a list of messages for the fine-tuning job. The user message is the prompt, and the assistant message is the response from the model\n", + "ft_df_with_class['messages'] = ft_df_with_class.apply(lambda x: [{\"role\": \"user\", \"content\": format_prompt(x)}, {\"role\": \"assistant\", \"content\": x['class']}],axis=1)\n", + "ft_df_with_class[['messages', 'class']].head()\n" ] }, { "cell_type": "code", - "execution_count": 236, + "execution_count": 169, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
promptcompletion
ordering
0Supplier: Sothebys; Description: Literary & Ar...0
1Supplier: Sotheby'S; Description: Literary & A...0
2Supplier: City Of Edinburgh Council; Descripti...1
2Supplier: John Graham Construction Ltd; Descri...2
3Supplier: John Graham Construction Ltd; Descri...2
\n", - "
" - ], - "text/plain": [ - " prompt completion\n", - "ordering \n", - "0 Supplier: Sothebys; Description: Literary & Ar... 0\n", - "1 Supplier: Sotheby'S; Description: Literary & A... 0\n", - "2 Supplier: City Of Edinburgh Council; Descripti... 1\n", - "2 Supplier: John Graham Construction Ltd; Descri... 2\n", - "3 Supplier: John Graham Construction Ltd; Descri... 2" - ] - }, - "execution_count": 236, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "# This step is unnecessary if you have a number of observations in each class\n", - "# In our case we don't, so we shuffle the data to give us a better chance of getting equal classes in our train and validation sets\n", - "# Our fine-tuned model will error if we have less classes in the validation set, so this is a necessary step\n", + "# Create train/validation split\n", + "samples = ft_df_with_class[\"messages\"].tolist()\n", + "train_df, valid_df = train_test_split(samples, test_size=0.2, random_state=42)\n", "\n", - "import random\n", - "\n", - "labels = [x for x in ft_df_with_class['class_id']]\n", - "text = [x for x in ft_df_with_class['prompt']]\n", - "ft_df = pd.DataFrame(zip(text, labels), columns = ['prompt','class_id']) #[:300]\n", - "ft_df.columns = ['prompt','completion']\n", - "ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)\n", - "ft_df.set_index('ordering',inplace=True)\n", - "ft_df_sorted = ft_df.sort_index(ascending=True)\n", - "ft_df_sorted.head()\n" + "def write_to_jsonl(list_of_messages, filename):\n", + " with open(filename, \"w+\") as f:\n", + " for messages in list_of_messages:\n", + " object = { \n", + " \"messages\": messages\n", + " }\n", + " f.write(json.dumps(object) + \"\\n\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 186, + "metadata": {}, + "outputs": [], + "source": [ + "# Write the train/validation split to jsonl files\n", + "train_file_name, valid_file_name = \"transactions_grouped_train.jsonl\", \"transactions_grouped_valid.jsonl\"\n", + "write_to_jsonl(train_df, train_file_name)\n", + "write_to_jsonl(valid_df, valid_file_name)\n" ] }, { @@ -1665,57 +1365,40 @@ "metadata": {}, "outputs": [], "source": [ - "# This step is to remove any existing files if we've already produced training/validation sets for this classifier\n", - "#!rm transactions_grouped*\n", - "\n", - "# We output our shuffled dataframe to a .jsonl file and run the prepare_data function to get us our input files\n", - "ft_df_sorted.to_json(\"transactions_grouped.jsonl\", orient='records', lines=True)\n", - "!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q\n" + "# Upload the files to OpenAI\n", + "train_file = client.files.create(file=open(train_file_name, \"rb\"), purpose=\"fine-tune\")\n", + "valid_file = client.files.create(file=open(valid_file_name, \"rb\"), purpose=\"fine-tune\")" ] }, { "cell_type": "code", - "execution_count": 322, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the fine-tuning job\n", + "fine_tuning_job = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=valid_file.id, model=\"gpt-4o-2024-08-06\")\n", + "# Get the fine-tuning job status and model name\n", + "status = client.fine_tuning.jobs.retrieve(fine_tuning_job.id)" + ] + }, + { + "cell_type": "code", + "execution_count": 209, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "31\n", - "8\n", - "All good\n" + "Fine tuned model id: ft:gpt-4o-2024-08-06:openai::BKr3Xy8U\n" ] } ], "source": [ - "# This functions checks that your classes all appear in both prepared files\n", - "# If they don't, the fine-tuned model creation will fail\n", - "check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# This step creates your model\n", - "!openai api fine_tunes.create -t \"transactions_grouped_prepared_train.jsonl\" -v \"transactions_grouped_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5 -m curie\n", - "\n", - "# You can use following command to get fine tuning job status and model name, replace the job name with your job\n", - "#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx\n" - ] - }, - { - "cell_type": "code", - "execution_count": 323, - "metadata": {}, - "outputs": [], - "source": [ - "# Congrats, you've got a fine-tuned model!\n", - "# Copy/paste the name provided into the variable below and we'll take it for a spin\n", - "fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'\n" + "# Once the fine-tuning job is complete, you can retrieve the model name from the job status\n", + "fine_tuned_model = client.fine_tuning.jobs.retrieve(fine_tuning_job.id).fine_tuned_model\n", + "print(f\"Fine tuned model id: {fine_tuned_model}\")" ] }, { @@ -1730,7 +1413,7 @@ }, { "cell_type": "code", - "execution_count": 324, + "execution_count": 210, "metadata": {}, "outputs": [ { @@ -1754,113 +1437,64 @@ " \n", " \n", " \n", - " prompt\n", - " completion\n", + " messages\n", + " expected_class\n", " \n", " \n", " \n", " \n", " 0\n", - " Supplier: Wavetek Ltd; Description: Kelvin Hal...\n", - " 2\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Utility Bills\n", " \n", " \n", " 1\n", - " Supplier: ECG Facilities Service; Description:...\n", - " 1\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Literature & Archive\n", " \n", " \n", " 2\n", - " Supplier: M & J Ballantyne Ltd; Description: G...\n", - " 2\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Literature & Archive\n", " \n", " \n", " 3\n", - " Supplier: Private Sale; Description: Literary ...\n", - " 0\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Literature & Archive\n", " \n", " \n", " 4\n", - " Supplier: Ex Libris; Description: IT equipment...\n", - " 3\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Building Improvement\n", " \n", " \n", "\n", "" ], "text/plain": [ - " prompt completion\n", - "0 Supplier: Wavetek Ltd; Description: Kelvin Hal... 2\n", - "1 Supplier: ECG Facilities Service; Description:... 1\n", - "2 Supplier: M & J Ballantyne Ltd; Description: G... 2\n", - "3 Supplier: Private Sale; Description: Literary ... 0\n", - "4 Supplier: Ex Libris; Description: IT equipment... 3" + " messages expected_class\n", + "0 [{'role': 'user', 'content': 'You are a data e... Utility Bills\n", + "1 [{'role': 'user', 'content': 'You are a data e... Literature & Archive\n", + "2 [{'role': 'user', 'content': 'You are a data e... Literature & Archive\n", + "3 [{'role': 'user', 'content': 'You are a data e... Literature & Archive\n", + "4 [{'role': 'user', 'content': 'You are a data e... Building Improvement" ] }, - "execution_count": 324, + "execution_count": 210, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)\n", - "test_set.head()\n" + "# Create a test set with the expected class labels\n", + "test_set = pd.read_json(valid_file_name, lines=True)\n", + "test_set['expected_class'] = test_set.apply(lambda x: x['messages'][-1]['content'], axis=1)\n", + "test_set.head()" ] }, { "cell_type": "code", - "execution_count": 325, - "metadata": {}, - "outputs": [], - "source": [ - "test_set['predicted_class'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n", - "test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 326, - "metadata": {}, - "outputs": [], - "source": [ - "test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 327, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True 4\n", - "False 4\n", - "Name: result, dtype: int64" - ] - }, - "execution_count": 327, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_set['result'].value_counts()\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Performance is not great - unfortunately this is expected. With only a few examples of each class, the above approach with embeddings and a traditional classifier worked better.\n", - "\n", - "A fine-tuned model works best with a great number of labelled observations. If we had a few hundred or thousand we may get better results, but lets do one last test on a holdout set to confirm that it doesn't generalise well to a new set of observations" - ] - }, - { - "cell_type": "code", - "execution_count": 330, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1884,312 +1518,115 @@ " \n", " \n", " \n", - " Date\n", - " Supplier\n", - " Description\n", - " Transaction value (£)\n", + " messages\n", + " expected_class\n", + " response\n", + " predicted_class\n", " \n", " \n", " \n", " \n", - " 101\n", - " 23/10/2017\n", - " City Building LLP\n", - " Causewayside Refurbishment\n", - " 53147.0\n", + " 0\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Utility Bills\n", + " ChatCompletion(id='chatcmpl-BKrC0S1wQSfM9ZQfcC...\n", + " Utility Bills\n", " \n", " \n", - " 102\n", - " 30/10/2017\n", - " ECG Facilities Service\n", - " Facilities Management Charge\n", - " 35758.0\n", + " 1\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Literature & Archive\n", + " ChatCompletion(id='chatcmpl-BKrC1BTr0DagbDkC2s...\n", + " Literature & Archive\n", " \n", " \n", - " 103\n", - " 30/10/2017\n", - " ECG Facilities Service\n", - " Facilities Management Charge\n", - " 35758.0\n", + " 2\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Literature & Archive\n", + " ChatCompletion(id='chatcmpl-BKrC1H3ZeIW5cz2Owr...\n", + " Literature & Archive\n", " \n", " \n", - " 104\n", - " 06/11/2017\n", - " John Graham Construction Ltd\n", - " Causewayside Refurbishment\n", - " 134208.0\n", + " 3\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Literature & Archive\n", + " ChatCompletion(id='chatcmpl-BKrC1wdhaMP0Q7YmYx...\n", + " Literature & Archive\n", " \n", " \n", - " 105\n", - " 06/11/2017\n", - " ALDL\n", - " Legal Deposit Services\n", - " 27067.0\n", + " 4\n", + " [{'role': 'user', 'content': 'You are a data e...\n", + " Building Improvement\n", + " ChatCompletion(id='chatcmpl-BKrC20c5pkpngy1xDu...\n", + " Building Improvement\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Date Supplier Description \\\n", - "101 23/10/2017 City Building LLP Causewayside Refurbishment \n", - "102 30/10/2017 ECG Facilities Service Facilities Management Charge \n", - "103 30/10/2017 ECG Facilities Service Facilities Management Charge \n", - "104 06/11/2017 John Graham Construction Ltd Causewayside Refurbishment \n", - "105 06/11/2017 ALDL Legal Deposit Services \n", + " messages expected_class \\\n", + "0 [{'role': 'user', 'content': 'You are a data e... Utility Bills \n", + "1 [{'role': 'user', 'content': 'You are a data e... Literature & Archive \n", + "2 [{'role': 'user', 'content': 'You are a data e... Literature & Archive \n", + "3 [{'role': 'user', 'content': 'You are a data e... Literature & Archive \n", + "4 [{'role': 'user', 'content': 'You are a data e... Building Improvement \n", "\n", - " Transaction value (£) \n", - "101 53147.0 \n", - "102 35758.0 \n", - "103 35758.0 \n", - "104 134208.0 \n", - "105 27067.0 " + " response predicted_class \n", + "0 ChatCompletion(id='chatcmpl-BKrC0S1wQSfM9ZQfcC... Utility Bills \n", + "1 ChatCompletion(id='chatcmpl-BKrC1BTr0DagbDkC2s... Literature & Archive \n", + "2 ChatCompletion(id='chatcmpl-BKrC1H3ZeIW5cz2Owr... Literature & Archive \n", + "3 ChatCompletion(id='chatcmpl-BKrC1wdhaMP0Q7YmYx... Literature & Archive \n", + "4 ChatCompletion(id='chatcmpl-BKrC20c5pkpngy1xDu... Building Improvement " ] }, - "execution_count": 330, + "execution_count": 211, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "holdout_df = transactions.copy().iloc[101:]\n", - "holdout_df.head()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 332, - "metadata": {}, - "outputs": [], - "source": [ - "holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n", - "holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n", - "holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 333, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
DateSupplierDescriptionTransaction value (£)combinedprediction_resultpred
10123/10/2017City Building LLPCausewayside Refurbishment53147.0Supplier: City Building LLP; Description: Caus...{'id': 'cmpl-63YDadbYLo8xKsGY2vReOFCMgTOvG', '...2
10230/10/2017ECG Facilities ServiceFacilities Management Charge35758.0Supplier: ECG Facilities Service; Description:...{'id': 'cmpl-63YDbNK1D7UikDc3xi5ATihg5kQEt', '...2
10330/10/2017ECG Facilities ServiceFacilities Management Charge35758.0Supplier: ECG Facilities Service; Description:...{'id': 'cmpl-63YDbwfiHjkjMWsfTKNt6naeqPzOe', '...2
10406/11/2017John Graham Construction LtdCausewayside Refurbishment134208.0Supplier: John Graham Construction Ltd; Descri...{'id': 'cmpl-63YDbWAndtsRqPTi2ZHZtPodZvOwr', '...2
10506/11/2017ALDLLegal Deposit Services27067.0Supplier: ALDL; Description: Legal Deposit Ser...{'id': 'cmpl-63YDbDu7WM3svYWsRAMdDUKtSFDBu', '...2
10627/11/2017Maggs Bros LtdLiterary & Archival Items26500.0Supplier: Maggs Bros Ltd; Description: Literar...{'id': 'cmpl-63YDbxNNI8ZH5CJJNxQ0IF9Zf925C', '...0
10730/11/2017Glasgow City CouncilKelvin Hall42345.0Supplier: Glasgow City Council; Description: K...{'id': 'cmpl-63YDb8R1FWu4bjwM2xE775rouwneV', '...2
10811/12/2017ECG Facilities ServiceFacilities Management Charge35758.0Supplier: ECG Facilities Service; Description:...{'id': 'cmpl-63YDcAPsp37WhbPs9kwfUX0kBk7Hv', '...2
10911/12/2017John Graham Construction LtdCausewayside Refurbishment159275.0Supplier: John Graham Construction Ltd; Descri...{'id': 'cmpl-63YDcML2welrC3wF0nuKgcNmVu1oQ', '...2
11008/01/2018ECG Facilities ServiceFacilities Management Charge35758.0Supplier: ECG Facilities Service; Description:...{'id': 'cmpl-63YDc95SSdOHnIliFB2cjMEEm7Z2u', '...2
\n", - "
" - ], - "text/plain": [ - " Date Supplier Description \\\n", - "101 23/10/2017 City Building LLP Causewayside Refurbishment \n", - "102 30/10/2017 ECG Facilities Service Facilities Management Charge \n", - "103 30/10/2017 ECG Facilities Service Facilities Management Charge \n", - "104 06/11/2017 John Graham Construction Ltd Causewayside Refurbishment \n", - "105 06/11/2017 ALDL Legal Deposit Services \n", - "106 27/11/2017 Maggs Bros Ltd Literary & Archival Items \n", - "107 30/11/2017 Glasgow City Council Kelvin Hall \n", - "108 11/12/2017 ECG Facilities Service Facilities Management Charge \n", - "109 11/12/2017 John Graham Construction Ltd Causewayside Refurbishment \n", - "110 08/01/2018 ECG Facilities Service Facilities Management Charge \n", - "\n", - " Transaction value (£) combined \\\n", - "101 53147.0 Supplier: City Building LLP; Description: Caus... \n", - "102 35758.0 Supplier: ECG Facilities Service; Description:... \n", - "103 35758.0 Supplier: ECG Facilities Service; Description:... \n", - "104 134208.0 Supplier: John Graham Construction Ltd; Descri... \n", - "105 27067.0 Supplier: ALDL; Description: Legal Deposit Ser... \n", - "106 26500.0 Supplier: Maggs Bros Ltd; Description: Literar... \n", - "107 42345.0 Supplier: Glasgow City Council; Description: K... \n", - "108 35758.0 Supplier: ECG Facilities Service; Description:... \n", - "109 159275.0 Supplier: John Graham Construction Ltd; Descri... \n", - "110 35758.0 Supplier: ECG Facilities Service; Description:... \n", - "\n", - " prediction_result pred \n", - "101 {'id': 'cmpl-63YDadbYLo8xKsGY2vReOFCMgTOvG', '... 2 \n", - "102 {'id': 'cmpl-63YDbNK1D7UikDc3xi5ATihg5kQEt', '... 2 \n", - "103 {'id': 'cmpl-63YDbwfiHjkjMWsfTKNt6naeqPzOe', '... 2 \n", - "104 {'id': 'cmpl-63YDbWAndtsRqPTi2ZHZtPodZvOwr', '... 2 \n", - "105 {'id': 'cmpl-63YDbDu7WM3svYWsRAMdDUKtSFDBu', '... 2 \n", - "106 {'id': 'cmpl-63YDbxNNI8ZH5CJJNxQ0IF9Zf925C', '... 0 \n", - "107 {'id': 'cmpl-63YDb8R1FWu4bjwM2xE775rouwneV', '... 2 \n", - "108 {'id': 'cmpl-63YDcAPsp37WhbPs9kwfUX0kBk7Hv', '... 2 \n", - "109 {'id': 'cmpl-63YDcML2welrC3wF0nuKgcNmVu1oQ', '... 2 \n", - "110 {'id': 'cmpl-63YDc95SSdOHnIliFB2cjMEEm7Z2u', '... 2 " - ] - }, - "execution_count": 333, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "holdout_df.head(10)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 334, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " 2 231\n", - " 0 27\n", - "Name: pred, dtype: int64" - ] - }, - "execution_count": 334, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "holdout_df['pred'].value_counts()\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Well those results were similarly underwhelming - so we've learned that with a dataset with a small number of labelled observations, either zero-shot classification or traditional classification with embeddings return better results than a fine-tuned model.\n", + "# Apply the fine-tuned model to the test set\n", + "test_set['response'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, messages=x['messages'][:-1], temperature=0),axis=1)\n", + "test_set['predicted_class'] = test_set.apply(lambda x: x['response'].choices[0].message.content, axis=1)\n", "\n", - "A fine-tuned model is still a great tool, but is more effective when you have a larger number of labelled examples for each class that you're looking to classify" + "test_set.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 212, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "result\n", + "True 20\n", + "False 1\n", + "Name: count, dtype: int64\n", + "F1 Score: 0.9296066252587991\n", + "Raw Accuracy: 0.9523809523809523\n" + ] + } + ], + "source": [ + "# Calculate the accuracy of the predictions\n", + "from sklearn.metrics import f1_score\n", + "test_set['result'] = test_set.apply(lambda x: str(x['predicted_class']).strip() == str(x['expected_class']).strip(), axis = 1)\n", + "test_set['result'].value_counts()\n", + "\n", + "print(test_set['result'].value_counts())\n", + "\n", + "print(\"F1 Score: \", f1_score(test_set['expected_class'], test_set['predicted_class'], average=\"weighted\"))\n", + "print(\"Raw Accuracy: \", test_set['result'].value_counts()[True] / len(test_set))\n" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "cookbook_env", "language": "python", "name": "python3" }, @@ -2203,7 +1640,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.8" } }, "nbformat": 4,