Files
triton/python/tutorials/02-fused-softmax.ipynb

309 lines
32 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "induced-zoning",
"metadata": {},
"source": [
"# Getting Started"
]
},
{
"cell_type": "markdown",
"id": "median-malaysia",
"metadata": {},
"source": [
"Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "precise-professor",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Compute the row-wise softmax of x \\in R^{M \\times N}\n",
"def naive_softmax(x):\n",
" # read MN elements ; write M elements\n",
" x_max = torch.max(x, axis=1)[0]\n",
" # read 2MN elements ; write MN elements\n",
" z = x - x_max[:, None]\n",
" # read MN elements ; write MN elements\n",
" numerator = torch.exp(x)\n",
" # read MN elements ; write M elements\n",
" denominator = torch.sum(numerator, axis=1)\n",
" # read 2MN elements ; write MN elements\n",
" ret = numerator / denominator[:, None]\n",
" # in total: read 7MN elements ; wrote 3MN + 2M elements\n",
" return ret"
]
},
{
"cell_type": "markdown",
"id": "gorgeous-monday",
"metadata": {},
"source": [
"When implemented naively in pytorch, computing $y$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\n",
"\n",
"Instead, we want to write a custom \"fused\" pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory."
]
},
{
"cell_type": "markdown",
"id": "identical-conditions",
"metadata": {},
"source": [
"# Writing the Compute Kernel"
]
},
{
"cell_type": "markdown",
"id": "prepared-apparatus",
"metadata": {},
"source": [
"Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:\n",
"\n",
"```c\n",
"__global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){\n",
" // row index\n",
" int m = get_program_id(0);\n",
" // column indices\n",
" int n [BLOCK] = 0 ... BLOCK;\n",
" // the memory address of all the elements\n",
" // that we want to load can be computed as follows\n",
" float* px [BLOCK] = X + m*stride_xm + n;\n",
" // because BLOCK has to be a power of two\n",
" // (per Triton-C specs), it is important\n",
" // to guard each memory operation with predicates\n",
" // or we will read out of bounds\n",
" bool check[BLOCK] = n < N;\n",
" float x [BLOCK] = check ? *px : -F32_INFINITY;\n",
" // syntax for reduction in Triton is:\n",
" // x[..., OPERATOR, ...]\n",
" // ^\n",
" // index\n",
" // The operators currently supported are {min, max, +}\n",
" float z [BLOCK] = x - x[max];\n",
" // The exponential in Triton is fast but approximate \n",
" // (i.e., like __expf in CUDA)\n",
" float num [BLOCK] = exp(z);\n",
" float denom = num[+];\n",
" // The result of the reduction is now stored in y\n",
" float y [BLOCK] = num / denom;\n",
" // We write it back\n",
" float* py [BLOCK] = Y + m*stride_ym + n;\n",
" *?(check)py = y; \n",
"}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "forbidden-wednesday",
"metadata": {},
"source": [
"# Writing the Torch bindings"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "former-pottery",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"_src = \"\"\"\n",
"__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n",
" int m = get_program_id(0);\n",
" int n [BLOCK] = 0 ... BLOCK;\n",
" float* px [BLOCK] = X + m*stride_xm + n;\n",
" bool check[BLOCK] = n < N;\n",
" float x [BLOCK] = check ? *px : -F32_INFINITY;\n",
" float z [BLOCK] = x - x[max];\n",
" float num [BLOCK] = exp(z);\n",
" float denom = num[+];\n",
" float y [BLOCK] = num / denom;\n",
" float* py [BLOCK] = Y + m*stride_ym + n;\n",
" *?(check)py = y; \n",
"}\n",
"\"\"\"\n",
"\n",
"# We need to make sure that BLOCK is the smallest power of two\n",
"# greater than the number of rows N of the input matrix.\n",
"# Different values of BLOCK will result in different kernels\n",
"def next_power_of_2(n):\n",
" n -= 1\n",
" n |= n >> 1\n",
" n |= n >> 2\n",
" n |= n >> 4\n",
" n |= n >> 8\n",
" n |= n >> 16\n",
" n += 1\n",
" return n\n",
"\n",
"_kernels = dict()\n",
"def make_kernel(N, device):\n",
" BLOCK = next_power_of_2(N)\n",
" key = (BLOCK, device)\n",
" if key not in _kernels:\n",
" defines = {'BLOCK': BLOCK}\n",
" _kernels[key] = triton.kernel(_src, device=device, defines=defines)\n",
" return _kernels[key]\n",
"\n",
"class _softmax(torch.autograd.Function):\n",
" \n",
" @staticmethod\n",
" def forward(ctx, x):\n",
" # constraints of the op\n",
" assert x.dtype == torch.float32\n",
" y = torch.empty_like(x)\n",
" # *create launch grid*:\n",
" # here we just launch a grid of M programs\n",
" M, N = y.shape\n",
" grid = lambda opt: (M, )\n",
" # *launch kernel*:\n",
" kernel = make_kernel(N, y.device)\n",
" kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid = grid)\n",
" return y\n",
" \n",
"softmax = _softmax.apply"
]
},
{
"cell_type": "markdown",
"id": "exclusive-salvation",
"metadata": {},
"source": [
"# Writing a Unit Test"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "pretty-prospect",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.0054, 0.0004, 0.0007, ..., 0.0004, 0.0019, 0.0018],\n",
" [0.0008, 0.0014, 0.0006, ..., 0.0023, 0.0019, 0.0012],\n",
" [0.0009, 0.0003, 0.0001, ..., 0.0010, 0.0001, 0.0012],\n",
" ...,\n",
" [0.0003, 0.0003, 0.0036, ..., 0.0002, 0.0003, 0.0013],\n",
" [0.0025, 0.0008, 0.0004, ..., 0.0016, 0.0007, 0.0005],\n",
" [0.0003, 0.0026, 0.0004, ..., 0.0005, 0.0009, 0.0005]],\n",
" device='cuda:0')\n",
"tensor([[0.0054, 0.0004, 0.0007, ..., 0.0004, 0.0019, 0.0018],\n",
" [0.0008, 0.0014, 0.0006, ..., 0.0023, 0.0019, 0.0012],\n",
" [0.0009, 0.0003, 0.0001, ..., 0.0010, 0.0001, 0.0012],\n",
" ...,\n",
" [0.0003, 0.0003, 0.0036, ..., 0.0002, 0.0003, 0.0013],\n",
" [0.0025, 0.0008, 0.0004, ..., 0.0016, 0.0007, 0.0005],\n",
" [0.0003, 0.0026, 0.0004, ..., 0.0005, 0.0009, 0.0005]],\n",
" device='cuda:0')\n",
"True\n"
]
}
],
"source": [
"x = torch.randn(1823, 781, device='cuda')\n",
"y_tri = softmax(x)\n",
"y_ref = torch.softmax(x, axis=1)\n",
"print(y_tri)\n",
"print(y_ref)\n",
"print(torch.allclose(y_tri, y_ref))"
]
},
{
"cell_type": "markdown",
"id": "regular-andrew",
"metadata": {},
"source": [
"Seems to work!"
]
},
{
"cell_type": "markdown",
"id": "polished-australia",
"metadata": {},
"source": [
"# Writing a Benchmark"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "chubby-audit",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABCZUlEQVR4nO3dd3wUdfrA8c+zJYVUSEJLQu9Krwo2bICIFct5guXk9NSz9zsPz7vf2dt5p2LvvSGWUwFBQaqAVKVICTUEQgKkbfb7++M7WZIQUiCbzSbP29e+dnZ2Zuf5xmWfmfk2McaglFJKAbhCHYBSSqn6Q5OCUkqpAE0KSimlAjQpKKWUCtCkoJRSKsAT6gCORHJysmnXrl2ow1BKqbCycOHCncaYlIreC+uk0K5dOxYsWBDqMJRSKqyIyIZDvae3j5RSSgVoUlBKKRWgSUEppVSAJgWllFIBmhSUUkoFBC0piEi6iEwXkRUislxEbnDWNxORb0RktfPc1FkvIvKUiKwRkZ9FpF+wYlNKKVWxYF4p+IBbjDE9gCHAtSLSA7gTmGqM6QxMdV4DjAQ6O48JwDNBjE0ppVQFgtZPwRizFdjqLOeKyEogFTgLONHZ7FXgO+AOZ/1rxo7lPUdEEkWklfM5Kgzl5BcxbeUOCnzFnNqjJc1iIkId0hHbva+QvKJiYiI8NIl043XX7R1YYww+vyGvqJj8wmLyipyHs1xQ5KfAV0y+81zg85NfVEyxH6K8LqK9bqK8bqK8LqK8bmIiPfRKSyDS467Tcqj6q046r4lIO6AvMBdoUeqHfhvQwllOBTaV2i3DWVcmKYjIBOyVBG3atAle0A3cmh25fLF0GwnRXhKbeGkWE0HTJhEkNvHStEkELhGKjcFvDH6/odhv8BuI9LqIi/QgIhV+7p68Ir5dsZ0vlm7l+9U7KSz2A3DPx8sY1jmZMb1bc9pRLYmNrN/9Jo0xbNmTz/LNe1i+JYflW+zz1j35ZbbzuoUmER5iItxER9gf3EiP66BnEcHnNxT7/RQ7f0+f81xU7KeouOyzz1kuLPY7rw8s1/YUKK0TorjxlC6c2y8VTx0nOVX/BP1fpojEAh8CNxpjckr/mBhjjIjU6CtujJkETAIYMGCAzhB0mB79+le+XLbtsPaNiXDTKjGaVglRziOa+GgvP6zO5Ic1OykqNrROiOLSY9oyqmcrorwuPluylc+WbOHm95YQ6VnK8G7NGXF0S7xuF3vyisjeX0R2XiE5znKkx0XHlFg6No+lU/NY2iY1qdHZbE5+Eau25rJiyx7W7dxHoc/5MXaSnM9vE16hz9gz6iI/+b5i8ovsWfaevCL25BUB4BLokBLLoPbNOKp1PPFRXvYVFrO/wMf+Ivu8r7CY/YU+50zdnp3n5vsCZ+0AHpfgcgkel+Au9fC6XUR5XcRFefC4XER4xHl24XXb90seEc7raCcJRXvtIyrCTZTnwBVApMdFZKnE5BLIL/I75bMx5RUVsyMnn2dnrOX2D3/muZlrufW0row4uuUhk35jsLfAx7rMvazL3Eduga/ijcyBpO53lu3JEwzv1pyeaQnVOtbufYXc9dFSMvcWkJoYTWrT6MBzWmI0SbGRFDrfpzzn/13JFWFa02g6t4irxZJbEsyZ10TEC0wB/meMecxZ9wtwojFmq4i0Ar4zxnQVkeec5bfLb3eozx8wYIDRYS5qzu839L3/G07u3px7RnVn9/5Cdu8vYte+QrL3F5K9vwi/AbcLXCK4xP54uVxCXqGPrXvy2Zqdz9Y9eWzZk8/OvQUYA6mJ0ZzRqxUjj25Jn/TEg35YjDH8tHE3kxdv4fOlW9m5t7DM+xFuFwlNvCREe8krLGZzdl7gPbdLaNOsCe2SmpAQ7aVJpD07bxLhISbSTXSEh6y9BazcmsOKrTls2nVg37goD9Fety2DU5aSH+iSH+SSH8+SH9bYKA9dW8TRo3UC3VvF0SSifl/ZHAljDP9bvo1Hvv6VNTv20istgdtO78qwTsk1Sg5FxX72FfjIyfORk19kH3k+cvOLyM33lbnNVfpHrthviPC48bqFSM+BBOhxy4EfxMCtMj/5hcVEel20ToimdWI0rROjSE20yylxkYGEnOMcN9eJo8C5ai2voKiYDVn7WZu5l7WZe9meU3BEf88It4t/nduT8/qnVbrdpl37Gf/yPDJ259GvTSJbnH9TRcXV+03+4wkduGtk98OKUUQWGmMGVPhesJKC2G/Tq8AuY8yNpdY/DGQZYx4QkTuBZsaY20XkDOA6YBQwGHjKGDOosmNoUjg8yzbvYfS/f+CJC/twdt/UI/68Qp+f3fsLaR4XWe0fEV+xnxVbc/C6XSQ6iSDa6y6z//5CH+sy97E2cy8btmaSsP4LUnfPZw49+bJ4ELsKPeQVFQe2F4H2STF0bx1Pj1bx9HCeaxJXY+Yr9vPRos08+e1qNmfn0SE5hiaRbtwuF24Bj8uFy2UTdF5hMfsLi9lb4As8F/oq/tEtzSXYqxvnVlu0141LhKJiP4XFfgp99hZZoc9Pkd8Q6XEFti9dH5JX5GdLdh6ZuUf2A14iPspDx+axdEiOpUNKjL1KTYkhscmh68HKX/m5RNhX4OPat35i9tosrjmxI7ed1hWX6+Dv3vIte7js5fkUFBXz4mUDGdiuGQDFfsOO3Hw2785jc3Yeu/YVEulxEx3hck5YDvwNWiVE0zIh6rDKG6qkMAz4HlgKlHxb7sbWK7wHtAE2ABcYY3Y5SeRpYASwH7jcGFPpL74mhcPz3Iy1/OvLVcy9+2RaxB/el6pOGAMb58DiN2D5J1C4FyJi7XNkAvS6AH/fS8lLOop9BT5iozwN+oy+rhT4inlr7kZmr80K1H3YW25+/H4oNoZor5smEW5iI22Fe0ykh5gIDzGRHuKjPMRHe4mP8hIf7SE+ykuc8//G65ZaTdAFvmK27clnc3YeW7LzycwtoEmEm7ioA8eNc54jvS6Eg4/tdQsJ0d5ai6uo2M/fJi/nrbkbOa1HCx6/sA8xperQfli9k6vfWEh8lIdXrxgUlFtAVQlJUqgLmhQOz7iX5rElO49vbz4hdEHszYQ3zwOXB2JbQlwLiHUecS1hxwpY9CbsWmsTwVFnQ5/fQ/pg2DgbFr4KKz6F4gJo3Rf6jYNuoyEmxV4yKBVCxhhemb2e+6esoGvLeF4cP4DWidF8smgzt76/hE7NY3nl8kGHfaZ/pDQpqIBCn5/e933NBQPSuO+so0MXyOK34JNroO1QyMuGvdtgf1bZbdoOhT6XQI+zIDL24M/I2w0/v2cTxI7ldl1UIiR3gZQu9jm5CyR1hqZtwe0NdqmUKmP6Lzu4/q1FRHndjOndmpdm/cYxHZJ4blx/4qNC932sLCnotXYjs2jjbvKKijm2U3JoA9kwC6Kbwvgp4HKaQfoKYd8OyN0OMUnQtF3lnxHdFAb/EQZNgC0/waZ5sPNX2LkaVn8Di944sK3LA4ltIakTJHW0j2Yd7THiWoG3Ht9GU2HrpK7N+ehPx3Llq/N5adZvnNm7NY+M7VWv+4VoUmhkZq3NwiUwpENSaANZP8teCbhKtYv3REBCmn3UhAik9reP0vJ2w841NlHsWgtZayBrHaz/Hor2l922SRLEt4b4VPsc2wI8keDy2isMlwfcEfbR6WSICXFSVWGjS4s4Jl87jLm/ZXFaj5YVVjzXJ5oUGpnZa3bSMzWBhOgQ3krJ2QK7f4NBVwX3ONFNIX2gfZRmDORutUkie5ONJ2ezfd6zGTLmH3wrq7SU7vDHGTZpKFUNTWMiGHF0q1CHUS2aFBqRfQU+Fm/K5qrjO4Q2kA2z7XPboaE5vohzVdD60NsU+6C4EPxFdtlfBMVFsGkufHglfP8onHR33cWsVB3RpNCIzFu/C5/fMLRjPahPiIiDlj1DG0dl3B77KC8x3dZXfP8odD+zfpdBqcOgA500IrPX7CTC7WJAu6ahDWT9LGgzBFz1t7KtUiP+BdHN4NNr7VWEUg2IJoXyfptpW680QLPWZNGvbSJR3hD+GO/NhJ2/QLsQ3TqqDU2awRmPwNYlMPupyrfN3gQvj4JPr7N
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"M = 4096\n",
"Ns = [128*i for i in range(2, 50)]\n",
"tri_ms = []\n",
"ref_ms = []\n",
"def_ms = []\n",
"for N in Ns:\n",
" x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n",
" gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n",
" tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]\n",
" ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]\n",
" def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]\n",
"plt.xlabel('N')\n",
"plt.ylabel('Bandwidth (GB/s)')\n",
"plt.plot(Ns, tri_ms, label = 'Triton')\n",
"plt.plot(Ns, ref_ms, label = 'Torch')\n",
"plt.plot(Ns, def_ms, label = 'Naive')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}