openai-cookbook/examples/Clustering.ipynb

263 lines
78 KiB
Plaintext
Raw Normal View History

2022-03-10 18:08:53 -08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Clustering\n",
"\n",
"We use a simple k-means algorithm to demonstrate how clustering can be done. Clustering can help discover valuable, hidden groupings within the data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 2048)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"\n",
"df = pd.read_csv('output/embedded_1k_reviews.csv')\n",
"df['text-similarity-babbage-001'] = df.babbage_similarity.apply(eval).apply(np.array)\n",
"matrix = np.vstack(df.babbage_similarity.values)\n",
"matrix.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Find the clusters using K-means"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We show the simplest use of K-means. You can pick the number of clusters that fits your use case best."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Cluster\n",
"2 2.543478\n",
"3 4.374046\n",
"0 4.709402\n",
"1 4.832099\n",
"Name: Score, dtype: float64"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cluster import KMeans\n",
"\n",
"n_clusters = 4\n",
"\n",
"kmeans = KMeans(n_clusters = n_clusters,init='k-means++',random_state=42)\n",
"kmeans.fit(matrix)\n",
"labels = kmeans.labels_\n",
"df['Cluster'] = labels\n",
"\n",
"df.groupby('Cluster').Score.mean().sort_values()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It looks like cluster 2 focused on negative reviews, while cluster 0 and 1 focused on positive reviews."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Clusters identified visualized in language 2d using t-SNE')"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEICAYAAABcVE8dAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAADNv0lEQVR4nOz9d3ic53XnjX+e6X0GZVAHrADYJVIiRMmSTMmSHUlRIsuR7VTLji1v8m68Tna9693su3Hi9ZuN/dNu1ustiRUndhJnY0e2QlsW5SLJVGMBxSJWEGABMKgzwPRent8fZx7MoAMkKJLifK8L1wxmnnI/z8x873Of8z3nKKqqUkUVVVRRxbsfums9gCqqqKKKKt4ZVAm/iiqqqOImQZXwq6iiiipuElQJv4oqqqjiJkGV8KuooooqbhJUCb+KKqqo4ibBDUH4iqL8saIof3+tx7EcKIqyV1GUJ+d5b42iKKqiKIardO64oijrSs+tiqL8UFGUiKIo/6Qoym8oivKTyzzuxxVFef1Kx3Q1MPOeLnT/r+Ac834PFUW5V1GUnss87mXf15sBV+P+KIqyqvSd1K/kca93XDeEryjKryuKcrj0IYyUfrD3rODxryrJzoSqqg+rqvqtq30eRVF+rijKp2ac26Gq6oXSv08AjUCdqqofVlX126qqfuBqj2smZozpnTjfO3L/K873mqqqG96p893IUBTlTkVRfqooyqSiKIGSIdL8To5BVdWB0neysNLHVhTlm4qifGmRbTyKovy1oiijiqLEFEU5pyjKv694X1UU5YSiKLqK176kKMo3S881PovP+PvoQue9LghfUZR/Dfx34E8RcloF/G/gsWs4rGl4pyaKq4DVwDlVVfPXeiBVVFFCDfB1YA3y/YwBf3MtB3QN8OeAA9gEuIFfBvpmbNMC/Ooix/GUJi7t7zsLbq2q6jX9K11sHPjwAtv8MfD3pef3Af4Z718CHiw9vwM4DESBMeC/lV4fANTSueLAXaXXfxs4A4SAHwOrK46rAv8S6AUuAkrpgxovHf8EsHWeMf8c+FTpuR54GggCF0rHVAFDxT34BjACDAFfAvSl9z4OvF7aP1Qax8Ol9/4/oACkS9f0PyvG3Q78CZAFcqX3P6kdr2KcG4GfApNAD/CRivfqgB+UrvUQ8J8r951xvXuB35vx2nHgQ5VjKj1/BDiN/NCHgM9VXuuMY1Tu94vA0dJ4BoE/rthuzYx7Wnn/j1d87vHSdveV3rsTeBMIl7a7r+KYa4F9pXH+FPiflL6Hc1z/fVR8L5Hv5OeAt4EI8B3AMs++Mz+Tr5auLwq8Bdw747fwXeBvS+M6BeyseP+20j2KAf9UOu+XrvT+lt7/GNAPTAD/iem/Ox3w74Hzpfe/C9QukQNuA2KX+b2bdt+XwQdzfV/+M/BG6d79BKhfyrXPOPenkd9bFvmu/XCecZ8EPrjAPVGBzyPco43xS8A35xr/Uv+uB8J/CMgvNHCWR/j7gd8qPXcAd853g5AVRB8yyxqA/xd4c8ZN/ylQC1iBX0B+gB6E/DcBzfOM+eeUCed3gLNAW+lYr8z4sj0H/CVgBxpKX/J/UfEjzQFPIRPH7wLDgDLzPPP8iKfu3cwffel8g8AnSte/A5mUNpfe/0fkh2sHtiLkPN8P72PAGxX/b0ZI1DzHmEYokRhi7d02c2zzXMt9wDaEXG5BfsAfXOAH/Kk5xvnp0mfhAlqRH/AjpWO+v/S/t+K79N8AM/BehAiWQ/iHECutFjEqfmeefaddN/CbCOkZgH8DjFKaLEqfZ7o0Zj3wX4ADpfdMCCl9FjACH0KIZ6mEv9D93YwQ2D2l8zyNfC+1391ngQOAr3S//hL4v0vkgN/XruEyvnfT7vvl8kHp+3Ie6ER+6z8H/mwp1z7HmL6p3fMFrvmvkMn6E0DHHO+rQAfCNxqPXDHhXw8unTogqK6cyyEHtCuKUq+qalxV1QMLbPs7wH9RVfVM6fx/CmxXFGV1xTb/RVXVSVVVU6VjOxGrWCntN7KEMX0E+O+qqg6qqjqJ/EgBUBSlEfnx/r6qqglVVceRVUTlUq5fVdVnVPE3fgtoRlxfV4pHgUuqqv6Nqqp5VVWPAt8DPlwKZv0K8EelcZ0snXs+PMf0e/cbwPdVVc3MsW0O2KwoiktV1ZCqqkeWMlhVVX+uquoJVVWLqqq+DfxfYPfSLhVKMaEvAb+sqmoUIdYXVFV9oXTMnyLW4COKoqwCuoD/pKpqRlXVV4EfLvVcJfwPVVWHS5/5D4HtS9lJVdW/V1V1ovSZ/FeEQCvjA6+XxlwA/g64tfT6ncgk8T9UVc2pqvp9ZNJZEha5v08g1urrqqpmgT9CCEfD7wD/UVVVf+kz/2PgicVcoYqi3FI61r8t/b/c791iWA4f/I2qqudKv/XvUv68Frv2y8FngG8DvwecVhSlT1GUh2dsoyKrif+kKIppnuMEFUUJV/xtWuik1wPhTwD1K+gj/yQyS59VFKVbUZRHF9h2NfBV7WYhbg0Fsfw0DGpPVFV9GVnW/y9gXFGUryuK4lrCmFoqj4NYYZVjMAIjFeP4S8TS1zBaMYZk6aljCeddDKuBXZVfGISomwAvQh7zjXsaVFWNAT+iPFH9GvKFngu/gkxy/Yqi7FMU5a6lDFZRlF2KorxSCvRFEJKpX+K+bciP+ElVVc+VXl6NTG6V138PMqG2ACFVVRMVh5n3+ufBaMXzJEv8zBRF+ZyiKGdKyqow4vKrvM6Zx7WUfj8twJBaMgFLqPz8FjvvQvd32ne49D2cqNh9NfBcxX08g7gb5zVMFEVpR1yBn1VV9bXSy8v63i0By+GD+T6vxa59QZSUcVpQdW/pGClVVf9UVdXbEaP3u8A/KYpSW7mvqqovAH7gX8xz+HpVVT0Vf2cWGsv1QPj7gQzwwSVunwBs2j8li8Cr/a+qaq+qqr+GEOaXgWcVRbEz94w8iLhOKm+YVVXVNyu2mbafqqr/o/QhbUa+SP92CWMeQdw5GlbNGEOG6R+cS1XVLUs47qzxLRODwL4Z1+9QVfV3gQDiaptv3HPh/wK/ViJwC+K6mj1gVe1WVfUx5DP6Z+TLDrM/26YZu/4D4tttU1XVDfwFMkEvCEVRrKXz/HdVVfdWvDUI/N2M67erqvpnyGdWU/ruaFjs+q8YiqLcC/w7ZFVYo6qqB4kBLHqdyJhbFUWp3Lby87uS+zuCuGu0fa0IUWkYRGJLlffSoqrq0DzXuRr4GfCfVVX9u4q3lvu9u1w+WA4Wu/aZmMkZ31bLQdWZVjyl1eafIi6stXMc7z8Cf0jFdV4urjnhq6oaQZZI/0tRlA8qimJTFMWoKMrDiqJ8ZY5dziEWzS8qimJE/O5m7U1FUX5TURSvqqpFxIcMUES+SEWgUgv+F8B/UBRlS2lft6IoH55vrIqidJWsICPyRUuXjrkYvgv8K0VRfIqi1CDBLe36R5AA0X9VFMWlKIpOUZT1iqIs1VUxNuOaloPngU5FUX6rdM+NpWvcVHIXfB/449Jnshl4cpHjvYBYel8EvlP6DKZBURRTyeJxq6qaQ4Jp2nbHgS2KomxXFMWCuAUq4QQmVVVNK4pyB/DrS7zOvwbOqqo68/v098AvKYryC4qi6BVFsSiKcp+iKD5VVfsR986flMZ8D/BLSzzflcCJEF4AMCiK8kdIvGEp2I9Y1b+nKIpBUZTHkKClhiu5v88i9+o9JffCHzN9EvoL4P/TXHqKonhL558FRVFagZcRkcFfVL53Gd+7y+WD5WCxa5+JRX+TiqL8p9JvzVT6LD5bGt+sXA5VVX+OBHkX+/0timtO+AAlP+W/Rj6sAGIt/B5ilc3cNgL8P0jQYwghXn/FJg8BpxRFiSNqh18tLZ+SiKrljdKy805VVZ9DZv1/VBQlitzUWTNwBVzAM4haRovY//+WcInPIAqg48AR5AtdiY8hwaDTpWM/i7gVloKvIr7SkKIo/2OJ+wBTbpgPIG6YYWRJ+2XKP5jfQ5a1o0gg6m8WOV4GubYHEWtxPvwWcKl0z38HcSNRcrV8EbH8ehF1UiX+H+CLiqLEECPhuywNvwo8rkz
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.manifold import TSNE\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"tsne = TSNE(n_components=2, perplexity=15, random_state=42, init='random', learning_rate=200)\n",
"vis_dims2 = tsne.fit_transform(matrix)\n",
"\n",
"x = [x for x,y in vis_dims2]\n",
"y = [y for x,y in vis_dims2]\n",
"\n",
"for category, color in enumerate(['purple', 'green', 'red', 'blue']):\n",
" xs = np.array(x)[df.Cluster==category]\n",
" ys = np.array(y)[df.Cluster==category]\n",
" plt.scatter(xs, ys, color=color, alpha=0.3)\n",
"\n",
" avg_x = xs.mean()\n",
" avg_y = ys.mean()\n",
" \n",
" plt.scatter(avg_x, avg_y, marker='x', color=color, s=100)\n",
"plt.title(\"Clusters identified visualized in language 2d using t-SNE\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualization of clusters in a 2d projection. The red cluster clearly represents negative reviews. The blue cluster seems quite different from the others. Let's see a few samples from each cluster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Text samples in the clusters & naming the clusters\n",
"\n",
"Let's show random samples from each cluster. We'll use davinci-instruct-beta-v3 to name the clusters, based on a random sample of 6 reviews from that cluster."
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cluster 0 Theme: All of the customer reviews mention the great flavor of the product.\n",
"5, French Vanilla Cappuccino: Great price. Really love the the flavor. No need to add anything to \n",
"5, great coffee: A bit pricey once you add the S & H but this is one of the best flavor\n",
"5, Love It: First let me say I'm new to drinking tea. So you're not getting a well\n",
"----------------------------------------------------------------------------------------------------\n",
"Cluster 1 Theme: All three reviews mention the quality of the product.\n",
"5, Beautiful: I don't plan to grind these, have plenty other peppers for that. I go\n",
"5, Awesome: I can't find this in the stores and thought I would like it. So I bou\n",
"5, Came as expected: It was tasty and fresh. The other one I bought was old and tasted mold\n",
"----------------------------------------------------------------------------------------------------\n",
"Cluster 2 Theme: All reviews are about customer's disappointment.\n",
"1, Disappointed...: I should read the fine print, I guess. I mostly went by the picture a\n",
"5, Excellent but Price?: I first heard about this on America's Test Kitchen where it won a blin\n",
"1, Disappointed: I received the offer from Amazon and had never tried this brand before\n",
"----------------------------------------------------------------------------------------------------\n",
"Cluster 3 Theme: The reviews for these products have in common that the customers' dogs love them.\n",
"5, My Dog's Favorite Snack!: I was first introduced to this snack at my dog's training classes at p\n",
"4, Fruitables Crunchy Dog Treats: My lab goes wild for these and I am almost tempted to have a go at som\n",
"5, Happy with the product: My dog was suffering with itchy skin. He had been eating Natural Choi\n",
"----------------------------------------------------------------------------------------------------\n"
]
}
],
"source": [
"import openai\n",
"\n",
"# Reading a review which belong to each group.\n",
"rev_per_cluster = 3\n",
"\n",
"for i in range(n_clusters):\n",
" print(f\"Cluster {i} Theme:\", end=\" \")\n",
" \n",
" reviews = \"\\n\".join(df[df.Cluster == i].combined.str.replace(\"Title: \", \"\").str.replace(\"\\n\\nContent: \", \": \").sample(rev_per_cluster, random_state=42).values)\n",
" response = openai.Completion.create(\n",
" engine=\"davinci-instruct-beta-v3\",\n",
" prompt=f\"What do the following customer reviews have in common?\\n\\nCustomer reviews:\\n\\\"\\\"\\\"\\n{reviews}\\n\\\"\\\"\\\"\\n\\nTheme:\",\n",
" temperature=0,\n",
" max_tokens=64,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0\n",
" )\n",
" print(response[\"choices\"][0][\"text\"].replace('\\n',''))\n",
"\n",
" sample_cluster_rows = df[df.Cluster == i].sample(rev_per_cluster, random_state=42) \n",
" for j in range(rev_per_cluster):\n",
" print(sample_cluster_rows.Score.values[j], end=\", \")\n",
" print(sample_cluster_rows.Summary.values[j], end=\": \")\n",
" print(sample_cluster_rows.Text.str[:70].values[j])\n",
" \n",
" print(\"-\" * 100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see based on the average ratings per cluster, that Cluster 2 contains mostly negative reviews. Cluster 0 and 1 contain mostly positive reviews, whilst Cluster 3 appears to contain reviews about dog products."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's important to note that clusters will not necessarily match what you intend to use them for. A larger amount of clusters will focus on more specific patterns, whereas a small number of clusters will usually focus on largest discrepencies in the data."
]
}
],
"metadata": {
"interpreter": {
"hash": "be4b5d5b73a21c599de40d6deb1129796d12dc1cc33a738f7bac13269cfcafe8"
},
"kernelspec": {
"display_name": "Python 3.7.3 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}