Files
triton/python/tutorials/02-fused-softmax.ipynb

309 lines
34 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "induced-zoning",
"metadata": {},
"source": [
2021-03-06 02:57:41 -05:00
"# Fused Softmax"
]
},
{
"cell_type": "markdown",
"id": "median-malaysia",
"metadata": {},
"source": [
"Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:"
]
},
{
"cell_type": "code",
2021-03-06 02:57:41 -05:00
"execution_count": 1,
"id": "precise-professor",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Compute the row-wise softmax of x \\in R^{M \\times N}\n",
"def naive_softmax(x):\n",
" # read MN elements ; write M elements\n",
" x_max = torch.max(x, axis=1)[0]\n",
" # read 2MN elements ; write MN elements\n",
" z = x - x_max[:, None]\n",
" # read MN elements ; write MN elements\n",
" numerator = torch.exp(x)\n",
" # read MN elements ; write M elements\n",
" denominator = torch.sum(numerator, axis=1)\n",
" # read 2MN elements ; write MN elements\n",
" ret = numerator / denominator[:, None]\n",
" # in total: read 7MN elements ; wrote 3MN + 2M elements\n",
" return ret"
]
},
{
"cell_type": "markdown",
"id": "gorgeous-monday",
"metadata": {},
"source": [
"When implemented naively in pytorch, computing $y$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\n",
"\n",
"Instead, we want to write a custom \"fused\" pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory."
]
},
{
"cell_type": "markdown",
"id": "identical-conditions",
"metadata": {},
"source": [
2021-03-06 02:57:41 -05:00
"## Writing the Compute Kernel"
]
},
{
"cell_type": "markdown",
"id": "prepared-apparatus",
"metadata": {},
"source": [
"Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:\n",
"\n",
"```c\n",
"__global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){\n",
" // row index\n",
" int m = get_program_id(0);\n",
" // column indices\n",
" int n [BLOCK] = 0 ... BLOCK;\n",
" // the memory address of all the elements\n",
" // that we want to load can be computed as follows\n",
" float* px [BLOCK] = X + m*stride_xm + n;\n",
" // because BLOCK has to be a power of two\n",
" // (per Triton-C specs), it is important\n",
" // to guard each memory operation with predicates\n",
" // or we will read out of bounds\n",
" bool check[BLOCK] = n < N;\n",
" float x [BLOCK] = check ? *px : -F32_INFINITY;\n",
" // syntax for reduction in Triton is:\n",
" // x[..., OPERATOR, ...]\n",
" // ^\n",
" // index\n",
" // The operators currently supported are {min, max, +}\n",
" float z [BLOCK] = x - x[max];\n",
" // The exponential in Triton is fast but approximate \n",
" // (i.e., like __expf in CUDA)\n",
" float num [BLOCK] = exp(z);\n",
" float denom = num[+];\n",
" // The result of the reduction is now stored in y\n",
" float y [BLOCK] = num / denom;\n",
" // We write it back\n",
" float* py [BLOCK] = Y + m*stride_ym + n;\n",
" *?(check)py = y; \n",
"}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "forbidden-wednesday",
"metadata": {},
"source": [
2021-03-06 02:57:41 -05:00
"## Writing the Torch bindings"
]
},
{
"cell_type": "code",
2021-03-06 02:57:41 -05:00
"execution_count": 2,
"id": "former-pottery",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"_src = \"\"\"\n",
"__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n",
" int m = get_program_id(0);\n",
" int n [BLOCK] = 0 ... BLOCK;\n",
" float* px [BLOCK] = X + m*stride_xm + n;\n",
" bool check[BLOCK] = n < N;\n",
" float x [BLOCK] = check ? *px : -F32_INFINITY;\n",
" float z [BLOCK] = x - x[max];\n",
" float num [BLOCK] = exp(z);\n",
" float denom = num[+];\n",
" float y [BLOCK] = num / denom;\n",
" float* py [BLOCK] = Y + m*stride_ym + n;\n",
" *?(check)py = y; \n",
"}\n",
"\"\"\"\n",
"\n",
"# We need to make sure that BLOCK is the smallest power of two\n",
"# greater than the number of rows N of the input matrix.\n",
"# Different values of BLOCK will result in different kernels\n",
"def next_power_of_2(n):\n",
" n -= 1\n",
" n |= n >> 1\n",
" n |= n >> 2\n",
" n |= n >> 4\n",
" n |= n >> 8\n",
" n |= n >> 16\n",
" n += 1\n",
" return n\n",
"\n",
"_kernels = dict()\n",
"def make_kernel(N, device):\n",
" BLOCK = next_power_of_2(N)\n",
" key = (BLOCK, device)\n",
" if key not in _kernels:\n",
" defines = {'BLOCK': BLOCK}\n",
" _kernels[key] = triton.kernel(_src, device=device, defines=defines)\n",
" return _kernels[key]\n",
"\n",
"class _softmax(torch.autograd.Function):\n",
" \n",
" @staticmethod\n",
" def forward(ctx, x):\n",
" # constraints of the op\n",
" assert x.dtype == torch.float32\n",
" y = torch.empty_like(x)\n",
" # *create launch grid*:\n",
" # here we just launch a grid of M programs\n",
" M, N = y.shape\n",
" grid = lambda opt: (M, )\n",
" # *launch kernel*:\n",
" kernel = make_kernel(N, y.device)\n",
" kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid = grid)\n",
" return y\n",
" \n",
"softmax = _softmax.apply"
]
},
{
"cell_type": "markdown",
"id": "exclusive-salvation",
"metadata": {},
"source": [
2021-03-06 02:57:41 -05:00
"## Writing a Unit Test"
]
},
{
"cell_type": "code",
2021-03-06 02:57:41 -05:00
"execution_count": 3,
"id": "pretty-prospect",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-06 02:57:41 -05:00
"tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],\n",
" [0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],\n",
" [0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],\n",
" ...,\n",
2021-03-06 02:57:41 -05:00
" [0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],\n",
" [0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],\n",
" [0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],\n",
" device='cuda:0')\n",
2021-03-06 02:57:41 -05:00
"tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],\n",
" [0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],\n",
" [0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],\n",
" ...,\n",
2021-03-06 02:57:41 -05:00
" [0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],\n",
" [0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],\n",
" [0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],\n",
" device='cuda:0')\n",
"True\n"
]
}
],
"source": [
"x = torch.randn(1823, 781, device='cuda')\n",
"y_tri = softmax(x)\n",
"y_ref = torch.softmax(x, axis=1)\n",
"print(y_tri)\n",
"print(y_ref)\n",
"print(torch.allclose(y_tri, y_ref))"
]
},
{
"cell_type": "markdown",
"id": "regular-andrew",
"metadata": {},
"source": [
"Seems to work!"
]
},
{
"cell_type": "markdown",
"id": "polished-australia",
"metadata": {},
"source": [
2021-03-06 02:57:41 -05:00
"## Writing a Benchmark"
]
},
{
"cell_type": "code",
2021-03-06 02:57:41 -05:00
"execution_count": 4,
"id": "chubby-audit",
"metadata": {},
"outputs": [
{
"data": {
2021-03-06 02:57:41 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABH8ElEQVR4nO3dd3xUVfr48c+TmfRKGi0EAlKkSQfFgmIv4LoWXAu6KmtddVdFd/enbnddXVdXvypr731X7GtHRToIUqUTICQkkF6mnN8f52YIGNLIZDLJ8/Y1rztz596Zc3Fyn3tPeY4YY1BKKaUAIkJdAKWUUu2HBgWllFIBGhSUUkoFaFBQSikVoEFBKaVUgDvUBTgU6enppk+fPqEuhlJKhZXFixfvNsZk1PdeWAeFPn36sGjRolAXQymlwoqIbDnYe1p9pJRSKkCDglJKqQANCkoppQI0KCillArQoKCUUiogaEFBRHqJyOciskpEVorIjc76VBH5WER+cJZdnPUiIg+JyHoRWS4io4JVNqWUUvUL5p2CF/i1MWYwMAG4TkQGA7cDnxpj+gOfOq8BTgP6O48ZwKNBLJtSSql6BG2cgjFmJ7DTeV4qIquBnsBUYJKz2bPAF8BMZ/1zxubyniciKSLS3fkcFabKqr34fIbkuMhD/iyf35C7p4KeKbG4XU27nqnx+lm3q5SSKg8V1T7Ka7xU1Pgor/ZSWeNj0sBMhmUlH3LZWkOVx0dJpYe9lR5KqzwM7p5MbJQr1MVSnUybDF4TkT7ASGA+0LXOiT4P6Oo87wlsq7NbrrNuv6AgIjOwdxJkZ2cHr9CdkMfnJ7+0mt2l1STEuMlMjCYh2o2INOtz/H7DtxsLeX3RNj5cmYfXZzh5SFcuHJfNxH7pREQ0/fN27K3kqx8KmLNuN1+v301xpYfEaDfj+6Yx8bA0Jh6WTv/MhEAZ/X7DmrxSvlm/m2827Gb+xiIqPb6Dfv6sORt5+/qJ9M1IaHKZvD4/xZUe9lR42FtRQ1F5DXsrPJRWezHGICLUHqEICFDp8VNa5aGkykNpldd5eCip9FJc6WFvZQ1VHv9+3/OL4/pyx2mHN7lcSrWGoAcFEUkA3gRuMsaU1D3BGGOMiDRrlh9jzCxgFsCYMWN0hqADFFd4WJNXwpq8UtbklbBuVxl+Y4iNdNlH1L6lz2/YVVJFXkkVecXVFJZXc+CcS3FRLjITo8lMiiEzMZruyTFkp8aRlRpHdmocPVNiiYm0V7NbCyt4Y/E23lyyne17K0mMcfPTUVnERrp4c0ku76/IIzs1jmnjenHe6F5kJEYHvscYw+6yGrbtqWBbUQXLtu1lzroCNhSUA9A1KZqTBnfliF4prNpRzDfrC/lk9S4AMhKjOapfGj6/Ye6GQorKawDolxHP+WOyGJeTRlpCFPFRbuKiXYHl3nIPUx/5ml88v5j/XjeR+OiG/xz2VtRw1XOLWLh5T4v+37gjhMQYN0mxkSTGuEmMjqR3WhwpcZGkxEWRHBtJcmwkKXGRPPrFBuZvLGrR9yh1KIIaFEQkEhsQXjTGvOWs3lVbLSQi3YF8Z/12oFed3bOcdaoRn6zaxcsLtrJ6Zwk7iqsC61PiIhnQNZE4t4vKGh/FlR4qPT6qanxUenyICF2TYuiaFM3QHsl0TYqhW3IMGQnRlFV72VVSRX5pNfml1ewqqWLljhI+XrWLau++K1oR6JoYQ0pcJGvyShGBY/pnMPO0QZw8uGsgYNxyykA+WpnHS/O3cu+Ha3ng43UcNyADr9+wraiC3D2V+31utDuC8X3TuHBcNscOyNjvbqDWtqIK5m7YzdwNhczdUEiEwKQBGUw8LJ2Jh6XTLTmmwX+3pJhI/nXhKC59aj63vbmchy8cedC7opIqD9OfWsDqnaVcM6kfXROj6RIfRUpcFKlxUaTERZIUE2lvC5zAajAYY1/GRrqIiYxo8l3X99tLePLrjVR5fIF/Q6XaQtCCgthf/5PAamPMP+q8NRuYDtzjLN+us/56EXkFGA8Ua3tCw4orPfzhnVW8uSSXrC6xjMtJZVD3JAZ1S+Tw7klkJkY3u+qnMcYYCkqr2VpUwbY9FWwtrGRrUQX5pVWcdUQPzhnVk+7JsT/aLybSxdQRPZk6oifr88t449t1DFrxV/KicuiWcQLHD+xNr9Q4srrE0su5C2nsZNgrNY4LUrO5YGzLqxGP7p/OracM4m8frmFkrxSuPKbvj7Ypr/by86cXsnJHCY9dPJoTB3et55Na1+jeXXjsS8OK7cWM7ZMa9O9TqlYw7xQmApcAK0RkmbPuN9hg8JqIXAFsAc533nsfOB1YD1QAlwexbO1KjdfPsm17mbthN99uKKSwvIafjOzJBWN7kZ4QXe8+c9YVMPPN5eSXVvPLEw7j+hP6E+UO/rATEbFVSUkxjGnhyeqwzARuP7wAlv7P9lHb8jh4x0CXKdD1LEhtwkm3ci/kr4b8Vc5jNbiioNd46DUOssZATNMakK8+ri/fbdvLXz9Yw5AeyRzZL23f19T4uOLZhSzZuoeHfzaqTQICwKjsFACWbNmjQaGd8fkNheXVRLtcREdGEO1u+h1gLb/fsKu0im1FleTuqcBvsFWKTrVi7fOk2Egim9iporWIObASOYyMGTPGhGuW1O+3F/PlugLmbSxk4eYiqjx+RGBIjyTio9zM31REpEs4fVh3LpnQm9G9uyAilFd7+cv7q3lx/lb6ZcTzj/NHcESvlFAfTvN9+Xf4/E9w1eew8XNYNRt2LrPvdR0K3UeA8YHfC/46S0857P4BSurULEYnQcYg8FRC/kowfkAgc7ANEL3GQ84xkJx10OKUVnk4+5FvKK708M4NR9M9OZYqj4+rnlvE1+t388D5Izh7ZM9g/ov8yPH3fUH/zARmXTqmTb83lDw+P+XVXsqqbS+xao+faq+PGq+f6sDDR7Q7grSEaNIToklLiCLxgA4RZdVecvdUkOucdHP3VFJe4yNCwBUhRIg4S3C7IkiItifgpJg6y5hIiis9bNxdzsaCcjbtLmNjQTlbCiuo8e1fhVrbZhfjtNfFR7mIi3ITH71vCcL2vZXkOtWldT/jYLonxzD7+qP3a39rDSKy2BhT7w9Lg0KwGAPbFkDecsg+EroOodLj553vdvD8vC2s2F4MwMCuiRzZL40j+6UxISct0HVzfX4ZL87fwhuLcymt8jKoWyJTRvTglQXb2LangiuPzuHXJw8M3/rmVy+GXavgl0v2rdu7FVa/Yx97NkNEJES4IMK97+GOgrTDIPNwe9LPHGxP9rUnhOpS2L7Y/ttvmw/bFkK1/bcmtR/kHAt9j4M+x0J82n5FWp9fytSHv2FAt0ReuGI8N76ylE9W53PvT4dz/thetLVfvbaMOesKWPjbE1u9GrC9ePqbTTz1zSbKq32UVXup8TZ+oqxPlDuC9PgoEmMiyS+tYk+FZ7/3YyIjSIyJxO83+I3B5zf4DfiNwePz4/E1fB6MdAnZqXH0zUigb3o8PVJi8foNVR4fVR4flU47XaXzurzaR0WNl/Jqu6682ovPb+iREkuv1Fh6dYmjV2pcoMrUHSH79Uorq/ZSVF7DXz9Yw0Xjs/nD1KEt+nc5GA0KrcHnhbzvILkXJGQefLuijfDdq7D8VdizKbC6JDKdzzxD+aRmGDtSx3P2xGGcPqz7QauHalXUeJm9bAfPfbuFVTtLyE6N477zjmBcTphXKfxzOPQcBec9E9zv8ftt9dKmObDpS9j8DdSU2ve6DYOscdDVCS6Zh/PB+iqueXEJXZOi2VVSzR/PHsolE3oHt4wH8eL8Lfz2P9/z5a2T6J0WH5IyBNvk+7/A6zcc0z+d+Gg3CVFuu4y2PcRi3LaKJsoVQXSkiyhXBFHuCGq8fnaX2R5zu0tr2O0sS6o8dE2KJquLPdnWLtPioxoMrFUeHyVOF2G79FBS5SUh2kXf9AR74m7jahyA3/5nBa8u3MYnvzqOPumt9xtoKCiE9SQ7bcZbDa9Nh3Uf2NfxGdB1iK3m6DoEMgbCjmXw3SuQuwAQyDmG/FG/5J7vk3HlzmOSbzknRS7m7KgvMGUPI9+Pgh1
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"M = 4096\n",
"Ns = [128*i for i in range(2, 50)]\n",
"tri_ms = []\n",
"ref_ms = []\n",
"def_ms = []\n",
"for N in Ns:\n",
" x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n",
" gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n",
" tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]\n",
" ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]\n",
" def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]\n",
"plt.xlabel('N')\n",
"plt.ylabel('Bandwidth (GB/s)')\n",
"plt.plot(Ns, tri_ms, label = 'Triton')\n",
"plt.plot(Ns, ref_ms, label = 'Torch')\n",
"plt.plot(Ns, def_ms, label = 'Naive')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}