Added embeddings_utils.py to utils directory, and updated references (#841)

This commit is contained in:
jhills20 2023-11-10 09:07:25 -08:00 committed by GitHub
parent ac4cb3b163
commit 7044eecb98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 386 additions and 136 deletions

File diff suppressed because one or more lines are too long

View File

@ -34,7 +34,7 @@
"import pandas as pd\n", "import pandas as pd\n",
"import tiktoken\n", "import tiktoken\n",
"\n", "\n",
"from openai.embeddings_utils import get_embedding\n" "from utils.embeddings_utils import get_embedding\n"
] ]
}, },
{ {
@ -212,7 +212,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.9 (main, Dec 7 2021, 18:04:56) \n[Clang 13.0.0 (clang-1300.0.29.3)]" "version": "3.9.16"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {

View File

@ -30,8 +30,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload \n", "%autoreload\n",
"%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers" "%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers\n"
] ]
}, },
{ {
@ -47,7 +47,7 @@
"import os\n", "import os\n",
"\n", "\n",
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n", "openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
"COMPLETIONS_MODEL = \"text-davinci-002\"" "COMPLETIONS_MODEL = \"text-davinci-002\"\n"
] ]
}, },
{ {
@ -85,7 +85,7 @@
], ],
"source": [ "source": [
"transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')\n", "transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')\n",
"len(transactions)" "len(transactions)\n"
] ]
}, },
{ {
@ -182,7 +182,7 @@
} }
], ],
"source": [ "source": [
"transactions.head()" "transactions.head()\n"
] ]
}, },
{ {
@ -192,7 +192,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def request_completion(prompt):\n", "def request_completion(prompt):\n",
" \n", "\n",
" completion_response = openai.Completion.create(\n", " completion_response = openai.Completion.create(\n",
" prompt=prompt,\n", " prompt=prompt,\n",
" temperature=0,\n", " temperature=0,\n",
@ -202,17 +202,17 @@
" presence_penalty=0,\n", " presence_penalty=0,\n",
" model=COMPLETIONS_MODEL\n", " model=COMPLETIONS_MODEL\n",
" )\n", " )\n",
" \n", "\n",
" return completion_response\n", " return completion_response\n",
"\n", "\n",
"def classify_transaction(transaction,prompt):\n", "def classify_transaction(transaction,prompt):\n",
" \n", "\n",
" prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n", " prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])\n",
" prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n", " prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])\n",
" prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n", " prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))\n",
" \n", "\n",
" classification = request_completion(prompt)['choices'][0]['text'].replace('\\n','')\n", " classification = request_completion(prompt)['choices'][0]['text'].replace('\\n','')\n",
" \n", "\n",
" return classification\n", " return classification\n",
"\n", "\n",
"# This function takes your training and validation outputs from the prepare_data function of the Finetuning API, and\n", "# This function takes your training and validation outputs from the prepare_data function of the Finetuning API, and\n",
@ -242,12 +242,12 @@
" valid_classes.add(result['completion'])\n", " valid_classes.add(result['completion'])\n",
" #print(f\"result: {result['completion']}\")\n", " #print(f\"result: {result['completion']}\")\n",
" #print(isinstance(result, dict))\n", " #print(isinstance(result, dict))\n",
" \n", "\n",
" if len(train_classes) == len(valid_classes):\n", " if len(train_classes) == len(valid_classes):\n",
" print('All good')\n", " print('All good')\n",
" \n", "\n",
" else:\n", " else:\n",
" print('Classes do not match, please prepare data again')" " print('Classes do not match, please prepare data again')\n"
] ]
}, },
{ {
@ -266,18 +266,18 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"zero_shot_prompt = '''You are a data expert working for the National Library of Scotland. \n", "zero_shot_prompt = '''You are a data expert working for the National Library of Scotland.\n",
"You are analysing all transactions over £25,000 in value and classifying them into one of five categories.\n", "You are analysing all transactions over £25,000 in value and classifying them into one of five categories.\n",
"The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.\n", "The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.\n",
"If you can't tell what it is, say Could not classify\n", "If you can't tell what it is, say Could not classify\n",
" \n", "\n",
"Transaction:\n", "Transaction:\n",
" \n", "\n",
"Supplier: SUPPLIER_NAME\n", "Supplier: SUPPLIER_NAME\n",
"Description: DESCRIPTION_TEXT\n", "Description: DESCRIPTION_TEXT\n",
"Value: TRANSACTION_VALUE\n", "Value: TRANSACTION_VALUE\n",
" \n", "\n",
"The classification is:'''" "The classification is:'''\n"
] ]
}, },
{ {
@ -304,7 +304,7 @@
"\n", "\n",
"# Use our completion function to return a prediction\n", "# Use our completion function to return a prediction\n",
"completion_response = request_completion(prompt)\n", "completion_response = request_completion(prompt)\n",
"print(completion_response['choices'][0]['text'])" "print(completion_response['choices'][0]['text'])\n"
] ]
}, },
{ {
@ -337,7 +337,7 @@
], ],
"source": [ "source": [
"test_transactions = transactions.iloc[:25]\n", "test_transactions = transactions.iloc[:25]\n",
"test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)" "test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)\n"
] ]
}, },
{ {
@ -362,7 +362,7 @@
} }
], ],
"source": [ "source": [
"test_transactions['Classification'].value_counts()" "test_transactions['Classification'].value_counts()\n"
] ]
}, },
{ {
@ -665,7 +665,7 @@
} }
], ],
"source": [ "source": [
"test_transactions.head(25)" "test_transactions.head(25)\n"
] ]
}, },
{ {
@ -791,7 +791,7 @@
], ],
"source": [ "source": [
"df = pd.read_csv('./data/labelled_transactions.csv')\n", "df = pd.read_csv('./data/labelled_transactions.csv')\n",
"df.head()" "df.head()\n"
] ]
}, },
{ {
@ -872,7 +872,7 @@
], ],
"source": [ "source": [
"df['combined'] = \"Supplier: \" + df['Supplier'].str.strip() + \"; Description: \" + df['Description'].str.strip() + \"; Value: \" + str(df['Transaction value (£)']).strip()\n", "df['combined'] = \"Supplier: \" + df['Supplier'].str.strip() + \"; Description: \" + df['Description'].str.strip() + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
"df.head(2)" "df.head(2)\n"
] ]
}, },
{ {
@ -896,7 +896,7 @@
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n", "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
"\n", "\n",
"df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))\n", "df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))\n",
"len(df)" "len(df)\n"
] ]
}, },
{ {
@ -905,7 +905,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"embedding_path = './data/transactions_with_embeddings_100.csv'" "embedding_path = './data/transactions_with_embeddings_100.csv'\n"
] ]
}, },
{ {
@ -914,11 +914,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from openai.embeddings_utils import get_embedding\n", "from utils.embeddings_utils import get_embedding\n",
"\n", "\n",
"df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, engine='text-similarity-babbage-001'))\n", "df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, engine='text-similarity-babbage-001'))\n",
"df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, engine='text-search-babbage-doc-001'))\n", "df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, engine='text-search-babbage-doc-001'))\n",
"df.to_csv(embedding_path)" "df.to_csv(embedding_path)\n"
] ]
}, },
{ {
@ -1091,7 +1091,7 @@
"\n", "\n",
"fs_df = pd.read_csv(embedding_path)\n", "fs_df = pd.read_csv(embedding_path)\n",
"fs_df[\"babbage_similarity\"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)\n", "fs_df[\"babbage_similarity\"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)\n",
"fs_df.head()" "fs_df.head()\n"
] ]
}, },
{ {
@ -1141,7 +1141,7 @@
"probas = clf.predict_proba(X_test)\n", "probas = clf.predict_proba(X_test)\n",
"\n", "\n",
"report = classification_report(y_test, preds)\n", "report = classification_report(y_test, preds)\n",
"print(report)" "print(report)\n"
] ]
}, },
{ {
@ -1195,7 +1195,7 @@
], ],
"source": [ "source": [
"ft_prep_df = fs_df.copy()\n", "ft_prep_df = fs_df.copy()\n",
"len(ft_prep_df)" "len(ft_prep_df)\n"
] ]
}, },
{ {
@ -1349,7 +1349,7 @@
} }
], ],
"source": [ "source": [
"ft_prep_df.head()" "ft_prep_df.head()\n"
] ]
}, },
{ {
@ -1378,7 +1378,7 @@
"classes = list(set(ft_prep_df['Classification']))\n", "classes = list(set(ft_prep_df['Classification']))\n",
"class_df = pd.DataFrame(classes).reset_index()\n", "class_df = pd.DataFrame(classes).reset_index()\n",
"class_df.columns = ['class_id','class']\n", "class_df.columns = ['class_id','class']\n",
"class_df , len(class_df)" "class_df , len(class_df)\n"
] ]
}, },
{ {
@ -1559,7 +1559,7 @@
"\n", "\n",
"# Adding a common separator onto the end of each prompt so the model knows when a prompt is terminating\n", "# Adding a common separator onto the end of each prompt so the model knows when a prompt is terminating\n",
"ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\\n\\n###\\n\\n',axis=1)\n", "ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\\n\\n###\\n\\n',axis=1)\n",
"ft_df_with_class.head()" "ft_df_with_class.head()\n"
] ]
}, },
{ {
@ -1647,7 +1647,7 @@
"# In our case we don't, so we shuffle the data to give us a better chance of getting equal classes in our train and validation sets\n", "# In our case we don't, so we shuffle the data to give us a better chance of getting equal classes in our train and validation sets\n",
"# Our fine-tuned model will error if we have less classes in the validation set, so this is a necessary step\n", "# Our fine-tuned model will error if we have less classes in the validation set, so this is a necessary step\n",
"\n", "\n",
"import random \n", "import random\n",
"\n", "\n",
"labels = [x for x in ft_df_with_class['class_id']]\n", "labels = [x for x in ft_df_with_class['class_id']]\n",
"text = [x for x in ft_df_with_class['prompt']]\n", "text = [x for x in ft_df_with_class['prompt']]\n",
@ -1656,7 +1656,7 @@
"ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)\n", "ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)\n",
"ft_df.set_index('ordering',inplace=True)\n", "ft_df.set_index('ordering',inplace=True)\n",
"ft_df_sorted = ft_df.sort_index(ascending=True)\n", "ft_df_sorted = ft_df.sort_index(ascending=True)\n",
"ft_df_sorted.head()" "ft_df_sorted.head()\n"
] ]
}, },
{ {
@ -1670,7 +1670,7 @@
"\n", "\n",
"# We output our shuffled dataframe to a .jsonl file and run the prepare_data function to get us our input files\n", "# We output our shuffled dataframe to a .jsonl file and run the prepare_data function to get us our input files\n",
"ft_df_sorted.to_json(\"transactions_grouped.jsonl\", orient='records', lines=True)\n", "ft_df_sorted.to_json(\"transactions_grouped.jsonl\", orient='records', lines=True)\n",
"!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q" "!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q\n"
] ]
}, },
{ {
@ -1691,7 +1691,7 @@
"source": [ "source": [
"# This functions checks that your classes all appear in both prepared files\n", "# This functions checks that your classes all appear in both prepared files\n",
"# If they don't, the fine-tuned model creation will fail\n", "# If they don't, the fine-tuned model creation will fail\n",
"check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')" "check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')\n"
] ]
}, },
{ {
@ -1704,7 +1704,7 @@
"!openai api fine_tunes.create -t \"transactions_grouped_prepared_train.jsonl\" -v \"transactions_grouped_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5 -m curie\n", "!openai api fine_tunes.create -t \"transactions_grouped_prepared_train.jsonl\" -v \"transactions_grouped_prepared_valid.jsonl\" --compute_classification_metrics --classification_n_classes 5 -m curie\n",
"\n", "\n",
"# You can use following command to get fine tuning job status and model name, replace the job name with your job\n", "# You can use following command to get fine tuning job status and model name, replace the job name with your job\n",
"#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx" "#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx\n"
] ]
}, },
{ {
@ -1715,7 +1715,7 @@
"source": [ "source": [
"# Congrats, you've got a fine-tuned model!\n", "# Congrats, you've got a fine-tuned model!\n",
"# Copy/paste the name provided into the variable below and we'll take it for a spin\n", "# Copy/paste the name provided into the variable below and we'll take it for a spin\n",
"fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'" "fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'\n"
] ]
}, },
{ {
@ -1804,7 +1804,7 @@
], ],
"source": [ "source": [
"test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)\n", "test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)\n",
"test_set.head()" "test_set.head()\n"
] ]
}, },
{ {
@ -1814,7 +1814,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"test_set['predicted_class'] = test_set.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n", "test_set['predicted_class'] = test_set.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
"test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)" "test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)\n"
] ]
}, },
{ {
@ -1823,7 +1823,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)" "test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)\n"
] ]
}, },
{ {
@ -1845,7 +1845,7 @@
} }
], ],
"source": [ "source": [
"test_set['result'].value_counts()" "test_set['result'].value_counts()\n"
] ]
}, },
{ {
@ -1953,7 +1953,7 @@
], ],
"source": [ "source": [
"holdout_df = transactions.copy().iloc[101:]\n", "holdout_df = transactions.copy().iloc[101:]\n",
"holdout_df.head()" "holdout_df.head()\n"
] ]
}, },
{ {
@ -1964,7 +1964,7 @@
"source": [ "source": [
"holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n", "holdout_df['combined'] = \"Supplier: \" + holdout_df['Supplier'].str.strip() + \"; Description: \" + holdout_df['Description'].str.strip() + '\\n\\n###\\n\\n' # + \"; Value: \" + str(df['Transaction value (£)']).strip()\n",
"holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n", "holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.Completion.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)\n",
"holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)" "holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)\n"
] ]
}, },
{ {
@ -2151,7 +2151,7 @@
} }
], ],
"source": [ "source": [
"holdout_df.head(10)" "holdout_df.head(10)\n"
] ]
}, },
{ {
@ -2173,7 +2173,7 @@
} }
], ],
"source": [ "source": [
"holdout_df['pred'].value_counts()" "holdout_df['pred'].value_counts()\n"
] ]
}, },
{ {

View File

@ -37,7 +37,7 @@
"import pandas as pd\n", "import pandas as pd\n",
"import pickle\n", "import pickle\n",
"\n", "\n",
"from openai.embeddings_utils import (\n", "from utils.embeddings_utils import (\n",
" get_embedding,\n", " get_embedding,\n",
" distances_from_embeddings,\n", " distances_from_embeddings,\n",
" tsne_components_from_embeddings,\n", " tsne_components_from_embeddings,\n",
@ -46,7 +46,7 @@
")\n", ")\n",
"\n", "\n",
"# constants\n", "# constants\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"" "EMBEDDING_MODEL = \"text-embedding-ada-002\"\n"
] ]
}, },
{ {
@ -158,7 +158,7 @@
"\n", "\n",
"# print dataframe\n", "# print dataframe\n",
"n_examples = 5\n", "n_examples = 5\n",
"df.head(n_examples)" "df.head(n_examples)\n"
] ]
}, },
{ {
@ -206,7 +206,7 @@
" print(\"\")\n", " print(\"\")\n",
" print(f\"Title: {row['title']}\")\n", " print(f\"Title: {row['title']}\")\n",
" print(f\"Description: {row['description']}\")\n", " print(f\"Description: {row['description']}\")\n",
" print(f\"Label: {row['label']}\")" " print(f\"Label: {row['label']}\")\n"
] ]
}, },
{ {
@ -252,7 +252,7 @@
" embedding_cache[(string, model)] = get_embedding(string, model)\n", " embedding_cache[(string, model)] = get_embedding(string, model)\n",
" with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n", " with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
" pickle.dump(embedding_cache, embedding_cache_file)\n", " pickle.dump(embedding_cache, embedding_cache_file)\n",
" return embedding_cache[(string, model)]" " return embedding_cache[(string, model)]\n"
] ]
}, },
{ {
@ -285,7 +285,7 @@
"\n", "\n",
"# print the first 10 dimensions of the embedding\n", "# print the first 10 dimensions of the embedding\n",
"example_embedding = embedding_from_string(example_string)\n", "example_embedding = embedding_from_string(example_string)\n",
"print(f\"\\nExample embedding: {example_embedding[:10]}...\")" "print(f\"\\nExample embedding: {example_embedding[:10]}...\")\n"
] ]
}, },
{ {
@ -317,9 +317,9 @@
" embeddings = [embedding_from_string(string, model=model) for string in strings]\n", " embeddings = [embedding_from_string(string, model=model) for string in strings]\n",
" # get the embedding of the source string\n", " # get the embedding of the source string\n",
" query_embedding = embeddings[index_of_source_string]\n", " query_embedding = embeddings[index_of_source_string]\n",
" # get distances between the source embedding and other embeddings (function from embeddings_utils.py)\n", " # get distances between the source embedding and other embeddings (function from utils.embeddings_utils.py)\n",
" distances = distances_from_embeddings(query_embedding, embeddings, distance_metric=\"cosine\")\n", " distances = distances_from_embeddings(query_embedding, embeddings, distance_metric=\"cosine\")\n",
" # get indices of nearest neighbors (function from embeddings_utils.py)\n", " # get indices of nearest neighbors (function from utils.utils.embeddings_utils.py)\n",
" indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)\n", " indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)\n",
"\n", "\n",
" # print out source string\n", " # print out source string\n",
@ -344,7 +344,7 @@
" Distance: {distances[i]:0.3f}\"\"\"\n", " Distance: {distances[i]:0.3f}\"\"\"\n",
" )\n", " )\n",
"\n", "\n",
" return indices_of_nearest_neighbors" " return indices_of_nearest_neighbors\n"
] ]
}, },
{ {
@ -396,7 +396,7 @@
" strings=article_descriptions, # let's base similarity off of the article description\n", " strings=article_descriptions, # let's base similarity off of the article description\n",
" index_of_source_string=0, # let's look at articles similar to the first one about Tony Blair\n", " index_of_source_string=0, # let's look at articles similar to the first one about Tony Blair\n",
" k_nearest_neighbors=5, # let's look at the 5 most similar articles\n", " k_nearest_neighbors=5, # let's look at the 5 most similar articles\n",
")" ")\n"
] ]
}, },
{ {
@ -452,7 +452,7 @@
" strings=article_descriptions, # let's base similarity off of the article description\n", " strings=article_descriptions, # let's base similarity off of the article description\n",
" index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset\n", " index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset\n",
" k_nearest_neighbors=5, # let's look at the 5 most similar articles\n", " k_nearest_neighbors=5, # let's look at the 5 most similar articles\n",
")" ")\n"
] ]
}, },
{ {
@ -11456,7 +11456,7 @@
" width=600,\n", " width=600,\n",
" height=500,\n", " height=500,\n",
" title=\"t-SNE components of article descriptions\",\n", " title=\"t-SNE components of article descriptions\",\n",
")" ")\n"
] ]
}, },
{ {
@ -11498,7 +11498,7 @@
"\n", "\n",
"tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)\n", "tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)\n",
"chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5\n", "chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5\n",
")" ")\n"
] ]
}, },
{ {
@ -22443,7 +22443,7 @@
" height=500,\n", " height=500,\n",
" title=\"Nearest neighbors of the Tony Blair article\",\n", " title=\"Nearest neighbors of the Tony Blair article\",\n",
" category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n", " category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n",
")" ")\n"
] ]
}, },
{ {
@ -33396,7 +33396,7 @@
" height=500,\n", " height=500,\n",
" title=\"Nearest neighbors of the chipset security article\",\n", " title=\"Nearest neighbors of the chipset security article\",\n",
" category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n", " category_orders={\"label\": [\"Other\", \"Nearest neighbor (top 5)\", \"Source\"]},\n",
")" ")\n"
] ]
}, },
{ {

View File

@ -53,7 +53,7 @@
} }
], ],
"source": [ "source": [
"from openai.embeddings_utils import get_embedding, cosine_similarity\n", "from utils.embeddings_utils import get_embedding, cosine_similarity\n",
"\n", "\n",
"# search through the reviews for a specific product\n", "# search through the reviews for a specific product\n",
"def search_reviews(df, product_description, n=3, pprint=True):\n", "def search_reviews(df, product_description, n=3, pprint=True):\n",
@ -98,7 +98,7 @@
} }
], ],
"source": [ "source": [
"results = search_reviews(df, \"whole wheat pasta\", n=3)" "results = search_reviews(df, \"whole wheat pasta\", n=3)\n"
] ]
}, },
{ {
@ -124,7 +124,7 @@
} }
], ],
"source": [ "source": [
"results = search_reviews(df, \"bad delivery\", n=1)" "results = search_reviews(df, \"bad delivery\", n=1)\n"
] ]
}, },
{ {
@ -150,7 +150,7 @@
} }
], ],
"source": [ "source": [
"results = search_reviews(df, \"spoilt\", n=1)" "results = search_reviews(df, \"spoilt\", n=1)\n"
] ]
}, },
{ {
@ -170,7 +170,7 @@
} }
], ],
"source": [ "source": [
"results = search_reviews(df, \"pet food\", n=2)" "results = search_reviews(df, \"pet food\", n=2)\n"
] ]
} }
], ],

View File

@ -75,7 +75,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from openai.embeddings_utils import cosine_similarity\n", "from utils.embeddings_utils import cosine_similarity\n",
"\n", "\n",
"# evaluate embeddings as recommendations on X_test\n", "# evaluate embeddings as recommendations on X_test\n",
"def evaluate_single_match(row):\n", "def evaluate_single_match(row):\n",
@ -140,7 +140,7 @@
"X_test.boxplot(column='percentile_cosine_similarity', by='Score')\n", "X_test.boxplot(column='percentile_cosine_similarity', by='Score')\n",
"plt.title('')\n", "plt.title('')\n",
"plt.show()\n", "plt.show()\n",
"plt.close()" "plt.close()\n"
] ]
}, },
{ {

View File

@ -126,7 +126,7 @@
"samples = pd.read_json(\"data/dbpedia_samples.jsonl\", lines=True)\n", "samples = pd.read_json(\"data/dbpedia_samples.jsonl\", lines=True)\n",
"categories = sorted(samples[\"category\"].unique())\n", "categories = sorted(samples[\"category\"].unique())\n",
"print(\"Categories of DBpedia samples:\", samples[\"category\"].value_counts())\n", "print(\"Categories of DBpedia samples:\", samples[\"category\"].value_counts())\n",
"samples.head()" "samples.head()\n"
] ]
}, },
{ {
@ -136,9 +136,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from openai.embeddings_utils import get_embeddings\n", "from utils.embeddings_utils import get_embeddings\n",
"# NOTE: The following code will send a query of batch size 200 to /embeddings\n", "# NOTE: The following code will send a query of batch size 200 to /embeddings\n",
"matrix = get_embeddings(samples[\"text\"].to_list(), engine=\"text-embedding-ada-002\")" "matrix = get_embeddings(samples[\"text\"].to_list(), engine=\"text-embedding-ada-002\")\n"
] ]
}, },
{ {
@ -159,7 +159,7 @@
"from sklearn.decomposition import PCA\n", "from sklearn.decomposition import PCA\n",
"pca = PCA(n_components=3)\n", "pca = PCA(n_components=3)\n",
"vis_dims = pca.fit_transform(matrix)\n", "vis_dims = pca.fit_transform(matrix)\n",
"samples[\"embed_vis\"] = vis_dims.tolist()" "samples[\"embed_vis\"] = vis_dims.tolist()\n"
] ]
}, },
{ {
@ -233,7 +233,7 @@
"ax.set_xlabel('x')\n", "ax.set_xlabel('x')\n",
"ax.set_ylabel('y')\n", "ax.set_ylabel('y')\n",
"ax.set_zlabel('z')\n", "ax.set_zlabel('z')\n",
"ax.legend(bbox_to_anchor=(1.1, 1))" "ax.legend(bbox_to_anchor=(1.1, 1))\n"
] ]
} }
], ],

View File

@ -86,11 +86,11 @@
} }
], ],
"source": [ "source": [
"from openai.embeddings_utils import cosine_similarity, get_embedding\n", "from utils.embeddings_utils import cosine_similarity, get_embedding\n",
"from sklearn.metrics import PrecisionRecallDisplay\n", "from sklearn.metrics import PrecisionRecallDisplay\n",
"\n", "\n",
"def evaluate_embeddings_approach(\n", "def evaluate_embeddings_approach(\n",
" labels = ['negative', 'positive'], \n", " labels = ['negative', 'positive'],\n",
" model = EMBEDDING_MODEL,\n", " model = EMBEDDING_MODEL,\n",
"):\n", "):\n",
" label_embeddings = [get_embedding(label, engine=model) for label in labels]\n", " label_embeddings = [get_embedding(label, engine=model) for label in labels]\n",
@ -107,7 +107,7 @@
" display = PrecisionRecallDisplay.from_predictions(df.sentiment, probas, pos_label='positive')\n", " display = PrecisionRecallDisplay.from_predictions(df.sentiment, probas, pos_label='positive')\n",
" _ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n", " _ = display.ax_.set_title(\"2-class Precision-Recall curve\")\n",
"\n", "\n",
"evaluate_embeddings_approach(labels=['negative', 'positive'], model=EMBEDDING_MODEL)" "evaluate_embeddings_approach(labels=['negative', 'positive'], model=EMBEDDING_MODEL)\n"
] ]
}, },
{ {
@ -152,7 +152,7 @@
} }
], ],
"source": [ "source": [
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])" "evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])\n"
] ]
}, },
{ {
@ -197,7 +197,7 @@
} }
], ],
"source": [ "source": [
"evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])" "evaluate_embeddings_approach(labels=['An Amazon review with a negative sentiment.', 'An Amazon review with a positive sentiment.'])\n"
] ]
}, },
{ {

View File

@ -0,0 +1,252 @@
import textwrap as tr
from typing import List, Optional
import matplotlib.pyplot as plt
import plotly.express as px
from scipy import spatial
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import average_precision_score, precision_recall_curve
from tenacity import retry, stop_after_attempt, wait_random_exponential
import openai
import numpy as np
import pandas as pd
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(
text: str, engine="text-similarity-davinci-001", **kwargs
) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][
"embedding"
]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data
return [d["embedding"] for d in data]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data
return [d["embedding"] for d in data]
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def plot_multiclass_precision_recall(
y_score, y_true_untransformed, class_list, classifier_name
):
"""
Precision-Recall plotting for a multiclass problem. It plots average precision-recall, per class precision recall and reference f1 contours.
Code slightly modified, but heavily based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
"""
n_classes = len(class_list)
y_true = pd.concat(
[(y_true_untransformed == class_list[i]) for i in range(n_classes)], axis=1
).values
# For each class
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], y_score[:, i])
average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i])
# A "micro-average": quantifying score on all classes jointly
precision_micro, recall_micro, _ = precision_recall_curve(
y_true.ravel(), y_score.ravel()
)
average_precision_micro = average_precision_score(y_true, y_score, average="micro")
print(
str(classifier_name)
+ " - Average precision score over all classes: {0:0.2f}".format(
average_precision_micro
)
)
# setup plot details
plt.figure(figsize=(9, 10))
f_scores = np.linspace(0.2, 0.8, num=4)
lines = []
labels = []
for f_score in f_scores:
x = np.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))
lines.append(l)
labels.append("iso-f1 curves")
(l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2)
lines.append(l)
labels.append(
"average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro)
)
for i in range(n_classes):
(l,) = plt.plot(recall[i], precision[i], lw=2)
lines.append(l)
labels.append(
"Precision-recall for class `{0}` (auprc = {1:0.2f})"
"".format(class_list[i], average_precision[i])
)
fig = plt.gcf()
fig.subplots_adjust(bottom=0.25)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title(f"{classifier_name}: Precision-Recall curve for each class")
plt.legend(lines, labels)
def distances_from_embeddings(
query_embedding: List[float],
embeddings: List[List[float]],
distance_metric="cosine",
) -> List[List]:
"""Return the distances between a query embedding and a list of embeddings."""
distance_metrics = {
"cosine": spatial.distance.cosine,
"L1": spatial.distance.cityblock,
"L2": spatial.distance.euclidean,
"Linf": spatial.distance.chebyshev,
}
distances = [
distance_metrics[distance_metric](query_embedding, embedding)
for embedding in embeddings
]
return distances
def indices_of_nearest_neighbors_from_distances(distances) -> np.ndarray:
"""Return a list of indices of nearest neighbors from a list of distances."""
return np.argsort(distances)
def pca_components_from_embeddings(
embeddings: List[List[float]], n_components=2
) -> np.ndarray:
"""Return the PCA components of a list of embeddings."""
pca = PCA(n_components=n_components)
array_of_embeddings = np.array(embeddings)
return pca.fit_transform(array_of_embeddings)
def tsne_components_from_embeddings(
embeddings: List[List[float]], n_components=2, **kwargs
) -> np.ndarray:
"""Returns t-SNE components of a list of embeddings."""
# use better defaults if not specified
if "init" not in kwargs.keys():
kwargs["init"] = "pca"
if "learning_rate" not in kwargs.keys():
kwargs["learning_rate"] = "auto"
tsne = TSNE(n_components=n_components, **kwargs)
array_of_embeddings = np.array(embeddings)
return tsne.fit_transform(array_of_embeddings)
def chart_from_components(
components: np.ndarray,
labels: Optional[List[str]] = None,
strings: Optional[List[str]] = None,
x_title="Component 0",
y_title="Component 1",
mark_size=5,
**kwargs,
):
"""Return an interactive 2D chart of embedding components."""
empty_list = ["" for _ in components]
data = pd.DataFrame(
{
x_title: components[:, 0],
y_title: components[:, 1],
"label": labels if labels else empty_list,
"string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
if strings
else empty_list,
}
)
chart = px.scatter(
data,
x=x_title,
y=y_title,
color="label" if labels else None,
symbol="label" if labels else None,
hover_data=["string"] if strings else None,
**kwargs,
).update_traces(marker=dict(size=mark_size))
return chart
def chart_from_components_3D(
components: np.ndarray,
labels: Optional[List[str]] = None,
strings: Optional[List[str]] = None,
x_title: str = "Component 0",
y_title: str = "Component 1",
z_title: str = "Compontent 2",
mark_size: int = 5,
**kwargs,
):
"""Return an interactive 3D chart of embedding components."""
empty_list = ["" for _ in components]
data = pd.DataFrame(
{
x_title: components[:, 0],
y_title: components[:, 1],
z_title: components[:, 2],
"label": labels if labels else empty_list,
"string": ["<br>".join(tr.wrap(string, width=30)) for string in strings]
if strings
else empty_list,
}
)
chart = px.scatter_3d(
data,
x=x_title,
y=y_title,
z=z_title,
color="label" if labels else None,
symbol="label" if labels else None,
hover_data=["string"] if strings else None,
**kwargs,
).update_traces(marker=dict(size=mark_size))
return chart

View File

@ -34,7 +34,7 @@
} }
], ],
"source": [ "source": [
"!pip install openai --quiet" "!pip install openai --quiet\n"
] ]
}, },
{ {
@ -48,7 +48,7 @@
"\n", "\n",
"# models\n", "# models\n",
"EMBEDDING_MODEL = \"text-embedding-ada-002\"\n", "EMBEDDING_MODEL = \"text-embedding-ada-002\"\n",
"GPT_MODEL = \"gpt-3.5-turbo\"" "GPT_MODEL = \"gpt-3.5-turbo\"\n"
] ]
}, },
{ {
@ -84,7 +84,7 @@
" ]\n", " ]\n",
")\n", ")\n",
"\n", "\n",
"print(response['choices'][0]['message']['content'])" "print(response['choices'][0]['message']['content'])\n"
] ]
}, },
{ {
@ -120,7 +120,7 @@
} }
], ],
"source": [ "source": [
"!pip install matplotlib plotly.express scikit-learn tabulate tiktoken wget --quiet" "!pip install matplotlib plotly.express scikit-learn tabulate tiktoken wget --quiet\n"
] ]
}, },
{ {
@ -133,7 +133,7 @@
"import pandas as pd\n", "import pandas as pd\n",
"import os\n", "import os\n",
"import wget\n", "import wget\n",
"import ast" "import ast\n"
] ]
}, },
{ {
@ -168,7 +168,7 @@
" wget.download(embeddings_path, file_path)\n", " wget.download(embeddings_path, file_path)\n",
" print(\"File downloaded successfully.\")\n", " print(\"File downloaded successfully.\")\n",
"else:\n", "else:\n",
" print(\"File already exists in the local file system.\")" " print(\"File already exists in the local file system.\")\n"
] ]
}, },
{ {
@ -183,7 +183,7 @@
")\n", ")\n",
"\n", "\n",
"# convert embeddings from CSV str type back to list type\n", "# convert embeddings from CSV str type back to list type\n",
"df['embedding'] = df['embedding'].apply(ast.literal_eval)" "df['embedding'] = df['embedding'].apply(ast.literal_eval)\n"
] ]
}, },
{ {
@ -193,7 +193,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"df" "df\n"
] ]
}, },
{ {
@ -219,7 +219,7 @@
} }
], ],
"source": [ "source": [
"df.info(show_counts=True)" "df.info(show_counts=True)\n"
] ]
}, },
{ {
@ -241,7 +241,7 @@
"\n", "\n",
"conn = s2.connect(\"<user>:<Password>@<host>:3306/\")\n", "conn = s2.connect(\"<user>:<Password>@<host>:3306/\")\n",
"\n", "\n",
"cur = conn.cursor()" "cur = conn.cursor()\n"
] ]
}, },
{ {
@ -267,7 +267,7 @@
" CREATE DATABASE IF NOT EXISTS winter_wikipedia2;\n", " CREATE DATABASE IF NOT EXISTS winter_wikipedia2;\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
"cur.execute(stmt)" "cur.execute(stmt)\n"
] ]
}, },
{ {
@ -296,7 +296,7 @@
" embedding BLOB\n", " embedding BLOB\n",
");\"\"\"\n", ");\"\"\"\n",
"\n", "\n",
"cur.execute(stmt)" "cur.execute(stmt)\n"
] ]
}, },
{ {
@ -349,7 +349,7 @@
"for i in range(0, len(record_arr), batch_size):\n", "for i in range(0, len(record_arr), batch_size):\n",
" batch = record_arr[i:i+batch_size]\n", " batch = record_arr[i:i+batch_size]\n",
" values = [(row[0], row[1], str(row[2])) for row in batch]\n", " values = [(row[0], row[1], str(row[2])) for row in batch]\n",
" cur.executemany(stmt, values)" " cur.executemany(stmt, values)\n"
] ]
}, },
{ {
@ -367,7 +367,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from openai.embeddings_utils import get_embedding\n", "from utils.embeddings_utils import get_embedding\n",
"\n", "\n",
"def strings_ranked_by_relatedness(\n", "def strings_ranked_by_relatedness(\n",
" query: str,\n", " query: str,\n",
@ -395,7 +395,7 @@
"\n", "\n",
" # Fetch the results\n", " # Fetch the results\n",
" results = cur.fetchall()\n", " results = cur.fetchall()\n",
" \n", "\n",
" strings = []\n", " strings = []\n",
" relatednesses = []\n", " relatednesses = []\n",
"\n", "\n",
@ -404,7 +404,7 @@
" relatednesses.append(row[1])\n", " relatednesses.append(row[1])\n",
"\n", "\n",
" # Return the results.\n", " # Return the results.\n",
" return strings[:top_n], relatednesses[:top_n]" " return strings[:top_n], relatednesses[:top_n]\n"
] ]
}, },
{ {
@ -424,7 +424,7 @@
"\n", "\n",
"for string, relatedness in zip(strings, relatednesses):\n", "for string, relatedness in zip(strings, relatednesses):\n",
" print(f\"{relatedness=:.3f}\")\n", " print(f\"{relatedness=:.3f}\")\n",
" print(tabulate([[string]], headers=['Result'], tablefmt='fancy_grid'))" " print(tabulate([[string]], headers=['Result'], tablefmt='fancy_grid'))\n"
] ]
}, },
{ {
@ -494,7 +494,7 @@
" temperature=0\n", " temperature=0\n",
" )\n", " )\n",
" response_message = response[\"choices\"][0][\"message\"][\"content\"]\n", " response_message = response[\"choices\"][0][\"message\"][\"content\"]\n",
" return response_message" " return response_message\n"
] ]
}, },
{ {
@ -531,7 +531,7 @@
"\n", "\n",
"answer = ask('Who won the gold medal for curling in Olymics 2022?')\n", "answer = ask('Who won the gold medal for curling in Olymics 2022?')\n",
"\n", "\n",
"pprint(answer)" "pprint(answer)\n"
] ]
} }
], ],

View File

@ -87,7 +87,7 @@
} }
], ],
"source": [ "source": [
"! pip install redis pandas openai" "! pip install redis pandas openai\n"
] ]
}, },
{ {
@ -303,7 +303,7 @@
"import numpy as np\n", "import numpy as np\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"from openai.embeddings_utils import (\n", "from utils.embeddings_utils import (\n",
" get_embeddings,\n", " get_embeddings,\n",
" distances_from_embeddings,\n", " distances_from_embeddings,\n",
" tsne_components_from_embeddings,\n", " tsne_components_from_embeddings,\n",
@ -322,7 +322,7 @@
"\n", "\n",
"# print dataframe\n", "# print dataframe\n",
"n_examples = 5\n", "n_examples = 5\n",
"df.head(n_examples)" "df.head(n_examples)\n"
] ]
}, },
{ {
@ -360,7 +360,7 @@
"df[\"product_text\"] = df.apply(lambda row: f\"name {row['productDisplayName']} category {row['masterCategory']} subcategory {row['subCategory']} color {row['baseColour']} gender {row['gender']}\".lower(), axis=1)\n", "df[\"product_text\"] = df.apply(lambda row: f\"name {row['productDisplayName']} category {row['masterCategory']} subcategory {row['subCategory']} color {row['baseColour']} gender {row['gender']}\".lower(), axis=1)\n",
"df.rename({\"id\":\"product_id\"}, inplace=True, axis=1)\n", "df.rename({\"id\":\"product_id\"}, inplace=True, axis=1)\n",
"\n", "\n",
"df.info()" "df.info()\n"
] ]
}, },
{ {
@ -382,7 +382,7 @@
], ],
"source": [ "source": [
"# check out one of the texts we will use to create semantic embeddings\n", "# check out one of the texts we will use to create semantic embeddings\n",
"df[\"product_text\"][0]" "df[\"product_text\"][0]\n"
] ]
}, },
{ {
@ -437,7 +437,7 @@
" port=REDIS_PORT,\n", " port=REDIS_PORT,\n",
" password=REDIS_PASSWORD\n", " password=REDIS_PASSWORD\n",
")\n", ")\n",
"redis_client.ping()" "redis_client.ping()\n"
] ]
}, },
{ {
@ -467,7 +467,7 @@
"INDEX_NAME = \"product_embeddings\" # name of the search index\n", "INDEX_NAME = \"product_embeddings\" # name of the search index\n",
"PREFIX = \"doc\" # prefix for the document keys\n", "PREFIX = \"doc\" # prefix for the document keys\n",
"DISTANCE_METRIC = \"L2\" # distance metric for the vectors (ex. COSINE, IP, L2)\n", "DISTANCE_METRIC = \"L2\" # distance metric for the vectors (ex. COSINE, IP, L2)\n",
"NUMBER_OF_VECTORS = len(df)" "NUMBER_OF_VECTORS = len(df)\n"
] ]
}, },
{ {
@ -492,7 +492,7 @@
" \"INITIAL_CAP\": NUMBER_OF_VECTORS,\n", " \"INITIAL_CAP\": NUMBER_OF_VECTORS,\n",
" }\n", " }\n",
")\n", ")\n",
"fields = [name, category, articleType, gender, season, year, text_embedding]" "fields = [name, category, articleType, gender, season, year, text_embedding]\n"
] ]
}, },
{ {
@ -511,7 +511,7 @@
" redis_client.ft(INDEX_NAME).create_index(\n", " redis_client.ft(INDEX_NAME).create_index(\n",
" fields = fields,\n", " fields = fields,\n",
" definition = IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH)\n", " definition = IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH)\n",
")" ")\n"
] ]
}, },
{ {
@ -538,7 +538,7 @@
" product_vectors = []\n", " product_vectors = []\n",
" docs = []\n", " docs = []\n",
" batchsize = 1000\n", " batchsize = 1000\n",
" \n", "\n",
" for idx,doc in enumerate(records,start=1):\n", " for idx,doc in enumerate(records,start=1):\n",
" # create byte vectors\n", " # create byte vectors\n",
" docs.append(doc[\"product_text\"])\n", " docs.append(doc[\"product_text\"])\n",
@ -548,7 +548,7 @@
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n", " print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n", " product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n", " print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
" return product_vectors" " return product_vectors\n"
] ]
}, },
{ {
@ -562,7 +562,7 @@
" product_vectors = embeddings_batch_request(documents)\n", " product_vectors = embeddings_batch_request(documents)\n",
" records = documents.to_dict(\"records\")\n", " records = documents.to_dict(\"records\")\n",
" batchsize = 500\n", " batchsize = 500\n",
" \n", "\n",
" # Use Redis pipelines to batch calls and save on round trip network communication\n", " # Use Redis pipelines to batch calls and save on round trip network communication\n",
" pipe = client.pipeline()\n", " pipe = client.pipeline()\n",
" for idx,doc in enumerate(records,start=1):\n", " for idx,doc in enumerate(records,start=1):\n",
@ -570,14 +570,14 @@
"\n", "\n",
" # create byte vectors\n", " # create byte vectors\n",
" text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n", " text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n",
" \n", "\n",
" # replace list of floats with byte vectors\n", " # replace list of floats with byte vectors\n",
" doc[\"product_vector\"] = text_embedding\n", " doc[\"product_vector\"] = text_embedding\n",
"\n", "\n",
" pipe.hset(key, mapping = doc)\n", " pipe.hset(key, mapping = doc)\n",
" if idx % batchsize == 0:\n", " if idx % batchsize == 0:\n",
" pipe.execute()\n", " pipe.execute()\n",
" pipe.execute()" " pipe.execute()\n"
] ]
}, },
{ {
@ -600,7 +600,7 @@
"source": [ "source": [
"%%time\n", "%%time\n",
"index_documents(redis_client, PREFIX, df)\n", "index_documents(redis_client, PREFIX, df)\n",
"print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")" "print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")\n"
] ]
}, },
{ {
@ -646,14 +646,14 @@
" .dialect(2)\n", " .dialect(2)\n",
" )\n", " )\n",
" params_dict = {\"vector\": np.array(embedded_query).astype(dtype=np.float32).tobytes()}\n", " params_dict = {\"vector\": np.array(embedded_query).astype(dtype=np.float32).tobytes()}\n",
" \n", "\n",
" # perform vector search\n", " # perform vector search\n",
" results = redis_client.ft(index_name).search(query, params_dict)\n", " results = redis_client.ft(index_name).search(query, params_dict)\n",
" if print_results:\n", " if print_results:\n",
" for i, product in enumerate(results.docs):\n", " for i, product in enumerate(results.docs):\n",
" score = 1 - float(product.vector_score)\n", " score = 1 - float(product.vector_score)\n",
" print(f\"{i}. {product.productDisplayName} (Score: {round(score ,3) })\")\n", " print(f\"{i}. {product.productDisplayName} (Score: {round(score ,3) })\")\n",
" return results.docs" " return results.docs\n"
] ]
}, },
{ {
@ -681,7 +681,7 @@
], ],
"source": [ "source": [
"# Execute a simple vector search in Redis\n", "# Execute a simple vector search in Redis\n",
"results = search_redis(redis_client, 'man blue jeans', k=10)" "results = search_redis(redis_client, 'man blue jeans', k=10)\n"
] ]
}, },
{ {
@ -724,7 +724,7 @@
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='@productDisplayName:\"blue jeans\"'\n", " hybrid_fields='@productDisplayName:\"blue jeans\"'\n",
" )" " )\n"
] ]
}, },
{ {
@ -755,7 +755,7 @@
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='@productDisplayName:\"slim fit\"'\n", " hybrid_fields='@productDisplayName:\"slim fit\"'\n",
" )" " )\n"
] ]
}, },
{ {
@ -788,7 +788,7 @@
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='@masterCategory:{Accessories}'\n", " hybrid_fields='@masterCategory:{Accessories}'\n",
" )" " )\n"
] ]
}, },
{ {
@ -821,7 +821,7 @@
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='@year:[2011 2012]'\n", " hybrid_fields='@year:[2011 2012]'\n",
" )" " )\n"
] ]
}, },
{ {
@ -854,7 +854,7 @@
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='(@year:[2011 2012] @season:{Summer})'\n", " hybrid_fields='(@year:[2011 2012] @season:{Summer})'\n",
" )" " )\n"
] ]
}, },
{ {
@ -883,7 +883,7 @@
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n", " hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n",
" )" " )\n"
] ]
} }
], ],