mirror of
https://github.com/james-m-jordan/openai-cookbook.git
synced 2025-05-09 19:32:38 +00:00
Added embeddings_utils.py to utils directory, and updated references (#841)
This commit is contained in:
parent
ac4cb3b163
commit
7044eecb98
File diff suppressed because one or more lines are too long
@ -34,7 +34,7 @@
|
|||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import tiktoken\n",
|
"import tiktoken\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from openai.embeddings_utils import get_embedding\n"
|
"from utils.embeddings_utils import get_embedding\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -212,7 +212,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.9 (main, Dec 7 2021, 18:04:56) \n[Clang 13.0.0 (clang-1300.0.29.3)]"
|
"version": "3.9.16"
|
||||||
},
|
},
|
||||||
"orig_nbformat": 4,
|
"orig_nbformat": 4,
|
||||||
"vscode": {
|
"vscode": {
|
||||||
|
@ -30,8 +30,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%load_ext autoreload\n",
|
"%load_ext autoreload\n",
|
||||||
"%autoreload \n",
|
"%autoreload\n",
|
||||||
"%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers"
|
"%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -47,7 +47,7 @@
|
|||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
|
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
|
||||||
"COMPLETIONS_MODEL = \"text-davinci-002\""
|
"COMPLETIONS_MODEL = \"text-davinci-002\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -85,7 +85,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')\n",
|
"transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')\n",
|
||||||
"len(transactions)"
|
"len(transactions)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -182,7 +182,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"transactions.head()"
|
"transactions.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -192,7 +192,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def request_completion(prompt):\n",
|
"def request_completion(prompt):\n",
|
||||||
" \n",
|
"\n",
|
||||||
" completion_response = openai.Completion.create(\n",
|
" completion_response = openai.Completion.create(\n",
|
||||||
" prompt=prompt,\n",
|
" prompt=prompt,\n",
|
||||||
" temperature=0,\n",
|
" temperature=0,\n",
|
||||||
@ -202,17 +202,17 @@
|
|||||||
" presence_penalty=0,\n",
|
" presence_penalty=0,\n",
|
||||||
" model=COMPLETIONS_MODEL\n",
|
" model=COMPLETIONS_MODEL\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" \n",
|
"\n",
|
||||||
" return completion_response\n",
|
" return completion_response\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def classify_transaction(transaction,prompt):\n",
|
"def classify_transaction(transaction,prompt):\n",
|
||||||
" \n",
|
"\n",
|
||||||
" prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n",
|
" prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n",
|
||||||
" prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n",
|
" prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n",
|
||||||
" prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n",
|
" prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n",
|
||||||
" \n",
|
"\n",
|
||||||
" classification = request_completion(prompt)['choices'][0]['text'].replace('\\n','')\n",
|
" classification = request_completion(prompt)['choices'][0]['text'].replace('\\n','')\n",
|
||||||
" \n",
|
"\n",
|
||||||
" return classification\n",
|
" return classification\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# This function takes your training and validation outputs from the prepare_data function of the Finetuning API, and\n",
|
"# This function takes your training and validation outputs from the prepare_data function of the Finetuning API, and\n",
|
||||||
@ -242,12 +242,12 @@
|
|||||||
" valid_classes.add(result['completion'])\n",
|
" valid_classes.add(result['completion'])\n",
|
||||||
" #print(f\"result: {result['completion']}\")\n",
|
" #print(f\"result: {result['completion']}\")\n",
|
||||||
" #print(isinstance(result, dict))\n",
|
" #print(isinstance(result, dict))\n",
|
||||||
" \n",
|
"\n",
|
||||||
" if len(train_classes) == len(valid_classes):\n",
|
" if len(train_classes) == len(valid_classes):\n",
|
||||||
" print('All good')\n",
|
" print('All good')\n",
|
||||||
" \n",
|
"\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" print('Classes do not match, please prepare data again')"
|
" print('Classes do not match, please prepare data again')\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -266,18 +266,18 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"zero_shot_prompt = '''You are a data expert working for the National Library of Scotland. \n",
|
"zero_shot_prompt = '''You are a data expert working for the National Library of Scotland.\n",
|
||||||
"You are analysing all transactions over £25,000 in value and classifying them into one of five categories.\n",
|
"You are analysing all transactions over £25,000 in value and classifying them into one of five categories.\n",
|
||||||
"The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.\n",
|
"The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.\n",
|
||||||
"If you can't tell what it is, say Could not classify\n",
|
"If you can't tell what it is, say Could not classify\n",
|
||||||
" \n",
|
"\n",
|
||||||
"Transaction:\n",
|
"Transaction:\n",
|
||||||
" \n",
|
"\n",
|
||||||
"Supplier: SUPPLIER_NAME\n",
|
"Supplier: SUPPLIER_NAME\n",
|
||||||
"Description: DESCRIPTION_TEXT\n",
|
"Description: DESCRIPTION_TEXT\n",
|
||||||
"Value: TRANSACTION_VALUE\n",
|
"Value: TRANSACTION_VALUE\n",
|
||||||
" \n",
|
"\n",
|
||||||
"The classification is:'''"
|
"The classification is:'''\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -304,7 +304,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Use our completion function to return a prediction\n",
|
"# Use our completion function to return a prediction\n",
|
||||||
"completion_response = request_completion(prompt)\n",
|
"completion_response = request_completion(prompt)\n",
|
||||||
"print(completion_response['choices'][0]['text'])"
|
"print(completion_response['choices'][0]['text'])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -337,7 +337,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"test_transactions = transactions.iloc[:25]\n",
|
"test_transactions = transactions.iloc[:25]\n",
|
||||||
"test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)"
|
"test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -362,7 +362,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"test_transactions['Classification'].value_counts()"
|
"test_transactions['Classification'].value_counts()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -665,7 +665,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"test_transactions.head(25)"
|
"test_transactions.head(25)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -791,7 +791,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"df = pd.read_csv('./data/labelled_transactions.csv')\n",
|
"df = pd.read_csv('./data/labelled_transactions.csv')\n",
|
||||||
"df.head()"
|
"df.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -872,7 +872,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"df['combined'] = \"Supplier: \" + df['Supplier'].str.strip() + \"; Description: \" + df['Description'].str.strip() + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
|
"df['combined'] = \"Supplier: \" + df['Supplier'].str.strip() + \"; Description: \" + df['Description'].str.strip() + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
|
||||||
"df.head(2)"
|
"df.head(2)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -896,7 +896,7 @@
|
|||||||
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
|
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))\n",
|
"df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))\n",
|
||||||
"len(df)"
|
"len(df)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -905,7 +905,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"embedding_path = './data/transactions_with_embeddings_100.csv'"
|
"embedding_path = './data/transactions_with_embeddings_100.csv'\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -914,11 +914,11 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai.embeddings_utils import get_embedding\n",
|
"from utils.embeddings_utils import get_embedding\n",
|
||||||
"\n",
|
"\n",
|
||||||
"df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, engine='text-similarity-babbage-001'))\n",
|
"df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, engine='text-similarity-babbage-001'))\n",
|
||||||
"df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, engine='text-search-babbage-doc-001'))\n",
|
"df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, engine='text-search-babbage-doc-001'))\n",
|
||||||
"df.to_csv(embedding_path)"
|
"df.to_csv(embedding_path)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1091,7 +1091,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"fs_df = pd.read_csv(embedding_path)\n",
|
"fs_df = pd.read_csv(embedding_path)\n",
|
||||||
"fs_df[\"babbage_similarity\"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)\n",
|
"fs_df[\"babbage_similarity\"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)\n",
|
||||||
"fs_df.head()"
|
"fs_df.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1141,7 +1141,7 @@
|
|||||||
"probas = clf.predict_proba(X_test)\n",
|
"probas = clf.predict_proba(X_test)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"report = classification_report(y_test, preds)\n",
|
"report = classification_report(y_test, preds)\n",
|
||||||
"print(report)"
|
"print(report)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1195,7 +1195,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"ft_prep_df = fs_df.copy()\n",
|
"ft_prep_df = fs_df.copy()\n",
|
||||||
"len(ft_prep_df)"
|
"len(ft_prep_df)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1349,7 +1349,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"ft_prep_df.head()"
|
"ft_prep_df.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1378,7 +1378,7 @@
|
|||||||
"classes = list(set(ft_prep_df['Classification']))\n",
|
"classes = list(set(ft_prep_df['Classification']))\n",
|
||||||
"class_df = pd.DataFrame(classes).reset_index()\n",
|
"class_df = pd.DataFrame(classes).reset_index()\n",
|
||||||
"class_df.columns = ['class_id','class']\n",
|
"class_df.columns = ['class_id','class']\n",
|
||||||
"class_df , len(class_df)"
|
"class_df , len(class_df)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1559,7 +1559,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Adding a common separator onto the end of each prompt so the model knows when a prompt is terminating\n",
|
"# Adding a common separator onto the end of each prompt so the model knows when a prompt is terminating\n",
|
||||||
"ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\\n\\n###\\n\\n',axis=1)\n",
|
"ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\\n\\n###\\n\\n',axis=1)\n",
|
||||||
"ft_df_with_class.head()"
|
"ft_df_with_class.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1647,7 +1647,7 @@
|
|||||||
"# In our case we don't, so we shuffle the data to give us a better chance of getting equal classes in our train and validation sets\n",
|
"# In our case we don't, so we shuffle the data to give us a better chance of getting equal classes in our train and validation sets\n",
|
||||||
"# Our fine-tuned model will error if we have less classes in the validation set, so this is a necessary step\n",
|
"# Our fine-tuned model will error if we have less classes in the validation set, so this is a necessary step\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import random \n",
|
"import random\n",
|
||||||
"\n",
|
"\n",
|
||||||
"labels = [x for x in ft_df_with_class['class_id']]\n",
|
"labels = [x for x in ft_df_with_class['class_id']]\n",
|
||||||
"text = [x for x in ft_df_with_class['prompt']]\n",
|
"text = [x for x in ft_df_with_class['prompt']]\n",
|
||||||
@ -1656,7 +1656,7 @@
|
|||||||
"ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)\n",
|
"ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)\n",
|
||||||
"ft_df.set_index('ordering',inplace=True)\n",
|
"ft_df.set_index('ordering',inplace=True)\n",
|
||||||
"ft_df_sorted = ft_df.sort_index(ascending=True)\n",
|
"ft_df_sorted = ft_df.sort_index(ascending=True)\n",
|
||||||
"ft_df_sorted.head()"
|
"ft_df_sorted.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1670,7 +1670,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# We output our shuffled dataframe to a .jsonl file and run the prepare_data function to get us our input files\n",
|
"# We output our shuffled dataframe to a .jsonl file and run the prepare_data function to get us our input files\n",
|
||||||
"ft_df_sorted.to_json(\"transactions_grouped.jsonl\", orient='records', lines=True)\n",
|
"ft_df_sorted.to_json(\"transactions_grouped.jsonl\", orient='records', lines=True)\n",
|
||||||
"!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q"
|
"!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1691,7 +1691,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# This functions checks that your classes all appear in both prepared files\n",
|
"# This functions checks that your classes all appear in both prepared files\n",
|
||||||
"# If they don't, the fine-tuned model creation will fail\n",
|
"# If they don't, the fine-tuned model creation will fail\n",
|
||||||
"check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')"
|
"check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1704,7 +1704,7 @@
|
|||||||
"!openai api fine_tunes.create -t \"transactions_grouped_prepared_train.jsonl\" -v \"transactions_grouped_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5 -m curie\n",
|
"!openai api fine_tunes.create -t \"transactions_grouped_prepared_train.jsonl\" -v \"transactions_grouped_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5 -m curie\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# You can use following command to get fine tuning job status and model name, replace the job name with your job\n",
|
"# You can use following command to get fine tuning job status and model name, replace the job name with your job\n",
|
||||||
"#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx"
|
"#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1715,7 +1715,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Congrats, you've got a fine-tuned model!\n",
|
"# Congrats, you've got a fine-tuned model!\n",
|
||||||
"# Copy/paste the name provided into the variable below and we'll take it for a spin\n",
|
"# Copy/paste the name provided into the variable below and we'll take it for a spin\n",
|
||||||
"fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'"
|
"fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1804,7 +1804,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)\n",
|
"test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)\n",
|
||||||
"test_set.head()"
|
"test_set.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1814,7 +1814,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"test_set['predicted_class'] = test_set.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
|
"test_set['predicted_class'] = test_set.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
|
||||||
"test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)"
|
"test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1823,7 +1823,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)"
|
"test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1845,7 +1845,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"test_set['result'].value_counts()"
|
"test_set['result'].value_counts()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1953,7 +1953,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"holdout_df = transactions.copy().iloc[101:]\n",
|
"holdout_df = transactions.copy().iloc[101:]\n",
|
||||||
"holdout_df.head()"
|
"holdout_df.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1964,7 +1964,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
|
"holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
|
||||||
"holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
|
"holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
|
||||||
"holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)"
|
"holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2151,7 +2151,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"holdout_df.head(10)"
|
"holdout_df.head(10)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2173,7 +2173,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"holdout_df['pred'].value_counts()"
|
"holdout_df['pred'].value_counts()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -37,7 +37,7 @@
|
|||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import pickle\n",
|
"import pickle\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from openai.embeddings_utils import (\n",
|
"from utils.embeddings_utils import (\n",
|
||||||
" get_embedding,\n",
|
" get_embedding,\n",
|
||||||
" distances_from_embeddings,\n",
|
" distances_from_embeddings,\n",
|
||||||
" tsne_components_from_embeddings,\n",
|
" tsne_components_from_embeddings,\n",
|
||||||
@ -46,7 +46,7 @@
|
|||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# constants\n",
|
"# constants\n",
|
||||||
"EMBEDDING_MODEL = \"text-embedding-ada-002\""
|
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -158,7 +158,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# print dataframe\n",
|
"# print dataframe\n",
|
||||||
"n_examples = 5\n",
|
"n_examples = 5\n",
|
||||||
"df.head(n_examples)"
|
"df.head(n_examples)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -206,7 +206,7 @@
|
|||||||
" print(\"\")\n",
|
" print(\"\")\n",
|
||||||
" print(f\"Title: {row['title']}\")\n",
|
" print(f\"Title: {row['title']}\")\n",
|
||||||
" print(f\"Description: {row['description']}\")\n",
|
" print(f\"Description: {row['description']}\")\n",
|
||||||
" print(f\"Label: {row['label']}\")"
|
" print(f\"Label: {row['label']}\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -252,7 +252,7 @@
|
|||||||
" embedding_cache[(string, model)] = get_embedding(string, model)\n",
|
" embedding_cache[(string, model)] = get_embedding(string, model)\n",
|
||||||
" with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
|
" with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
|
||||||
" pickle.dump(embedding_cache, embedding_cache_file)\n",
|
" pickle.dump(embedding_cache, embedding_cache_file)\n",
|
||||||
" return embedding_cache[(string, model)]"
|
" return embedding_cache[(string, model)]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -285,7 +285,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# print the first 10 dimensions of the embedding\n",
|
"# print the first 10 dimensions of the embedding\n",
|
||||||
"example_embedding = embedding_from_string(example_string)\n",
|
"example_embedding = embedding_from_string(example_string)\n",
|
||||||
"print(f\"\\nExample embedding: {example_embedding[:10]}...\")"
|
"print(f\"\\nExample embedding: {example_embedding[:10]}...\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -317,9 +317,9 @@
|
|||||||
" embeddings = [embedding_from_string(string, model=model) for string in strings]\n",
|
" embeddings = [embedding_from_string(string, model=model) for string in strings]\n",
|
||||||
" # get the embedding of the source string\n",
|
" # get the embedding of the source string\n",
|
||||||
" query_embedding = embeddings[index_of_source_string]\n",
|
" query_embedding = embeddings[index_of_source_string]\n",
|
||||||
" # get distances between the source embedding and other embeddings (function from embeddings_utils.py)\n",
|
" # get distances between the source embedding and other embeddings (function from utils.embeddings_utils.py)\n",
|
||||||
" distances = distances_from_embeddings(query_embedding, embeddings, distance_metric=\"cosine\")\n",
|
" distances = distances_from_embeddings(query_embedding, embeddings, distance_metric=\"cosine\")\n",
|
||||||
" # get indices of nearest neighbors (function from embeddings_utils.py)\n",
|
" # get indices of nearest neighbors (function from utils.utils.embeddings_utils.py)\n",
|
||||||
" indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)\n",
|
" indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # print out source string\n",
|
" # print out source string\n",
|
||||||
@ -344,7 +344,7 @@
|
|||||||
" Distance: {distances[i]:0.3f}\"\"\"\n",
|
" Distance: {distances[i]:0.3f}\"\"\"\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return indices_of_nearest_neighbors"
|
" return indices_of_nearest_neighbors\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -396,7 +396,7 @@
|
|||||||
" strings=article_descriptions, # let's base similarity off of the article description\n",
|
" strings=article_descriptions, # let's base similarity off of the article description\n",
|
||||||
" index_of_source_string=0, # let's look at articles similar to the first one about Tony Blair\n",
|
" index_of_source_string=0, # let's look at articles similar to the first one about Tony Blair\n",
|
||||||
" k_nearest_neighbors=5, # let's look at the 5 most similar articles\n",
|
" k_nearest_neighbors=5, # let's look at the 5 most similar articles\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -452,7 +452,7 @@
|
|||||||
" strings=article_descriptions, # let's base similarity off of the article description\n",
|
" strings=article_descriptions, # let's base similarity off of the article description\n",
|
||||||
" index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset\n",
|
" index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset\n",
|
||||||
" k_nearest_neighbors=5, # let's look at the 5 most similar articles\n",
|
" k_nearest_neighbors=5, # let's look at the 5 most similar articles\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -11456,7 +11456,7 @@
|
|||||||
" width=600,\n",
|
" width=600,\n",
|
||||||
" height=500,\n",
|
" height=500,\n",
|
||||||
" title=\"t-SNE components of article descriptions\",\n",
|
" title=\"t-SNE components of article descriptions\",\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -11498,7 +11498,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)\n",
|
"tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)\n",
|
||||||
"chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5\n",
|
"chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -22443,7 +22443,7 @@
|
|||||||
" height=500,\n",
|
" height=500,\n",
|
||||||
" title=\"Nearest neighbors of the Tony Blair article\",\n",
|
" title=\"Nearest neighbors of the Tony Blair article\",\n",
|
||||||
" category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n",
|
" category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -33396,7 +33396,7 @@
|
|||||||
" height=500,\n",
|
" height=500,\n",
|
||||||
" title=\"Nearest neighbors of the chipset security article\",\n",
|
" title=\"Nearest neighbors of the chipset security article\",\n",
|
||||||
" category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n",
|
" category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -53,7 +53,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai.embeddings_utils import get_embedding, cosine_similarity\n",
|
"from utils.embeddings_utils import get_embedding, cosine_similarity\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# search through the reviews for a specific product\n",
|
"# search through the reviews for a specific product\n",
|
||||||
"def search_reviews(df, product_description, n=3, pprint=True):\n",
|
"def search_reviews(df, product_description, n=3, pprint=True):\n",
|
||||||
@ -98,7 +98,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"results = search_reviews(df, \"whole wheat pasta\", n=3)"
|
"results = search_reviews(df, \"whole wheat pasta\", n=3)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -124,7 +124,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"results = search_reviews(df, \"bad delivery\", n=1)"
|
"results = search_reviews(df, \"bad delivery\", n=1)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -150,7 +150,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"results = search_reviews(df, \"spoilt\", n=1)"
|
"results = search_reviews(df, \"spoilt\", n=1)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -170,7 +170,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"results = search_reviews(df, \"pet food\", n=2)"
|
"results = search_reviews(df, \"pet food\", n=2)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -75,7 +75,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai.embeddings_utils import cosine_similarity\n",
|
"from utils.embeddings_utils import cosine_similarity\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# evaluate embeddings as recommendations on X_test\n",
|
"# evaluate embeddings as recommendations on X_test\n",
|
||||||
"def evaluate_single_match(row):\n",
|
"def evaluate_single_match(row):\n",
|
||||||
@ -140,7 +140,7 @@
|
|||||||
"X_test.boxplot(column='percentile_cosine_similarity', by='Score')\n",
|
"X_test.boxplot(column='percentile_cosine_similarity', by='Score')\n",
|
||||||
"plt.title('')\n",
|
"plt.title('')\n",
|
||||||
"plt.show()\n",
|
"plt.show()\n",
|
||||||
"plt.close()"
|
"plt.close()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -126,7 +126,7 @@
|
|||||||
"samples = pd.read_json(\"data/dbpedia_samples.jsonl\", lines=True)\n",
|
"samples = pd.read_json(\"data/dbpedia_samples.jsonl\", lines=True)\n",
|
||||||
"categories = sorted(samples[\"category\"].unique())\n",
|
"categories = sorted(samples[\"category\"].unique())\n",
|
||||||
"print(\"Categories of DBpedia samples:\", samples[\"category\"].value_counts())\n",
|
"print(\"Categories of DBpedia samples:\", samples[\"category\"].value_counts())\n",
|
||||||
"samples.head()"
|
"samples.head()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -136,9 +136,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai.embeddings_utils import get_embeddings\n",
|
"from utils.embeddings_utils import get_embeddings\n",
|
||||||
"# NOTE: The following code will send a query of batch size 200 to /embeddings\n",
|
"# NOTE: The following code will send a query of batch size 200 to /embeddings\n",
|
||||||
"matrix = get_embeddings(samples[\"text\"].to_list(), engine=\"text-embedding-ada-002\")"
|
"matrix = get_embeddings(samples[\"text\"].to_list(), engine=\"text-embedding-ada-002\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -159,7 +159,7 @@
|
|||||||
"from sklearn.decomposition import PCA\n",
|
"from sklearn.decomposition import PCA\n",
|
||||||
"pca = PCA(n_components=3)\n",
|
"pca = PCA(n_components=3)\n",
|
||||||
"vis_dims = pca.fit_transform(matrix)\n",
|
"vis_dims = pca.fit_transform(matrix)\n",
|
||||||
"samples[\"embed_vis\"] = vis_dims.tolist()"
|
"samples[\"embed_vis\"] = vis_dims.tolist()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -233,7 +233,7 @@
|
|||||||
"ax.set_xlabel('x')\n",
|
"ax.set_xlabel('x')\n",
|
||||||
"ax.set_ylabel('y')\n",
|
"ax.set_ylabel('y')\n",
|
||||||
"ax.set_zlabel('z')\n",
|
"ax.set_zlabel('z')\n",
|
||||||
"ax.legend(bbox_to_anchor=(1.1, 1))"
|
"ax.legend(bbox_to_anchor=(1.1, 1))\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -86,11 +86,11 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai.embeddings_utils import cosine_similarity, get_embedding\n",
|
"from utils.embeddings_utils import cosine_similarity, get_embedding\n",
|
||||||
"from sklearn.metrics import PrecisionRecallDisplay\n",
|
"from sklearn.metrics import PrecisionRecallDisplay\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def evaluate_embeddings_approach(\n",
|
"def evaluate_embeddings_approach(\n",
|
||||||
" labels = ['negative', 'positive'], \n",
|
" labels = ['negative', 'positive'],\n",
|
||||||
" model = EMBEDDING_MODEL,\n",
|
" model = EMBEDDING_MODEL,\n",
|
||||||
"):\n",
|
"):\n",
|
||||||
" label_embeddings = [get_embedding(label, engine=model) for label in labels]\n",
|
" label_embeddings = [get_embedding(label, engine=model) for label in labels]\n",
|
||||||
@ -107,7 +107,7 @@
|
|||||||
" display = PrecisionRecallDisplay.from_predictions(df.sentiment, probas, pos_label='positive')\n",
|
" display = PrecisionRecallDisplay.from_predictions(df.sentiment, probas, pos_label='positive')\n",
|
||||||
" _ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n",
|
" _ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"evaluate_embeddings_approach(labels=['negative', 'positive'], model=EMBEDDING_MODEL)"
|
"evaluate_embeddings_approach(labels=['negative', 'positive'], model=EMBEDDING_MODEL)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -152,7 +152,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])"
|
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -197,7 +197,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])"
|
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
252
examples/utils/embeddings_utils.py
Normal file
252
examples/utils/embeddings_utils.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
import textwrap as tr
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import plotly.express as px
|
||||||
|
from scipy import spatial
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
from sklearn.metrics import average_precision_score, precision_recall_curve
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||||
|
def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:
|
||||||
|
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
|
||||||
|
return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||||
|
async def aget_embedding(
|
||||||
|
text: str, engine="text-similarity-davinci-001", **kwargs
|
||||||
|
) -> List[float]:
|
||||||
|
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
|
||||||
|
return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][
|
||||||
|
"embedding"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||||
|
def get_embeddings(
|
||||||
|
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
||||||
|
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||||
|
|
||||||
|
data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data
|
||||||
|
return [d["embedding"] for d in data]
|
||||||
|
|
||||||
|
|
||||||
|
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||||
|
async def aget_embeddings(
|
||||||
|
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
||||||
|
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||||
|
|
||||||
|
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data
|
||||||
|
return [d["embedding"] for d in data]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||||
|
|
||||||
|
|
||||||
|
def plot_multiclass_precision_recall(
|
||||||
|
y_score, y_true_untransformed, class_list, classifier_name
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Precision-Recall plotting for a multiclass problem. It plots average precision-recall, per class precision recall and reference f1 contours.
|
||||||
|
|
||||||
|
Code slightly modified, but heavily based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
|
||||||
|
"""
|
||||||
|
n_classes = len(class_list)
|
||||||
|
y_true = pd.concat(
|
||||||
|
[(y_true_untransformed == class_list[i]) for i in range(n_classes)], axis=1
|
||||||
|
).values
|
||||||
|
|
||||||
|
# For each class
|
||||||
|
precision = dict()
|
||||||
|
recall = dict()
|
||||||
|
average_precision = dict()
|
||||||
|
for i in range(n_classes):
|
||||||
|
precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], y_score[:, i])
|
||||||
|
average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i])
|
||||||
|
|
||||||
|
# A "micro-average": quantifying score on all classes jointly
|
||||||
|
precision_micro, recall_micro, _ = precision_recall_curve(
|
||||||
|
y_true.ravel(), y_score.ravel()
|
||||||
|
)
|
||||||
|
average_precision_micro = average_precision_score(y_true, y_score, average="micro")
|
||||||
|
print(
|
||||||
|
str(classifier_name)
|
||||||
|
+ " - Average precision score over all classes: {0:0.2f}".format(
|
||||||
|
average_precision_micro
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup plot details
|
||||||
|
plt.figure(figsize=(9, 10))
|
||||||
|
f_scores = np.linspace(0.2, 0.8, num=4)
|
||||||
|
lines = []
|
||||||
|
labels = []
|
||||||
|
for f_score in f_scores:
|
||||||
|
x = np.linspace(0.01, 1)
|
||||||
|
y = f_score * x / (2 * x - f_score)
|
||||||
|
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
|
||||||
|
plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))
|
||||||
|
|
||||||
|
lines.append(l)
|
||||||
|
labels.append("iso-f1 curves")
|
||||||
|
(l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2)
|
||||||
|
lines.append(l)
|
||||||
|
labels.append(
|
||||||
|
"average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(n_classes):
|
||||||
|
(l,) = plt.plot(recall[i], precision[i], lw=2)
|
||||||
|
lines.append(l)
|
||||||
|
labels.append(
|
||||||
|
"Precision-recall for class `{0}` (auprc = {1:0.2f})"
|
||||||
|
"".format(class_list[i], average_precision[i])
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = plt.gcf()
|
||||||
|
fig.subplots_adjust(bottom=0.25)
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
plt.ylim([0.0, 1.05])
|
||||||
|
plt.xlabel("Recall")
|
||||||
|
plt.ylabel("Precision")
|
||||||
|
plt.title(f"{classifier_name}: Precision-Recall curve for each class")
|
||||||
|
plt.legend(lines, labels)
|
||||||
|
|
||||||
|
|
||||||
|
def distances_from_embeddings(
|
||||||
|
query_embedding: List[float],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
distance_metric="cosine",
|
||||||
|
) -> List[List]:
|
||||||
|
"""Return the distances between a query embedding and a list of embeddings."""
|
||||||
|
distance_metrics = {
|
||||||
|
"cosine": spatial.distance.cosine,
|
||||||
|
"L1": spatial.distance.cityblock,
|
||||||
|
"L2": spatial.distance.euclidean,
|
||||||
|
"Linf": spatial.distance.chebyshev,
|
||||||
|
}
|
||||||
|
distances = [
|
||||||
|
distance_metrics[distance_metric](query_embedding, embedding)
|
||||||
|
for embedding in embeddings
|
||||||
|
]
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def indices_of_nearest_neighbors_from_distances(distances) -> np.ndarray:
|
||||||
|
"""Return a list of indices of nearest neighbors from a list of distances."""
|
||||||
|
return np.argsort(distances)
|
||||||
|
|
||||||
|
|
||||||
|
def pca_components_from_embeddings(
|
||||||
|
embeddings: List[List[float]], n_components=2
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Return the PCA components of a list of embeddings."""
|
||||||
|
pca = PCA(n_components=n_components)
|
||||||
|
array_of_embeddings = np.array(embeddings)
|
||||||
|
return pca.fit_transform(array_of_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def tsne_components_from_embeddings(
|
||||||
|
embeddings: List[List[float]], n_components=2, **kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Returns t-SNE components of a list of embeddings."""
|
||||||
|
# use better defaults if not specified
|
||||||
|
if "init" not in kwargs.keys():
|
||||||
|
kwargs["init"] = "pca"
|
||||||
|
if "learning_rate" not in kwargs.keys():
|
||||||
|
kwargs["learning_rate"] = "auto"
|
||||||
|
tsne = TSNE(n_components=n_components, **kwargs)
|
||||||
|
array_of_embeddings = np.array(embeddings)
|
||||||
|
return tsne.fit_transform(array_of_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def chart_from_components(
|
||||||
|
components: np.ndarray,
|
||||||
|
labels: Optional[List[str]] = None,
|
||||||
|
strings: Optional[List[str]] = None,
|
||||||
|
x_title="Component 0",
|
||||||
|
y_title="Component 1",
|
||||||
|
mark_size=5,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Return an interactive 2D chart of embedding components."""
|
||||||
|
empty_list = ["" for _ in components]
|
||||||
|
data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
x_title: components[:, 0],
|
||||||
|
y_title: components[:, 1],
|
||||||
|
"label": labels if labels else empty_list,
|
||||||
|
"string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
|
||||||
|
if strings
|
||||||
|
else empty_list,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chart = px.scatter(
|
||||||
|
data,
|
||||||
|
x=x_title,
|
||||||
|
y=y_title,
|
||||||
|
color="label" if labels else None,
|
||||||
|
symbol="label" if labels else None,
|
||||||
|
hover_data=["string"] if strings else None,
|
||||||
|
**kwargs,
|
||||||
|
).update_traces(marker=dict(size=mark_size))
|
||||||
|
return chart
|
||||||
|
|
||||||
|
|
||||||
|
def chart_from_components_3D(
|
||||||
|
components: np.ndarray,
|
||||||
|
labels: Optional[List[str]] = None,
|
||||||
|
strings: Optional[List[str]] = None,
|
||||||
|
x_title: str = "Component 0",
|
||||||
|
y_title: str = "Component 1",
|
||||||
|
z_title: str = "Compontent 2",
|
||||||
|
mark_size: int = 5,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Return an interactive 3D chart of embedding components."""
|
||||||
|
empty_list = ["" for _ in components]
|
||||||
|
data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
x_title: components[:, 0],
|
||||||
|
y_title: components[:, 1],
|
||||||
|
z_title: components[:, 2],
|
||||||
|
"label": labels if labels else empty_list,
|
||||||
|
"string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
|
||||||
|
if strings
|
||||||
|
else empty_list,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chart = px.scatter_3d(
|
||||||
|
data,
|
||||||
|
x=x_title,
|
||||||
|
y=y_title,
|
||||||
|
z=z_title,
|
||||||
|
color="label" if labels else None,
|
||||||
|
symbol="label" if labels else None,
|
||||||
|
hover_data=["string"] if strings else None,
|
||||||
|
**kwargs,
|
||||||
|
).update_traces(marker=dict(size=mark_size))
|
||||||
|
return chart
|
@ -34,7 +34,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install openai --quiet"
|
"!pip install openai --quiet\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -48,7 +48,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# models\n",
|
"# models\n",
|
||||||
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
|
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
|
||||||
"GPT_MODEL = \"gpt-3.5-turbo\""
|
"GPT_MODEL = \"gpt-3.5-turbo\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -84,7 +84,7 @@
|
|||||||
" ]\n",
|
" ]\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(response['choices'][0]['message']['content'])"
|
"print(response['choices'][0]['message']['content'])\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -120,7 +120,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install matplotlib plotly.express scikit-learn tabulate tiktoken wget --quiet"
|
"!pip install matplotlib plotly.express scikit-learn tabulate tiktoken wget --quiet\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -133,7 +133,7 @@
|
|||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import wget\n",
|
"import wget\n",
|
||||||
"import ast"
|
"import ast\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -168,7 +168,7 @@
|
|||||||
" wget.download(embeddings_path, file_path)\n",
|
" wget.download(embeddings_path, file_path)\n",
|
||||||
" print(\"File downloaded successfully.\")\n",
|
" print(\"File downloaded successfully.\")\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" print(\"File already exists in the local file system.\")"
|
" print(\"File already exists in the local file system.\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -183,7 +183,7 @@
|
|||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# convert embeddings from CSV str type back to list type\n",
|
"# convert embeddings from CSV str type back to list type\n",
|
||||||
"df['embedding'] = df['embedding'].apply(ast.literal_eval)"
|
"df['embedding'] = df['embedding'].apply(ast.literal_eval)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -193,7 +193,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"df"
|
"df\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -219,7 +219,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"df.info(show_counts=True)"
|
"df.info(show_counts=True)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -241,7 +241,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"conn = s2.connect(\"<user>:<Password>@<host>:3306/\")\n",
|
"conn = s2.connect(\"<user>:<Password>@<host>:3306/\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"cur = conn.cursor()"
|
"cur = conn.cursor()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -267,7 +267,7 @@
|
|||||||
" CREATE DATABASE IF NOT EXISTS winter_wikipedia2;\n",
|
" CREATE DATABASE IF NOT EXISTS winter_wikipedia2;\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"cur.execute(stmt)"
|
"cur.execute(stmt)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -296,7 +296,7 @@
|
|||||||
" embedding BLOB\n",
|
" embedding BLOB\n",
|
||||||
");\"\"\"\n",
|
");\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"cur.execute(stmt)"
|
"cur.execute(stmt)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -349,7 +349,7 @@
|
|||||||
"for i in range(0, len(record_arr), batch_size):\n",
|
"for i in range(0, len(record_arr), batch_size):\n",
|
||||||
" batch = record_arr[i:i+batch_size]\n",
|
" batch = record_arr[i:i+batch_size]\n",
|
||||||
" values = [(row[0], row[1], str(row[2])) for row in batch]\n",
|
" values = [(row[0], row[1], str(row[2])) for row in batch]\n",
|
||||||
" cur.executemany(stmt, values)"
|
" cur.executemany(stmt, values)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -367,7 +367,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from openai.embeddings_utils import get_embedding\n",
|
"from utils.embeddings_utils import get_embedding\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def strings_ranked_by_relatedness(\n",
|
"def strings_ranked_by_relatedness(\n",
|
||||||
" query: str,\n",
|
" query: str,\n",
|
||||||
@ -395,7 +395,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Fetch the results\n",
|
" # Fetch the results\n",
|
||||||
" results = cur.fetchall()\n",
|
" results = cur.fetchall()\n",
|
||||||
" \n",
|
"\n",
|
||||||
" strings = []\n",
|
" strings = []\n",
|
||||||
" relatednesses = []\n",
|
" relatednesses = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -404,7 +404,7 @@
|
|||||||
" relatednesses.append(row[1])\n",
|
" relatednesses.append(row[1])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Return the results.\n",
|
" # Return the results.\n",
|
||||||
" return strings[:top_n], relatednesses[:top_n]"
|
" return strings[:top_n], relatednesses[:top_n]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -424,7 +424,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"for string, relatedness in zip(strings, relatednesses):\n",
|
"for string, relatedness in zip(strings, relatednesses):\n",
|
||||||
" print(f\"{relatedness=:.3f}\")\n",
|
" print(f\"{relatedness=:.3f}\")\n",
|
||||||
" print(tabulate([[string]], headers=['Result'], tablefmt='fancy_grid'))"
|
" print(tabulate([[string]], headers=['Result'], tablefmt='fancy_grid'))\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -494,7 +494,7 @@
|
|||||||
" temperature=0\n",
|
" temperature=0\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
|
||||||
" return response_message"
|
" return response_message\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -531,7 +531,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"answer = ask('Who won the gold medal for curling in Olymics 2022?')\n",
|
"answer = ask('Who won the gold medal for curling in Olymics 2022?')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pprint(answer)"
|
"pprint(answer)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -87,7 +87,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"! pip install redis pandas openai"
|
"! pip install redis pandas openai\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -303,7 +303,7 @@
|
|||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"from typing import List\n",
|
"from typing import List\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from openai.embeddings_utils import (\n",
|
"from utils.embeddings_utils import (\n",
|
||||||
" get_embeddings,\n",
|
" get_embeddings,\n",
|
||||||
" distances_from_embeddings,\n",
|
" distances_from_embeddings,\n",
|
||||||
" tsne_components_from_embeddings,\n",
|
" tsne_components_from_embeddings,\n",
|
||||||
@ -322,7 +322,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# print dataframe\n",
|
"# print dataframe\n",
|
||||||
"n_examples = 5\n",
|
"n_examples = 5\n",
|
||||||
"df.head(n_examples)"
|
"df.head(n_examples)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -360,7 +360,7 @@
|
|||||||
"df[\"product_text\"] = df.apply(lambda row: f\"name {row['productDisplayName']} category {row['masterCategory']} subcategory {row['subCategory']} color {row['baseColour']} gender {row['gender']}\".lower(), axis=1)\n",
|
"df[\"product_text\"] = df.apply(lambda row: f\"name {row['productDisplayName']} category {row['masterCategory']} subcategory {row['subCategory']} color {row['baseColour']} gender {row['gender']}\".lower(), axis=1)\n",
|
||||||
"df.rename({\"id\":\"product_id\"}, inplace=True, axis=1)\n",
|
"df.rename({\"id\":\"product_id\"}, inplace=True, axis=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"df.info()"
|
"df.info()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -382,7 +382,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# check out one of the texts we will use to create semantic embeddings\n",
|
"# check out one of the texts we will use to create semantic embeddings\n",
|
||||||
"df[\"product_text\"][0]"
|
"df[\"product_text\"][0]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -437,7 +437,7 @@
|
|||||||
" port=REDIS_PORT,\n",
|
" port=REDIS_PORT,\n",
|
||||||
" password=REDIS_PASSWORD\n",
|
" password=REDIS_PASSWORD\n",
|
||||||
")\n",
|
")\n",
|
||||||
"redis_client.ping()"
|
"redis_client.ping()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -467,7 +467,7 @@
|
|||||||
"INDEX_NAME = \"product_embeddings\" # name of the search index\n",
|
"INDEX_NAME = \"product_embeddings\" # name of the search index\n",
|
||||||
"PREFIX = \"doc\" # prefix for the document keys\n",
|
"PREFIX = \"doc\" # prefix for the document keys\n",
|
||||||
"DISTANCE_METRIC = \"L2\" # distance metric for the vectors (ex. COSINE, IP, L2)\n",
|
"DISTANCE_METRIC = \"L2\" # distance metric for the vectors (ex. COSINE, IP, L2)\n",
|
||||||
"NUMBER_OF_VECTORS = len(df)"
|
"NUMBER_OF_VECTORS = len(df)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -492,7 +492,7 @@
|
|||||||
" \"INITIAL_CAP\": NUMBER_OF_VECTORS,\n",
|
" \"INITIAL_CAP\": NUMBER_OF_VECTORS,\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
")\n",
|
")\n",
|
||||||
"fields = [name, category, articleType, gender, season, year, text_embedding]"
|
"fields = [name, category, articleType, gender, season, year, text_embedding]\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -511,7 +511,7 @@
|
|||||||
" redis_client.ft(INDEX_NAME).create_index(\n",
|
" redis_client.ft(INDEX_NAME).create_index(\n",
|
||||||
" fields = fields,\n",
|
" fields = fields,\n",
|
||||||
" definition = IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH)\n",
|
" definition = IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH)\n",
|
||||||
")"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -538,7 +538,7 @@
|
|||||||
" product_vectors = []\n",
|
" product_vectors = []\n",
|
||||||
" docs = []\n",
|
" docs = []\n",
|
||||||
" batchsize = 1000\n",
|
" batchsize = 1000\n",
|
||||||
" \n",
|
"\n",
|
||||||
" for idx,doc in enumerate(records,start=1):\n",
|
" for idx,doc in enumerate(records,start=1):\n",
|
||||||
" # create byte vectors\n",
|
" # create byte vectors\n",
|
||||||
" docs.append(doc[\"product_text\"])\n",
|
" docs.append(doc[\"product_text\"])\n",
|
||||||
@ -548,7 +548,7 @@
|
|||||||
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
|
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
|
||||||
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
|
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
|
||||||
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
|
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
|
||||||
" return product_vectors"
|
" return product_vectors\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -562,7 +562,7 @@
|
|||||||
" product_vectors = embeddings_batch_request(documents)\n",
|
" product_vectors = embeddings_batch_request(documents)\n",
|
||||||
" records = documents.to_dict(\"records\")\n",
|
" records = documents.to_dict(\"records\")\n",
|
||||||
" batchsize = 500\n",
|
" batchsize = 500\n",
|
||||||
" \n",
|
"\n",
|
||||||
" # Use Redis pipelines to batch calls and save on round trip network communication\n",
|
" # Use Redis pipelines to batch calls and save on round trip network communication\n",
|
||||||
" pipe = client.pipeline()\n",
|
" pipe = client.pipeline()\n",
|
||||||
" for idx,doc in enumerate(records,start=1):\n",
|
" for idx,doc in enumerate(records,start=1):\n",
|
||||||
@ -570,14 +570,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # create byte vectors\n",
|
" # create byte vectors\n",
|
||||||
" text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n",
|
" text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n",
|
||||||
" \n",
|
"\n",
|
||||||
" # replace list of floats with byte vectors\n",
|
" # replace list of floats with byte vectors\n",
|
||||||
" doc[\"product_vector\"] = text_embedding\n",
|
" doc[\"product_vector\"] = text_embedding\n",
|
||||||
"\n",
|
"\n",
|
||||||
" pipe.hset(key, mapping = doc)\n",
|
" pipe.hset(key, mapping = doc)\n",
|
||||||
" if idx % batchsize == 0:\n",
|
" if idx % batchsize == 0:\n",
|
||||||
" pipe.execute()\n",
|
" pipe.execute()\n",
|
||||||
" pipe.execute()"
|
" pipe.execute()\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -600,7 +600,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"%%time\n",
|
"%%time\n",
|
||||||
"index_documents(redis_client, PREFIX, df)\n",
|
"index_documents(redis_client, PREFIX, df)\n",
|
||||||
"print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")"
|
"print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -646,14 +646,14 @@
|
|||||||
" .dialect(2)\n",
|
" .dialect(2)\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" params_dict = {\"vector\": np.array(embedded_query).astype(dtype=np.float32).tobytes()}\n",
|
" params_dict = {\"vector\": np.array(embedded_query).astype(dtype=np.float32).tobytes()}\n",
|
||||||
" \n",
|
"\n",
|
||||||
" # perform vector search\n",
|
" # perform vector search\n",
|
||||||
" results = redis_client.ft(index_name).search(query, params_dict)\n",
|
" results = redis_client.ft(index_name).search(query, params_dict)\n",
|
||||||
" if print_results:\n",
|
" if print_results:\n",
|
||||||
" for i, product in enumerate(results.docs):\n",
|
" for i, product in enumerate(results.docs):\n",
|
||||||
" score = 1 - float(product.vector_score)\n",
|
" score = 1 - float(product.vector_score)\n",
|
||||||
" print(f\"{i}. {product.productDisplayName} (Score: {round(score ,3) })\")\n",
|
" print(f\"{i}. {product.productDisplayName} (Score: {round(score ,3) })\")\n",
|
||||||
" return results.docs"
|
" return results.docs\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -681,7 +681,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Execute a simple vector search in Redis\n",
|
"# Execute a simple vector search in Redis\n",
|
||||||
"results = search_redis(redis_client, 'man blue jeans', k=10)"
|
"results = search_redis(redis_client, 'man blue jeans', k=10)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -724,7 +724,7 @@
|
|||||||
" vector_field=\"product_vector\",\n",
|
" vector_field=\"product_vector\",\n",
|
||||||
" k=10,\n",
|
" k=10,\n",
|
||||||
" hybrid_fields='@productDisplayName:\"blue jeans\"'\n",
|
" hybrid_fields='@productDisplayName:\"blue jeans\"'\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -755,7 +755,7 @@
|
|||||||
" vector_field=\"product_vector\",\n",
|
" vector_field=\"product_vector\",\n",
|
||||||
" k=10,\n",
|
" k=10,\n",
|
||||||
" hybrid_fields='@productDisplayName:\"slim fit\"'\n",
|
" hybrid_fields='@productDisplayName:\"slim fit\"'\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -788,7 +788,7 @@
|
|||||||
" vector_field=\"product_vector\",\n",
|
" vector_field=\"product_vector\",\n",
|
||||||
" k=10,\n",
|
" k=10,\n",
|
||||||
" hybrid_fields='@masterCategory:{Accessories}'\n",
|
" hybrid_fields='@masterCategory:{Accessories}'\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -821,7 +821,7 @@
|
|||||||
" vector_field=\"product_vector\",\n",
|
" vector_field=\"product_vector\",\n",
|
||||||
" k=10,\n",
|
" k=10,\n",
|
||||||
" hybrid_fields='@year:[2011 2012]'\n",
|
" hybrid_fields='@year:[2011 2012]'\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -854,7 +854,7 @@
|
|||||||
" vector_field=\"product_vector\",\n",
|
" vector_field=\"product_vector\",\n",
|
||||||
" k=10,\n",
|
" k=10,\n",
|
||||||
" hybrid_fields='(@year:[2011 2012] @season:{Summer})'\n",
|
" hybrid_fields='(@year:[2011 2012] @season:{Summer})'\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -883,7 +883,7 @@
|
|||||||
" vector_field=\"product_vector\",\n",
|
" vector_field=\"product_vector\",\n",
|
||||||
" k=10,\n",
|
" k=10,\n",
|
||||||
" hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n",
|
" hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n",
|
||||||
" )"
|
" )\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user