Files
triton/_downloads/034d953b6214fedce6ea03803c712b89/02-fused-softmax.ipynb
2021-04-21 01:40:29 -04:00

161 lines
9.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Fused Softmax\nIn this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:\n\n- The benefits of kernel fusion for bandwidth-bound operations.\n- The reduction operators in Triton.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Motivations\nCustom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.\nLet us consider instead the case of a simple (numerically stabilized) softmax operation:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n\n\n# Compute the row-wise softmax of x\ndef naive_softmax(x):\n # read MN elements ; write M elements\n x_max = torch.max(x, axis=1)[0]\n # read 2MN elements ; write MN elements\n z = x - x_max[:, None]\n # read MN elements ; write MN elements\n numerator = torch.exp(x)\n # read MN elements ; write M elements\n denominator = torch.sum(numerator, axis=1)\n # read 2MN elements ; write MN elements\n ret = numerator / denominator[:, None]\n # in total: read 7MN elements ; wrote 3MN + 2M elements\n return ret"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When implemented naively in pytorch, computing :code:`y = naive_softmax(x)` for $x \\in R^{M \\times N}$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\nThis is obviously wasteful; we'd prefer to have a custom \"fused\" kernel that only reads X once and does all the necessary computations on-chip.\nThis solution would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of ~5x (i.e., $(10MN + 2M) / 2MN$).\nIn practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute Kernel\nOur softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.\nNote that one important limitation of Triton is that each block must have a power-of-two number of elements,\nso we need to internally \"pad\" tiles and guard the memory operations properly if we want to handle any possible input shapes:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import triton\n\n\n@triton.jit\ndef _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):\n # row index\n m = triton.program_id(0)\n # col indices\n n = triton.arange(0, meta['BLOCK'])\n # the memory address of all the elements\n # that we want to load can be computed as follows\n X = X + m * stride_xm + n\n x = triton.load(X, mask=n < N, other=-float('inf'))\n # Substract maximum for numerical stability\n z = x - triton.max(x, axis=0)\n # Note that exponentials in Triton are fast\n # but approximate (i.e., think __expf in CUDA)\n num = triton.exp(z)\n denom = triton.sum(num, axis=0)\n y = num / denom\n # Write back to Y\n Y = Y + m * stride_ym + n\n triton.store(Y, y, mask=n < N)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def next_power_of_2(n):\n n -= 1\n n |= n >> 1\n n |= n >> 2\n n |= n >> 4\n n |= n >> 8\n n |= n >> 16\n n += 1\n return n\n\n\ndef softmax(x):\n M, N = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK = next_power_of_2(N)\n # Another trick we can use is to ask the compiler to parallelize each\n # row-normalization more aggressively -- i.e., with more warps -- vectors\n # that are longer\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself\n num_warps = 4\n if BLOCK >= 2048: num_warps = 8\n if BLOCK >= 4096: num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix\n _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, BLOCK=BLOCK)\n return y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Unit Test\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We make sure that we test our kernel on a matrix with an irregular number of rows and columns.\nThis will allow us to verify that our padding mechanism works.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"torch.manual_seed(0)\nx = torch.randn(1823, 781, device='cuda')\ny_tri = softmax(x)\ny_ref = torch.softmax(x, axis=1)\nprint(torch.allclose(y_tri, y_ref))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As expected, the results are identical.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark\nHere we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.\nWe will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'], # argument names to use as an x-axis for the plot\n x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`\n y_name='provider', # argument name whose value corresponds to a different line in the plot\n y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`\n y_lines=[\"Torch\", \"Triton\", 'Naive'], # label name for the lines\n ylabel=\"GB/s\", # label name for the y-axis\n plot_name=\"softmax-performance\", # name for the plot. Used also as a file name for saving the plot.\n args={'M': 4096} # values for function arguments not in `x_names` and `y_name`\n )\n)\ndef benchmark(M, N, provider):\n x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n if provider == 'torch':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))\n if provider == 'triton':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))\n if provider == 'naive':\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))\n gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\nbenchmark.run(show_plots=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the above plot, we can see that:\n\n - Triton is 4-5x faster than the naive implementation, which is consistent with our theoretical predictions.\n - Triton is significantly faster than :code:`torch.softmax` for very large input matrices. My guess from looking at the source-code of the `PyTorch kernel <https://github.com/pytorch/pytorch/blob/9409a3a39b7149bb2d833a89e0c944109bef7c27/caffe2/operators/softmax_ops.cu#L240>`_ is that PyTorch only partially fuses the computation of the softmax.\n This means that -- when temporary data is too large to fit entirely in the GPU's cache -- it transfers almost twice the amount of data necessary.\n Note that our Triton kernel is not only faster than PyTorch's CUDA kernel, it is also **easier to read, understand and maintain**.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}