openai-cookbook/examples/Clustering.ipynb

274 lines
133 KiB
Plaintext
Raw Normal View History

2022-03-10 18:08:53 -08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Clustering\n",
"\n",
"We use a simple k-means algorithm to demonstrate how clustering can be done. Clustering can help discover valuable, hidden groupings within the data. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "code",
2022-07-11 15:36:54 -07:00
"execution_count": 1,
2022-03-10 18:08:53 -08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 1536)"
2022-03-10 18:08:53 -08:00
]
},
2022-07-11 15:36:54 -07:00
"execution_count": 1,
2022-03-10 18:08:53 -08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# If you have not run the \"Obtain_dataset.ipynb\" notebook, you can download the datafile from here: https://cdn.openai.com/API/examples/data/fine_food_reviews_with_embeddings_1k.csv\n",
"datafile_path = \"./data/fine_food_reviews_with_embeddings_1k.csv\"\n",
"\n",
2022-07-11 15:36:54 -07:00
"df = pd.read_csv(datafile_path)\n",
"df[\"ada_similarity\"] = df.ada_similarity.apply(eval).apply(np.array)\n",
"matrix = np.vstack(df.ada_similarity.values)\n",
2022-07-11 15:38:07 -07:00
"matrix.shape\n"
2022-03-10 18:08:53 -08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Find the clusters using K-means"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We show the simplest use of K-means. You can pick the number of clusters that fits your use case best."
]
},
{
"cell_type": "code",
2022-07-11 15:36:54 -07:00
"execution_count": 2,
2022-03-10 18:08:53 -08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Cluster\n",
"0 4.105691\n",
"1 4.191176\n",
"2 4.215613\n",
"3 4.306590\n",
2022-03-10 18:08:53 -08:00
"Name: Score, dtype: float64"
]
},
2022-07-11 15:36:54 -07:00
"execution_count": 2,
2022-03-10 18:08:53 -08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cluster import KMeans\n",
"\n",
"n_clusters = 4\n",
"\n",
"kmeans = KMeans(n_clusters=n_clusters, init=\"k-means++\", random_state=42, n_init='auto')\n",
2022-03-10 18:08:53 -08:00
"kmeans.fit(matrix)\n",
"labels = kmeans.labels_\n",
2022-07-11 15:38:07 -07:00
"df[\"Cluster\"] = labels\n",
2022-03-10 18:08:53 -08:00
"\n",
2022-07-11 15:38:07 -07:00
"df.groupby(\"Cluster\").Score.mean().sort_values()\n"
2022-03-10 18:08:53 -08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It looks like cluster 2 focused on negative reviews, while cluster 0 and 1 focused on positive reviews."
]
},
{
"cell_type": "code",
2022-07-11 15:36:54 -07:00
"execution_count": 3,
2022-03-10 18:08:53 -08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Clusters identified visualized in language 2d using t-SNE')"
]
},
2022-07-11 15:36:54 -07:00
"execution_count": 3,
2022-03-10 18:08:53 -08:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGzCAYAAAABsTylAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd5hc5Xm37zO9z87Mttld1V0VhBBCooliMBZgShxkTIttikuIwRCS2P5MYicmwcZ2iitgYxOZ2MEFDI5jjGkxYEwxqICE0Era1RbtztbpfU75/niZ0c4WaVfSogXe+7r2kubMmTPvOXNmzu887/P8HsUwDAOJRCKRSCSSOYjpaA9AIpFIJBKJZCqkUJFIJBKJRDJnkUJFIpFIJBLJnEUKFYlEIpFIJHMWKVQkEolEIpHMWaRQkUgkEolEMmeRQkUikUgkEsmcRQoViUQikUgkcxYpVCQSiUQikcxZpFCZAQsXLuTaa6892sN4S/nRj36Eoih0dXUddN2jfXwUReFLX/pS1bKXX36Z0047DbfbjaIobN26lS996UsoinJE3/vss8/m7LPPPmLbm2xfjgbjxzGT8+FIMt1z60gft2uvvZaFCxcese1JDp0j/R2bDvLznxtIoQJ0dHRw/fXXs3jxYhwOBz6fj9NPP51vfetb5HK5t2QM2WyWL33pSzz99NNvyfu9Xfntb3877QtRqVTisssuIxqN8o1vfIMf//jHLFiwYHYHKJFIJrBz504+97nPsXr1arxeL+FwmIsuuohXXnnlaA9tznLXXXfxox/9aEav2bZtGx/60IdYsGABDoeD5uZmzj33XL7zne9Urbdw4UIUReGmm26asI2nn34aRVF48MEHK8vKNyhT/b344ouHtI/TxTKrW38b8Mgjj3DZZZdht9u5+uqrWblyJcVikeeee47PfvazvP7669xzzz2zPo5sNsttt90G8JbfNRyIj370o1x55ZXY7fajPRRACJU777xzUrGSy+WwWPaf0h0dHXR3d/ODH/yAT3ziE5XlX/jCF/j85z//Vgz3kBm/L3OFuXY+jGeuHrd3Oz/84Q+59957ufTSS7nhhhtIJBJ8//vf59RTT+V3v/sd69evP9pDnJQf/OAH6Lp+VN77rrvuora2dtpR6ueff573vve9zJ8/n09+8pM0NjbS29vLiy++yLe+9a1JRckPfvADbr31Vpqamqb1Hv/8z//MokWLJixva2ub1usPlXf1N3rv3r1ceeWVLFiwgP/7v/8jHA5XnrvxxhvZs2cPjzzyyFEc4eGTyWRwu92H/Hqz2YzZbD6CI5o9HA5H1eOhoSEAampqqpZbLJY5fzEbvy9zhbl+PszV4/Zu56qrruJLX/oSHo+nsuxjH/sYxxxzDF/60pfmrFCxWq1HewjT5stf/jJ+v5+XX355wm9e+bdwLMceeyzt7e189atf5dvf/va03uOCCy7gxBNPPBLDnRHv6qmfr3/966TTae69994qkVKmra2Nv/7rv57y9VPlOkw2j//KK69w/vnnU1tbi9PpZNGiRXzsYx8DoKuri7q6OgBuu+22SjhtbNRg586dfOhDHyIYDOJwODjxxBP59a9/Pen7PvPMM9xwww3U19fT0tICQCqV4pZbbmHhwoXY7Xbq6+s599xz2bx58wGP0WT7YhgGt99+Oy0tLbhcLt773vfy+uuvT/r6eDzOLbfcwrx587Db7bS1tfG1r32t6i6lq6sLRVH4t3/7N+655x5aW1ux2+2cdNJJvPzyy5X1rr32Wu68806AqrBjmbHH7Nprr+Wss84C4LLLLkNRlEqkaqrP7Sc/+Qlr167F6XQSDAa58sor6e3tnbBeeYxOp5OTTz6ZP/zhDwc8hmVWrlzJe9/73gnLdV2nubmZD33oQ5PuC0zv85sqj2P83H6xWOQf//EfWbt2LX6/H7fbzZlnnsnvf//7g+7D+POhfCwn+xs7Fl3X+eY3v8mxxx6Lw+GgoaGB66+/nlgsVrX9mZxbkzH+uJXHt2fPHq699lpqamrw+/1cd911ZLPZaW93LP/2b//GaaedRigUwul0snbt2qow+dixfPrTn+ZXv/oVK1euxG63c+yxx/K73/1uwrpPP/00J554Ig6Hg9bWVr7//e9POE/L35PJpgPG73d3dzc33HADy5Ytw+l0EgqFuOyyyybNLXrttdc466yzcDqdtLS0cPvtt7Nx48ZJc5EeffRRzjzzTNxuN16vl4suumhan8/atWurRApAKBTizDPP5I033piw/qF+x2ZyjKbznRqfozLd36oyDzzwACtWrMDhcLBy5UoefvjhaeW9LFy4kNdff51nnnmm8n06WKS9o6ODY489doJIAaivr5/0Pa6++mp+8IMf0N/ff8BtH23m9m3lLPO///u/LF68mNNOO21W32doaIjzzjuPuro6Pv/5z1NTU0NXVxcPPfQQAHV1ddx999186lOfYsOGDXzwgx8EYNWqVQC8/vrrnH766TQ3N/P5z38et9vNL37xCy655BJ++ctfsmHDhqr3u+GGG6irq+Mf//EfyWQyAPzVX/0VDz74IJ/+9KdZsWIFo6OjPPfcc7zxxhusWbNmRvvzj//4j9x+++1ceOGFXHjhhWzevJnzzjuPYrFYtV42m+Wss86ir6+P66+/nvnz5/P8889z6623EolE+OY3v1m1/v33308qleL6669HURS+/vWv88EPfpDOzk6sVivXX389/f39PPHEE/z4xz8+4Bivv/56mpub+cpXvsLNN9/MSSedRENDw5Trf/nLX+aLX/wil19+OZ/4xCcYHh7mO9/5Du95z3vYsmVL5ct/7733cv3113Paaadxyy230NnZyQc+8AGCwSDz5s074JiuuOIKvvSlLzEwMEBjY2Nl+XPPPUd/fz9XXnnllK89kp9fMpnkhz/8IVdddRWf/OQnSaVS3HvvvZx//vn86U9/YvXq1dPe1gc/+MEJYd9NmzbxzW9+s+rH8frrr+dHP/oR1113HTfffDN79+7lu9/9Llu2bOGPf/xj5c51uufWTLn88stZtGgRd9xxB5s3b+aHP/wh9fX1fO1rX5vxtr71rW/xgQ98gA9/+MMUi0V+9rOfcdlll/Gb3/yGiy66qGrd5557joceeogbbrgBr9fLt7/9bS699FJ6enoIhUIAbNmyhfe///2Ew2Fuu+02NE3jn//5nys3L4fCyy+/zPPPP8+VV15JS0sLXV1d3H333Zx99tns2LEDl8sFQF9fH+9973tRFIVbb70Vt9vND3/4w0mn9n784x9zzTXXcP755/O1r32NbDbL3XffzRlnnMGWLVsOKel0YGCA2traqmWH8x2bCYfznTrYbxWItIIrrriC4447jjvuuINYLMbHP/5xmpubDzq2b37zm9x00014PB7+4R/+AeCAv18ACxYs4IUXXmD79u2sXLlyWsfgH/7hH/iv//qvaUdVEokEIyMjVcsURamcy7OG8S4lkUgYgPHnf/7n037NggULjGuuuaby+J/+6Z+MyQ7hxo0bDcDYu3evYRiG8fDDDxuA8fLLL0+57eHhYQMw/umf/mnCc+973/uM4447zsjn85Vluq4bp512mrFkyZIJ73vGGWcYqqpWbcPv9xs33njjNPd06n0ZGhoybDabcdFFFxm6rlfW+/u//3sDqDo+//Iv/2K43W5j165dVdv8/Oc/b5jNZqOnp8cwDMPYu3evARihUMiIRqOV9f7nf/7HAIz//d//rSy78cYbJz3mhmFMOH6///3vDcB44IEHqtYb/7l1dXUZZrPZ+PKXv1y13rZt2wyLxVJZXiwWjfr6emP16tVGoVCorHfPPfcYgHHWWWdNOq4y7e3tBmB85zvfqVp+ww03GB6Px8hms1Puy3Q+v/HnZ5mzzjqramyqqlaN3zAMIxaLGQ0NDcbHPvaxquXjxzH+fBjP8PCwMX/+fOO4444z0um0YRiG8Yc//MEAjP/+7/+uWvd3v/td1fKZnFtTMX685c96/H5t2LDBCIVCB93eNddcYyxYsKBq2djPyTDEebFy5UrjnHPOmTAWm81m7Nmzp7Ls1Vd
2022-03-10 18:08:53 -08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
2022-03-10 18:08:53 -08:00
]
},
"metadata": {},
2022-03-10 18:08:53 -08:00
"output_type": "display_data"
}
],
"source": [
"from sklearn.manifold import TSNE\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
2022-07-11 15:38:07 -07:00
"tsne = TSNE(\n",
" n_components=2, perplexity=15, random_state=42, init=\"random\", learning_rate=200\n",
")\n",
2022-03-10 18:08:53 -08:00
"vis_dims2 = tsne.fit_transform(matrix)\n",
"\n",
2022-07-11 15:38:07 -07:00
"x = [x for x, y in vis_dims2]\n",
"y = [y for x, y in vis_dims2]\n",
2022-03-10 18:08:53 -08:00
"\n",
2022-07-11 15:38:07 -07:00
"for category, color in enumerate([\"purple\", \"green\", \"red\", \"blue\"]):\n",
" xs = np.array(x)[df.Cluster == category]\n",
" ys = np.array(y)[df.Cluster == category]\n",
2022-03-10 18:08:53 -08:00
" plt.scatter(xs, ys, color=color, alpha=0.3)\n",
"\n",
" avg_x = xs.mean()\n",
" avg_y = ys.mean()\n",
2022-07-11 15:38:07 -07:00
"\n",
" plt.scatter(avg_x, avg_y, marker=\"x\", color=color, s=100)\n",
"plt.title(\"Clusters identified visualized in language 2d using t-SNE\")\n"
2022-03-10 18:08:53 -08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualization of clusters in a 2d projection. The red cluster clearly represents negative reviews. The blue cluster seems quite different from the others. Let's see a few samples from each cluster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Text samples in the clusters & naming the clusters\n",
"\n",
"Let's show random samples from each cluster. We'll use davinci-instruct-beta-v3 to name the clusters, based on a random sample of 6 reviews from that cluster."
]
},
{
"cell_type": "code",
2022-07-11 15:36:54 -07:00
"execution_count": 4,
2022-03-10 18:08:53 -08:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cluster 0 Theme: All reviews mention the good flavor of the bars.\n",
"5, Loved these gluten free healthy bars, saved $$ ordering on Amazon: These Kind Bars are so good and healthy & gluten free. My daughter ca\n",
"1, Should advertise coconut as an ingredient more prominently: First, these should be called Mac - Coconut bars, as Coconut is the #2\n",
"5, very good!!: just like the runts<br />great flavor, def worth getting<br />I even o\n",
2022-03-10 18:08:53 -08:00
"----------------------------------------------------------------------------------------------------\n",
"Cluster 1 Theme: The customer reviews have in common that the customers are not happy with the product. The customers do not think that the product is good for their cats, that the gravy is not good, and that the product is not good for their hands.\n",
"2, Messy and apparently undelicious: My cat is not a huge fan. Sure, she'll lap up the gravy, but leaves th\n",
"4, The cats like it: My 7 cats like this food but it is a little yucky for the human. Piece\n",
"5, cant get enough of it!!!: Our lil shih tzu puppy cannot get enough of it. Everytime she sees the\n",
2022-03-10 18:08:53 -08:00
"----------------------------------------------------------------------------------------------------\n",
"Cluster 2 Theme: The reviews have in common that they are all about Fog Chaser Coffee.\n",
"5, Fog Chaser Coffee: This coffee has a full body and a rich taste. The price is far below t\n",
"5, Excellent taste: This is to me a great coffee, once you try it you will enjoy it, this \n",
"4, Good, but not Wolfgang Puck good: Honestly, I have to admit that I expected a little better. That's not \n",
2022-03-10 18:08:53 -08:00
"----------------------------------------------------------------------------------------------------\n",
"Cluster 3 Theme: All reviews mention the product's content.\n",
"5, Wonderful alternative to soda pop: This is a wonderful alternative to soda pop. It's carbonated for thos\n",
"5, So convenient, for so little!: I needed two vanilla beans for the Love Goddess cake that my husbands \n",
"2, bot very cheesy: Got this about a month ago.first of all it smells horrible...it tastes\n",
2022-03-10 18:08:53 -08:00
"----------------------------------------------------------------------------------------------------\n"
]
}
],
"source": [
"import openai\n",
"\n",
"# Reading a review which belong to each group.\n",
"rev_per_cluster = 3\n",
"\n",
"for i in range(n_clusters):\n",
" print(f\"Cluster {i} Theme:\", end=\" \")\n",
2022-07-11 15:38:07 -07:00
"\n",
" reviews = \"\\n\".join(\n",
" df[df.Cluster == i]\n",
" .combined.str.replace(\"Title: \", \"\")\n",
" .str.replace(\"\\n\\nContent: \", \": \")\n",
" .sample(rev_per_cluster, random_state=42)\n",
" .values\n",
" )\n",
2022-03-10 18:08:53 -08:00
" response = openai.Completion.create(\n",
" engine=\"davinci-instruct-beta-v3\",\n",
2022-07-11 15:38:07 -07:00
" prompt=f'What do the following customer reviews have in common?\\n\\nCustomer reviews:\\n\"\"\"\\n{reviews}\\n\"\"\"\\n\\nTheme:',\n",
2022-03-10 18:08:53 -08:00
" temperature=0,\n",
" max_tokens=64,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
2022-07-11 15:38:07 -07:00
" presence_penalty=0,\n",
2022-03-10 18:08:53 -08:00
" )\n",
2022-07-11 15:38:07 -07:00
" print(response[\"choices\"][0][\"text\"].replace(\"\\n\", \"\"))\n",
2022-03-10 18:08:53 -08:00
"\n",
2022-07-11 15:38:07 -07:00
" sample_cluster_rows = df[df.Cluster == i].sample(rev_per_cluster, random_state=42)\n",
2022-03-10 18:08:53 -08:00
" for j in range(rev_per_cluster):\n",
" print(sample_cluster_rows.Score.values[j], end=\", \")\n",
" print(sample_cluster_rows.Summary.values[j], end=\": \")\n",
" print(sample_cluster_rows.Text.str[:70].values[j])\n",
2022-07-11 15:38:07 -07:00
"\n",
" print(\"-\" * 100)\n"
2022-03-10 18:08:53 -08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see based on the average ratings per cluster, that Cluster 2 contains mostly negative reviews. Cluster 0 and 1 contain mostly positive reviews, whilst Cluster 3 appears to contain reviews about dog products."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's important to note that clusters will not necessarily match what you intend to use them for. A larger amount of clusters will focus on more specific patterns, whereas a small number of clusters will usually focus on largest discrepencies in the data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai-cookbook",
2022-07-11 15:38:07 -07:00
"language": "python",
"name": "openai-cookbook"
2022-03-10 18:08:53 -08:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
2022-03-10 18:08:53 -08:00
},
2022-07-11 15:38:07 -07:00
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
}
2022-03-10 18:08:53 -08:00
},
"nbformat": 4,
"nbformat_minor": 2
}