2022-03-10 18:08:53 -08:00
{
"cells": [
{
"cell_type": "markdown",
"id": "983ef639-fbf4-4912-b593-9cf08aeb11cd",
"metadata": {},
"source": [
2022-07-11 17:12:07 -07:00
"# Visualizing embeddings in 3D"
2022-03-10 18:08:53 -08:00
]
},
{
"cell_type": "markdown",
"id": "9c9ea9a8-675d-4e3a-a8f7-6f4563df84ad",
"metadata": {},
"source": [
"The example uses [PCA](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html) to reduce the dimensionality fo the embeddings from 2048 to 3. Then we can visualize the data points in a 3D plot. The small dataset `dbpedia_samples.jsonl` is curated by randomly sampling 200 samples from [DBpedia validation dataset](https://www.kaggle.com/danofer/dbpedia-classes?select=DBPEDIA_val.csv)."
]
},
{
"cell_type": "markdown",
"id": "8df5f2c3-ddbb-4cc4-9205-4c0af1670562",
"metadata": {},
"source": [
"### 1. Load the dataset and query embeddings"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "133dfc2a-9dbd-4a5a-96fa-477272f7af5a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Categories of DBpedia samples: Artist 21\n",
"Plant 19\n",
2022-07-27 17:31:25 -07:00
"Film 18\n",
2022-03-10 18:08:53 -08:00
"OfficeHolder 18\n",
"Company 17\n",
"NaturalPlace 16\n",
"Athlete 16\n",
"Village 12\n",
2022-07-27 17:31:25 -07:00
"Building 12\n",
2022-03-10 18:08:53 -08:00
"WrittenWork 11\n",
"Album 11\n",
"Animal 11\n",
"EducationalInstitution 10\n",
"MeanOfTransportation 8\n",
2022-07-27 17:31:25 -07:00
"Name: completion, dtype: int64\n"
2022-03-10 18:08:53 -08:00
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
2022-07-27 17:31:25 -07:00
" <th>prompt</th>\n",
" <th>completion</th>\n",
2022-03-10 18:08:53 -08:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Morada Limited is a textile company based in ...</td>\n",
" <td>Company</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>The Armenian Mirror-Spectator is a newspaper ...</td>\n",
" <td>WrittenWork</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Mt. Kinka (金華山 Kinka-zan) also known as Kinka...</td>\n",
" <td>NaturalPlace</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Planning the Play of a Bridge Hand is a book ...</td>\n",
" <td>WrittenWork</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Wang Yuanping (born 8 December 1976) is a ret...</td>\n",
" <td>Athlete</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2022-07-27 17:31:25 -07:00
" prompt completion\n",
2022-03-10 18:08:53 -08:00
"0 Morada Limited is a textile company based in ... Company\n",
"1 The Armenian Mirror-Spectator is a newspaper ... WrittenWork\n",
"2 Mt. Kinka (金華山 Kinka-zan) also known as Kinka... NaturalPlace\n",
"3 Planning the Play of a Bridge Hand is a book ... WrittenWork\n",
"4 Wang Yuanping (born 8 December 1976) is a ret... Athlete"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
2022-07-11 17:12:07 -07:00
"samples = pd.read_json(\"data/dbpedia_samples.jsonl\", lines=True)\n",
2022-07-27 17:31:25 -07:00
"categories = sorted(samples[\"completion\"].unique())\n",
"print(\"Categories of DBpedia samples:\", samples[\"completion\"].value_counts())\n",
2022-03-10 18:08:53 -08:00
"samples.head()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "19874e3e-a216-48cc-a27b-acb73854d832",
"metadata": {},
"outputs": [],
"source": [
"from openai.embeddings_utils import get_embeddings\n",
"# NOTE: The following code will send a query of batch size 200 to /embeddings, cost about $0.2\n",
2022-07-27 17:31:25 -07:00
"matrix = get_embeddings(samples[\"prompt\"].to_list(), engine=\"text-similarity-babbage-001\")"
2022-03-10 18:08:53 -08:00
]
},
{
"cell_type": "markdown",
"id": "d410c268-d8a7-4979-887c-45b1d382dda9",
"metadata": {},
"source": [
"### 2. Reduce the embedding dimensionality"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f5410068-f3da-490c-8576-48e84a8728de",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.decomposition import PCA\n",
"pca = PCA(n_components=3)\n",
"vis_dims = pca.fit_transform(matrix)\n",
"samples[\"embed_vis\"] = vis_dims.tolist()"
]
},
{
"cell_type": "markdown",
"id": "b6565f57-59c6-4d36-a094-3cbbd9ddeb4c",
"metadata": {},
"source": [
"### 3. Plot the embeddings of lower dimensionality"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b17caad3-f0de-4115-83eb-55434a132acc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-07-27 17:31:25 -07:00
"<matplotlib.legend.Legend at 0x153a5f190>"
2022-03-10 18:08:53 -08:00
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-07-27 17:31:25 -07:00
"model_id": "04e0ac33bf214db4863b980f63a13706",
2022-03-10 18:08:53 -08:00
"version_major": 2,
"version_minor": 0
},
2022-07-27 17:31:25 -07:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+gAAAH0CAYAAACuKActAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXhb5Zn38e+RZFmWN3lfEifO6oQsZF/MFkoG0lJmaClNW2hJaZkZGOCltAOlDBQ6TWkpdCilpTNMIKULSzuF6UCB0pQtEEiAeIntOLZjx3a825JtydZ6zvuH0cHyvkiWk9yf68p1xbJ0Fkm29TvP89y3ommahhBCCCGEEEIIIaLKEO0DEEIIIYQQQgghhAR0IYQQQgghhBBiVpCALoQQQgghhBBCzAIS0IUQQgghhBBCiFlAAroQQgghhBBCCDELSEAXQgghhBBCCCFmAQnoQgghhBBCCCHELCABXQghhBBCCCGEmAUkoAshhBBCCCGEELOABHQhhBBCCCGEEGIWkIAuhBBCCCGEEELMAhLQhRBCCCGEEEKIWUACuhBCCCGEEEIIMQtIQBdCCCGEEEIIIWYBCehCCCGEEEIIIcQsIAFdCCGEEEIIIYSYBSSgCyGEEEIIIYQQs4AEdCGEEEIIIYQQYhaQgC6EEEIIIYQQQswCEtCFEEIIIYQQQohZQAK6EEIIIYQQQggxC0hAF0IIIYQQQgghZgEJ6EIIIYQQQgghxCwgAV0IIYQQQgghhJgFJKALIYQQQgghhBCzgAR0IYQQQgghhBBiFpCALoQQQgghhBBCzAIS0IUQQgghhBBCiFlAAroQQgghhBBCCDELSEAXQgghhBBCCCFmAQnoQgghhBBCCCHELCABXQghhBBCCCGEmAUkoAshhBBCCCGEELOABHQhhBBCCCGEEGIWkIAuhBBCCCGEEELMAhLQhRBCCCGEEEKIWUACuhBCCCGEEEIIMQtIQBdCCCGEEEIIIWYBCehCCCGEEEIIIcQsIAFdCCGEEEIIIYSYBSSgCyGEEEIIIYQQs4AEdCGEEEIIIYQQYhaQgC6EEEIIIYQQQswCEtCFEEIIIYQQQohZQAK6EEIIIYQQQggxC0hAF0IIIYQQQgghZgEJ6EIIIYQQQgghxCwgAV0IIYQQQgghhJgFJKALIYQQQgghhBCzgAR0IYQQQgghhBBiFpCALoQQQgghhBBCzAIS0IUQQgghhBBCiFlAAroQQgghhBBCCDELSEAXQgghhBBCCCFmAQnoQgghhBBCCCHELCABXQghhBBCCCGEmAUkoAshhBBCCCGEELOABHQhhBBCCCGEEGIWkIAuhBBCCCGEEELMAhLQhRBCCCGEEEKIWUACuhBCCCGEEEIIMQtIQBdCCCGEEEIIIWYBCehCCCGEEEIIIcQsIAFdCCGEEEIIIYSYBSSgCyGEEEIIIYQQs4Ap2gcghBBCCCGEmLhAIIDP54v2YQgxa8XExGA0GqN9GFMiAV0IIYQQQohTgKZptLS04HA4on0oQsx6NpuN7OxsFEWJ9qFMigR0IYQQQgghTgHBcJ6ZmYnVaj3lgocQM0HTNPr6+mhrawMgJycnykc0ORLQhRBCCCGEmOUCgYAeztPS0qJ9OELManFxcQC0tbWRmZl5Sk13lyJxQgghhBBCzHLBNedWqzXKRyLEqSH4s3Kq1WuQgC6EEEIIIcQpQqa1CzExp+rPigR0IYQQQgghhBBiFpCALoQQQgghhIia119/HUVR9Or0e/fuxWazRfWYhIgWCehCCCGEEEKIiDtw4ABGo5FLL7002ocixKwlAV0IIYQQQggRcXv27OGmm27izTffpKmpKdqHI8SsJAFdCCGEEEKIM0RNu5ObnzrM6nteYcP3X+X7L5TT3Rf5KtdOp5NnnnmG66+/nksvvZS9e/eO+5jnn3+eJUuWYLFYuOSSS2hoaNC/t2vXLi6//PKQ+99yyy1s27ZN/3rbtm3cdNNN3HLLLaSkpJCVlcVjjz2Gy+Xiq1/9KomJiSxevJiXXnopTGcpxPRJQBdCCCGEEOIMUNvh4h8eeZsXS5vpcfvpcHp54u1arvzPd+j3BiK672effZZly5ZRUFDA1VdfzeOPP46maaPev6+vj927d/Pkk0/y9ttv43A4+MIXvjDp/f7qV78iPT2dgwcPctNNN3H99ddz5ZVXUlhYyIcffsjFF1/Ml7/8Zfr6+qZzekKEjQR0IYQQQgghzgCP/K2Kfl+AgPpxMA5ocKzVyXOHT0Z033v27OHqq68GYMeOHXR3d/PGG2+Men+fz8cjjzzC1q1bWb9+Pb/61a945513OHjw4KT2e/bZZ/Nv//ZvLFmyhDvuuAOLxUJ6ejrXXXcdS5Ys4e6776azs5OSkpJpnZ8Q4SIBXQghhBBCiDPAG8faQ8J5kEGBt6s7IrbfyspKDh48yBe/+EUATCYTO3fuZM+ePaM+xmQysXHjRv3rZcuWYbPZqKiomNS+V69erf/faDSSlpbGqlWr9NuysrIAaGtrm9R2hYgUU7QPQAghhBBCCBF5VrMJ8A67XVEU4szGiO13z549+P1+cnNz9ds0TSM2NpZHHnlkSts0GAzDpsj7fMPX0sfExIR8rShKyG2KogCgquqUjkOIcJMRdCGEEEIIIc4An1k7B4My/PaAqnHZ2bnDvxEGfr+fJ598kgcffJCioiL9X3FxMbm5uTz11FOjPu7999/Xv66srMThcLB8+XIAMjIyaG5uDnlMUVFRRM5BiJkkAV0IIYQQQogzwD9dsJCz59oAMBoUjB+l9as2z+P8JekR2ecLL7yA3W7na1/7GitXrgz5d8UVV4w6zT0mJoabbrqJ9957jw8++IBdu3axZcsWNm3aBMAnPvEJ3n//fZ588kmqqqr47ne/y5EjRyJyDkLMJAnoQgghhBBCnAGsZhPP/NNWfvqFNVy+Zg47N+bxu69v5vuXr9Sneofbnj172L59O8nJycO+d8UVV/D++++PWKDNarVy++2386UvfYlzzjmHhIQEnnnmGf37l1xyCXfddRe33XYbGzdupLe3l6985SsROQchZpKijdXfQAghhBBCCBF1breb2tpaFixYgMViifbhCDHrnao/MzKCLoQQQgghhBBCzAIS0IUQQgghhBBCiFlAAroQQgghhBBCCDELSEAXQgghhBBCCCFmAVO0D0AIIYQ4HWmahqqqeDweYKBlkNFoRFGUiFVLFkIIIcSpTQK6EEIIEWaapuH3+/H7/Xg8HjRNw+PxoCgKRqNRD+tGoxGDQSazCSGEEGKABHQhhBAijFRVxefzoaoqAEajUf9eMLj7fD59JF0CuxBCCCGCJKALIYQQYRCc0h4M50ODdjCQB2/XNC0ksAMYDAZMJhMmk0kCuxBCCHEGkr/6QgghxDRpmobP58Pr9aJpGgaDYdx15sGwbjKZiImJwWQyoSgKPp+P0tJSamtr6enpwel04na7Q0blhRDiTHPPPfewZs2a02Y/QoxGAroQQggxDaqq4vV68fv9euieShG4wYHd4/EQCAQA8Pl89Pf343Q6hwV2TdPCfTpCCBExBw4cwGg0cumll076sd/61rfYt29fBI5KiNlFAroQQggxBcHp6cEwPVown2pYh4H164OnvMNAYO/r68PpdNLd3a0Hdr/fL4FdCDGr7dmzh5tuuok333yTpqamST02ISGBtLS0CB2ZELOHBHQhhBBikoJT2gevHQ9367ShYTtYUC44JX5oYO/t7dVH2D0ejwR2IcSs4nQ6eeaZZ7j++uu59NJL2bt3r/69119/HUVR2LdvHxs2bMBqtVJYWEhlZaV+n6FTz3ft2sXll1/OD37wA7KysrDZbHzve9/D7/fzr//6r6SmpjJ37lyeeOKJkOO4/fbbWbp0KVarlYULF3LXXXfpv8uFmA0koAshhBCTEAgE9AA8nSntY5nI9kYK7MELBy6XSw/sLpdLArsQQudw+XjnqJ3n32vlTwdb+bCmG7c3EPH9PvvssyxbtoyCggKuvvpqHn/88WG/k+68804efPBB3n//fUwmE9dee+2Y2/zb3/5GU1MTb775Jj/5yU/
2022-03-10 18:08:53 -08:00
"text/html": [
"\n",
" <div style=\"display: inline-block;\">\n",
" <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
" Figure\n",
" </div>\n",
2022-07-27 17:31:25 -07:00
" <img src='
2022-03-10 18:08:53 -08:00
" </div>\n",
" "
],
"text/plain": [
"Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib widget\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"fig = plt.figure(figsize=(10, 5))\n",
"ax = fig.add_subplot(projection='3d')\n",
"cmap = plt.get_cmap(\"tab20\")\n",
"\n",
"# Plot each sample category individually such that we can set label name.\n",
"for i, cat in enumerate(categories):\n",
2022-07-27 17:31:25 -07:00
" sub_matrix = np.array(samples[samples[\"completion\"] == cat][\"embed_vis\"].to_list())\n",
2022-03-10 18:08:53 -08:00
" x=sub_matrix[:, 0]\n",
" y=sub_matrix[:, 1]\n",
" z=sub_matrix[:, 2]\n",
" colors = [cmap(i/len(categories))] * len(sub_matrix)\n",
" ax.scatter(x, y, zs=z, zdir='z', c=colors, label=cat)\n",
"\n",
"ax.set_xlabel('x')\n",
"ax.set_ylabel('y')\n",
"ax.set_zlabel('z')\n",
"ax.legend(bbox_to_anchor=(1.1, 1))"
]
2022-07-27 17:31:25 -07:00
},
{
"cell_type": "code",
"execution_count": null,
"id": "163d0fc4",
"metadata": {},
"outputs": [],
"source": []
2022-03-10 18:08:53 -08:00
}
],
"metadata": {
"kernelspec": {
2022-07-27 17:31:25 -07:00
"display_name": "scratch-venv",
2022-03-10 18:08:53 -08:00
"language": "python",
2022-07-27 17:31:25 -07:00
"name": "scratch-venv"
2022-03-10 18:08:53 -08:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-07-27 17:31:25 -07:00
"version": "3.10.4"
2022-07-11 17:12:07 -07:00
},
"vscode": {
"interpreter": {
"hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
}
2022-03-10 18:08:53 -08:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}