From ca04da35755c02e5b665c08727a1861b7e92af0c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 6 Mar 2021 14:03:01 -0500 Subject: [PATCH] [DOCS] Switched tutorials to Python and use Sphinx Gallery --- docs/conf.py | 13 +- docs/index.rst | 7 +- docs/tutorials/01-vector-add.ipynb | 159 +++++++++++- docs/tutorials/02-fused-softmax.ipynb | 1 - python/tutorials/01-vector-add.ipynb | 329 ------------------------ python/tutorials/01-vector-add.py | 158 ++++++++++++ python/tutorials/02-fused-softmax.ipynb | 308 ---------------------- python/tutorials/02-fused-softmax.py | 181 +++++++++++++ python/tutorials/README.rst | 4 + 9 files changed, 515 insertions(+), 645 deletions(-) mode change 120000 => 100644 docs/tutorials/01-vector-add.ipynb delete mode 120000 docs/tutorials/02-fused-softmax.ipynb delete mode 100644 python/tutorials/01-vector-add.ipynb create mode 100644 python/tutorials/01-vector-add.py delete mode 100644 python/tutorials/02-fused-softmax.ipynb create mode 100644 python/tutorials/02-fused-softmax.py create mode 100644 python/tutorials/README.rst diff --git a/docs/conf.py b/docs/conf.py index fdf193464..85ebd6889 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,9 +30,18 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['nbsphinx', 'sphinx.ext.autosectionlabel'] +extensions = ['sphinx.ext.autosectionlabel'] autosectionlabel_prefix_document = True +# Sphinx gallery +extensions += ['sphinx_gallery.gen_gallery'] +sphinx_gallery_conf = { + 'examples_dirs': '../python/tutorials/', + 'gallery_dirs': 'tutorials', + 'filename_pattern': '', + 'ignore_pattern': r'__init__\.py', +} + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -69,7 +78,7 @@ language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' diff --git a/docs/index.rst b/docs/index.rst index d42e953fe..77b005964 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,7 +15,6 @@ Welcome to Triton's documentation! .. toctree:: :maxdepth: 1 - :caption: Tutorials - - Vector Addition - Fused Softmax \ No newline at end of file + :caption: Installation Instructions + + tutorials/index \ No newline at end of file diff --git a/docs/tutorials/01-vector-add.ipynb b/docs/tutorials/01-vector-add.ipynb deleted file mode 120000 index 047b85dda..000000000 --- a/docs/tutorials/01-vector-add.ipynb +++ /dev/null @@ -1 +0,0 @@ -../../python/tutorials/01-vector-add.ipynb \ No newline at end of file diff --git a/docs/tutorials/01-vector-add.ipynb b/docs/tutorials/01-vector-add.ipynb new file mode 100644 index 000000000..6856022dc --- /dev/null +++ b/docs/tutorials/01-vector-add.ipynb @@ -0,0 +1,158 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Vector Addition\nIn this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n* The basic syntax of the Triton programming language\n* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n* The best practices for validating and benchmarking custom ops against native reference implementations\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Writing the Compute Kernel\n\nEach compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel\non different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)\nprogramming model for more details).\n\n .. code-block:: C\n\n __global__ void add(float* z, float* x, float* y, int N){\n // The `get_program_id(i)` returns the i-th coordinate\n // of the program in the overaching SPMD context\n // (a.k.a launch grid). This is what allows us to process\n // different chunks of data in parallel.\n // For those similar with CUDA, `get_program_id({0,1,2})`\n // is similar to blockIdx.{x,y,z}\n int pid = get_program_id(0);\n // In Triton, arrays are first-class citizen. In other words,\n // they are primitives data-types and are -- contrary to C and\n // CUDA -- not implemented as pointers to contiguous chunks of\n // memory.\n // In the few lines below, we create an array of `BLOCK` pointers\n // whose memory values are, e.g.:\n // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz [BLOCK] = z + offset;\n float* px [BLOCK] = x + offset;\n float* py [BLOCK] = y + offset;\n // Simple element-wise control-flow for load/store operations can\n // be achieved using the the ternary operator `cond ? val_true : val_false`\n // or the conditional dereferencing operator `*?(cond)ptr\n // Here, we make sure that we do not access memory out-of-bounds when we\n // write-back `z`\n bool check[BLOCK] = offset < N;\n *?(check)pz = *?(check)px + *?(check)py;\n }\n\nThe existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper `_.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Writing the Torch bindings\nThe only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n\nTo create a `triton.kernel`, you only need three things:\n- `source: string`: the source-code of the kernel you want to create\n- `device: torch.device`: the device you want to compile this code for\n- `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import torch\nimport triton" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "source-code for Triton compute kernel\nhere we just copy-paste the above code without the extensive comments.\nyou may prefer to store it in a .c file and load it from there instead.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "_src = \"\"\"\n__global__ void add(float* z, float* x, float* y, int N){\n // program id\n int pid = get_program_id(0);\n // create arrays of pointers\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz[BLOCK] = z + offset;\n float* px[BLOCK] = x + offset;\n float* py[BLOCK] = y + offset;\n // bounds checking\n bool check[BLOCK] = offset < N;\n // write-back\n *?(check)pz = *?(check)px + *?(check)py;\n}\n \"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This function returns a callable `triton.kernel` object\ncreated from the above source code.\nFor portability, we maintain a cache of kernels for different `torch.device`\nWe compile the kernel with -DBLOCK=1024\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def make_add_kernel(device):\n cache = make_add_kernel.cache\n if device not in cache:\n defines = {'BLOCK': 1024}\n cache[device] = triton.kernel(_src, device=device, defines=defines)\n return cache[device]\n\n\nmake_add_kernel.cache = dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a standard torch custom autograd Function\nThe only difference is that we can now use the above kernel\nin the `forward` and `backward` functions.`\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class _add(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, y):\n # constraints of the op\n assert x.dtype == torch.float32\n # *allocate output*\n z = torch.empty_like(x)\n # *create launch grid*:\n # this is a function which takes compilation parameters `opt`\n # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n # triton.cdiv is a shortcut for ceil division:\n # triton.cdiv(a, b) = (a + b - 1) // b\n N = z.shape[0]\n grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n # *launch kernel*:\n # pointer to the data of torch tensors can be retrieved with\n # the `.data_ptr()` method\n kernel = make_add_kernel(z.device)\n kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)\n return z\n\n\n# Just like we standard PyTorch ops We use the `.apply` method to create a callable object for our function\nadd = _add.apply" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Writing a Unit Test\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "torch.manual_seed(0)\nx = torch.rand(98432, device='cuda')\ny = torch.rand(98432, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Writing a Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "warmup = 10\nrep = 200\nfor N in [2**i for i in range(17, 26, 1)]:\n x = torch.rand(N, device='cuda')\n y = torch.rand(N, device='cuda')\n triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)\n torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)\n # print the performance of triton and torch as well as the achieved bandwidth\n print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/docs/tutorials/02-fused-softmax.ipynb b/docs/tutorials/02-fused-softmax.ipynb deleted file mode 120000 index c3c9a225c..000000000 --- a/docs/tutorials/02-fused-softmax.ipynb +++ /dev/null @@ -1 +0,0 @@ -../../python/tutorials/02-fused-softmax.ipynb \ No newline at end of file diff --git a/python/tutorials/01-vector-add.ipynb b/python/tutorials/01-vector-add.ipynb deleted file mode 100644 index f5d7d4c26..000000000 --- a/python/tutorials/01-vector-add.ipynb +++ /dev/null @@ -1,329 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "acute-possession", - "metadata": {}, - "source": [ - "# Vector Addition" - ] - }, - { - "cell_type": "markdown", - "id": "median-malaysia", - "metadata": {}, - "source": [ - "In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n", - "* The basic syntax of the Triton programming language\n", - "* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n", - "* The best practices for validating and benchmarking custom ops against native reference implementations" - ] - }, - { - "cell_type": "markdown", - "id": "identical-conditions", - "metadata": {}, - "source": [ - "## Writing the Compute Kernel" - ] - }, - { - "cell_type": "markdown", - "id": "collectible-belle", - "metadata": {}, - "source": [ - "Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n", - "\n", - "\n", - "```c\n", - "__global__ void add(float* z, float* x, float* y, int N){\n", - " // The `get_program_id(i)` returns the i-th coordinate\n", - " // of the program in the overaching SPMD context\n", - " // (a.k.a launch grid). This is what allows us to process\n", - " // different chunks of data in parallel.\n", - " // For those similar with CUDA, `get_program_id({0,1,2})`\n", - " // is similar to blockIdx.{x,y,z}\n", - " int pid = get_program_id(0);\n", - " // In Triton, arrays are first-class citizen. In other words,\n", - " // they are primitives data-types and are -- contrary to C and\n", - " // CUDA -- not implemented as pointers to contiguous chunks of\n", - " // memory.\n", - " // In the few lines below, we create an array of `BLOCK` pointers\n", - " // whose memory values are, e.g.:\n", - " // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n", - " // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n", - " int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n", - " float* pz [BLOCK] = z + offset;\n", - " float* px [BLOCK] = x + offset;\n", - " float* py [BLOCK] = y + offset;\n", - " // Simple element-wise control-flow for load/store operations can\n", - " // be achieved using the the ternary operator `cond ? val_true : val_false`\n", - " // or the conditional dereferencing operator `*?(cond)ptr\n", - " // Here, we make sure that we do not access memory out-of-bounds when we\n", - " // write-back `z`\n", - " bool check[BLOCK] = offset < N;\n", - " *?(check)pz = *?(check)px + *?(check)py;\n", - "}\n", - "```\n", - "\n", - "The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)." - ] - }, - { - "cell_type": "markdown", - "id": "forbidden-wednesday", - "metadata": {}, - "source": [ - "## Writing the Torch bindings" - ] - }, - { - "cell_type": "markdown", - "id": "numerical-agency", - "metadata": {}, - "source": [ - "The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n", - "\n", - "To create a `triton.kernel`, you only need three things:\n", - "* `source: string`: the source-code of the kernel you want to create\n", - "* `device: torch.device`: the device you want to compile this code for\n", - "* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n", - "\n", - "Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "sporting-keyboard", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import triton\n", - "\n", - "# source-code for Triton compute kernel\n", - "# here we just copy-paste the above code without the extensive comments.\n", - "# you may prefer to store it in a .c file and load it from there instead.\n", - "_src = \"\"\"\n", - "__global__ void add(float* z, float* x, float* y, int N){\n", - " // program id\n", - " int pid = get_program_id(0);\n", - " // create arrays of pointers\n", - " int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n", - " float* pz[BLOCK] = z + offset;\n", - " float* px[BLOCK] = x + offset;\n", - " float* py[BLOCK] = y + offset;\n", - " // bounds checking\n", - " bool check[BLOCK] = offset < N;\n", - " // write-back\n", - " *?(check)pz = *?(check)px + *?(check)py;\n", - "}\n", - " \"\"\"\n", - "# This function returns a callable `triton.kernel` object\n", - "# created from the above source code.\n", - "# For portability, we maintain a cache of kernels for different `torch.device`\n", - "# We compile the kernel with -DBLOCK=1024\n", - "_kernels = dict()\n", - "def make_add_kernel(device):\n", - " if device not in _kernels:\n", - " defines = {'BLOCK': 1024}\n", - " _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n", - " return _kernels[device]\n", - "\n", - "# This is a standard torch custom autograd Function\n", - "# The only difference is that we can now use the above kernel\n", - "# in the `forward` and `backward` functions.`\n", - "class _add(torch.autograd.Function):\n", - " \n", - " @staticmethod\n", - " def forward(ctx, x, y):\n", - " # constraints of the op\n", - " assert x.dtype == torch.float32\n", - " # *allocate output*\n", - " z = torch.empty_like(x)\n", - " # *create launch grid*:\n", - " # this is a function which takes compilation parameters `opt`\n", - " # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n", - " # triton.cdiv is a shortcut for ceil division:\n", - " # triton.cdiv(a, b) = (a + b - 1) // b\n", - " N = z.shape[0]\n", - " grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n", - " # *launch kernel*:\n", - " # pointer to the data of torch tensors can be retrieved with\n", - " # the `.data_ptr()` method\n", - " kernel = make_add_kernel(z.device)\n", - " kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n", - " return z\n", - "# Just like we standard PyTorch ops\n", - "# We use the `.apply` method to create a \n", - "# callable object for our function\n", - "add = _add.apply" - ] - }, - { - "cell_type": "markdown", - "id": "separated-polyester", - "metadata": {}, - "source": [ - "At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!" - ] - }, - { - "cell_type": "markdown", - "id": "exclusive-salvation", - "metadata": {}, - "source": [ - "## Writing a Unit Test" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "supported-ribbon", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n", - "tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n", - "The maximum difference between torch and triton is 0.0\n" - ] - } - ], - "source": [ - "torch.manual_seed(0)\n", - "x = torch.rand(98432, device='cuda')\n", - "y = torch.rand(98432, device='cuda')\n", - "za = x + y\n", - "zb = add(x, y)\n", - "print(za)\n", - "print(zb)\n", - "print(f'The maximum difference between torch and triton is '\n", - " f'{torch.max(torch.abs(za - zb))}')" - ] - }, - { - "cell_type": "markdown", - "id": "otherwise-canadian", - "metadata": {}, - "source": [ - "Seems to work!" - ] - }, - { - "cell_type": "markdown", - "id": "polished-australia", - "metadata": {}, - "source": [ - "## Writing a Benchmark" - ] - }, - { - "cell_type": "markdown", - "id": "historic-glass", - "metadata": {}, - "source": [ - "The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "strange-luxembourg", - "metadata": {}, - "outputs": [], - "source": [ - "# We now want to benchmark the performance of `add`\n", - "# Against that of PyTorch for increasing vector sizes\n", - "def do_bench(fn, warmup = 10, rep = 50):\n", - " start_event = torch.cuda.Event(enable_timing=True)\n", - " end_event = torch.cuda.Event(enable_timing=True)\n", - " ret = fn()\n", - " for i in range(warmup):\n", - " fn()\n", - " torch.cuda.synchronize()\n", - " start_event.record()\n", - " for i in range(rep):\n", - " fn()\n", - " end_event.record()\n", - " torch.cuda.synchronize()\n", - " time_ms = start_event.elapsed_time(end_event) / rep\n", - " return time_ms" - ] - }, - { - "cell_type": "markdown", - "id": "hairy-claim", - "metadata": {}, - "source": [ - "We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "pleasant-valley", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "131072 0.020 0.003\n", - "262144 0.019 0.004\n", - "524288 0.016 0.016\n", - "1048576 0.033 0.033\n", - "2097152 0.071 0.070\n", - "4194304 0.142 0.144\n", - "8388608 0.287 0.286\n", - "16777216 0.572 0.568\n", - "33554432 1.139 1.110\n" - ] - } - ], - "source": [ - "for N in [2**i for i in range(17, 26, 1)]:\n", - " x = torch.rand(N, device='cuda')\n", - " y = torch.rand(N, device='cuda')\n", - " triton_ms = do_bench(lambda: add(x, y))\n", - " torch_ms = do_bench(lambda: x + y)\n", - " # print the performance of triton and torch as well as the achieved bandwidth\n", - " print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')" - ] - }, - { - "cell_type": "markdown", - "id": "juvenile-supplement", - "metadata": {}, - "source": [ - "Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py new file mode 100644 index 000000000..732a8845a --- /dev/null +++ b/python/tutorials/01-vector-add.py @@ -0,0 +1,158 @@ +""" +Vector Addition +================= +In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: +* The basic syntax of the Triton programming language +* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API +* The best practices for validating and benchmarking custom ops against native reference implementations +""" + +# %% +# Writing the Compute Kernel +# -------------------------- +# +# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel +# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_) +# programming model for more details). +# +# .. code-block:: C +# +# __global__ void add(float* z, float* x, float* y, int N){ +# // The `get_program_id(i)` returns the i-th coordinate +# // of the program in the overaching SPMD context +# // (a.k.a launch grid). This is what allows us to process +# // different chunks of data in parallel. +# // For those similar with CUDA, `get_program_id({0,1,2})` +# // is similar to blockIdx.{x,y,z} +# int pid = get_program_id(0); +# // In Triton, arrays are first-class citizen. In other words, +# // they are primitives data-types and are -- contrary to C and +# // CUDA -- not implemented as pointers to contiguous chunks of +# // memory. +# // In the few lines below, we create an array of `BLOCK` pointers +# // whose memory values are, e.g.: +# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1] +# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time +# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK; +# float* pz [BLOCK] = z + offset; +# float* px [BLOCK] = x + offset; +# float* py [BLOCK] = y + offset; +# // Simple element-wise control-flow for load/store operations can +# // be achieved using the the ternary operator `cond ? val_true : val_false` +# // or the conditional dereferencing operator `*?(cond)ptr +# // Here, we make sure that we do not access memory out-of-bounds when we +# // write-back `z` +# bool check[BLOCK] = offset < N; +# *?(check)pz = *?(check)px + *?(check)py; +# } +# +# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper `_. + +# %% +# Writing the Torch bindings +# -------------------------- +# The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects. +# +# To create a `triton.kernel`, you only need three things: +# - `source: string`: the source-code of the kernel you want to create +# - `device: torch.device`: the device you want to compile this code for +# - `defines: dict`: the set of macros that you want the pre-processor to `#define` for you + +import torch +import triton + +# %% +# source-code for Triton compute kernel +# here we just copy-paste the above code without the extensive comments. +# you may prefer to store it in a .c file and load it from there instead. + +_src = """ +__global__ void add(float* z, float* x, float* y, int N){ + // program id + int pid = get_program_id(0); + // create arrays of pointers + int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK; + float* pz[BLOCK] = z + offset; + float* px[BLOCK] = x + offset; + float* py[BLOCK] = y + offset; + // bounds checking + bool check[BLOCK] = offset < N; + // write-back + *?(check)pz = *?(check)px + *?(check)py; +} + """ + +# %% +# This function returns a callable `triton.kernel` object +# created from the above source code. +# For portability, we maintain a cache of kernels for different `torch.device` +# We compile the kernel with -DBLOCK=1024 + + +def make_add_kernel(device): + cache = make_add_kernel.cache + if device not in cache: + defines = {'BLOCK': 1024} + cache[device] = triton.kernel(_src, device=device, defines=defines) + return cache[device] + + +make_add_kernel.cache = dict() + +# %% +# This is a standard torch custom autograd Function +# The only difference is that we can now use the above kernel +# in the `forward` and `backward` functions.` + + +class _add(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + # constraints of the op + assert x.dtype == torch.float32 + # *allocate output* + z = torch.empty_like(x) + # *create launch grid*: + # this is a function which takes compilation parameters `opt` + # as input and returns a tuple of int (i.e., launch grid) for the kernel. + # triton.cdiv is a shortcut for ceil division: + # triton.cdiv(a, b) = (a + b - 1) // b + N = z.shape[0] + grid = lambda opt: (triton.cdiv(N, opt.BLOCK), ) + # *launch kernel*: + # pointer to the data of torch tensors can be retrieved with + # the `.data_ptr()` method + kernel = make_add_kernel(z.device) + kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid) + return z + + +# Just like we standard PyTorch ops We use the `.apply` method to create a callable object for our function +add = _add.apply + +# %% +# Writing a Unit Test +# -------------------------- +torch.manual_seed(0) +x = torch.rand(98432, device='cuda') +y = torch.rand(98432, device='cuda') +za = x + y +zb = add(x, y) +print(za) +print(zb) +print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}') + +# %% +# Writing a Benchmark +# -------------------------- +# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does + +warmup = 10 +rep = 200 +for N in [2**i for i in range(17, 26, 1)]: + x = torch.rand(N, device='cuda') + y = torch.rand(N, device='cuda') + triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep) + torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep) + # print the performance of triton and torch as well as the achieved bandwidth + print(f'{N} {triton_ms:.3f} {torch_ms:.3f}') \ No newline at end of file diff --git a/python/tutorials/02-fused-softmax.ipynb b/python/tutorials/02-fused-softmax.ipynb deleted file mode 100644 index fb28888d4..000000000 --- a/python/tutorials/02-fused-softmax.ipynb +++ /dev/null @@ -1,308 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "induced-zoning", - "metadata": {}, - "source": [ - "# Fused Softmax" - ] - }, - { - "cell_type": "markdown", - "id": "median-malaysia", - "metadata": {}, - "source": [ - "Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "precise-professor", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "# Compute the row-wise softmax of x \\in R^{M \\times N}\n", - "def naive_softmax(x):\n", - " # read MN elements ; write M elements\n", - " x_max = torch.max(x, axis=1)[0]\n", - " # read 2MN elements ; write MN elements\n", - " z = x - x_max[:, None]\n", - " # read MN elements ; write MN elements\n", - " numerator = torch.exp(x)\n", - " # read MN elements ; write M elements\n", - " denominator = torch.sum(numerator, axis=1)\n", - " # read 2MN elements ; write MN elements\n", - " ret = numerator / denominator[:, None]\n", - " # in total: read 7MN elements ; wrote 3MN + 2M elements\n", - " return ret" - ] - }, - { - "cell_type": "markdown", - "id": "gorgeous-monday", - "metadata": {}, - "source": [ - "When implemented naively in pytorch, computing $y$ requires reading $7MN$ elements from DRAM and writing back $3MN + 2M$ elements.\n", - "\n", - "Instead, we want to write a custom \"fused\" pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only $MN$ bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory." - ] - }, - { - "cell_type": "markdown", - "id": "identical-conditions", - "metadata": {}, - "source": [ - "## Writing the Compute Kernel" - ] - }, - { - "cell_type": "markdown", - "id": "prepared-apparatus", - "metadata": {}, - "source": [ - "Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:\n", - "\n", - "```c\n", - "__global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){\n", - " // row index\n", - " int m = get_program_id(0);\n", - " // column indices\n", - " int n [BLOCK] = 0 ... BLOCK;\n", - " // the memory address of all the elements\n", - " // that we want to load can be computed as follows\n", - " float* px [BLOCK] = X + m*stride_xm + n;\n", - " // because BLOCK has to be a power of two\n", - " // (per Triton-C specs), it is important\n", - " // to guard each memory operation with predicates\n", - " // or we will read out of bounds\n", - " bool check[BLOCK] = n < N;\n", - " float x [BLOCK] = check ? *px : -F32_INFINITY;\n", - " // syntax for reduction in Triton is:\n", - " // x[..., OPERATOR, ...]\n", - " // ^\n", - " // index\n", - " // The operators currently supported are {min, max, +}\n", - " float z [BLOCK] = x - x[max];\n", - " // The exponential in Triton is fast but approximate \n", - " // (i.e., like __expf in CUDA)\n", - " float num [BLOCK] = exp(z);\n", - " float denom = num[+];\n", - " // The result of the reduction is now stored in y\n", - " float y [BLOCK] = num / denom;\n", - " // We write it back\n", - " float* py [BLOCK] = Y + m*stride_ym + n;\n", - " *?(check)py = y; \n", - "}\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "forbidden-wednesday", - "metadata": {}, - "source": [ - "## Writing the Torch bindings" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "former-pottery", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import triton\n", - "\n", - "# source-code for Triton compute kernel\n", - "_src = \"\"\"\n", - "__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){\n", - " int m = get_program_id(0);\n", - " int n [BLOCK] = 0 ... BLOCK;\n", - " float* px [BLOCK] = X + m*stride_xm + n;\n", - " bool check[BLOCK] = n < N;\n", - " float x [BLOCK] = check ? *px : -F32_INFINITY;\n", - " float z [BLOCK] = x - x[max];\n", - " float num [BLOCK] = exp(z);\n", - " float denom = num[+];\n", - " float y [BLOCK] = num / denom;\n", - " float* py [BLOCK] = Y + m*stride_ym + n;\n", - " *?(check)py = y; \n", - "}\n", - "\"\"\"\n", - "\n", - "# We need to make sure that BLOCK is the smallest power of two\n", - "# greater than the number of rows N of the input matrix.\n", - "# Different values of BLOCK will result in different kernels\n", - "def next_power_of_2(n):\n", - " n -= 1\n", - " n |= n >> 1\n", - " n |= n >> 2\n", - " n |= n >> 4\n", - " n |= n >> 8\n", - " n |= n >> 16\n", - " n += 1\n", - " return n\n", - "\n", - "_kernels = dict()\n", - "def make_kernel(N, device):\n", - " BLOCK = next_power_of_2(N)\n", - " key = (BLOCK, device)\n", - " if key not in _kernels:\n", - " defines = {'BLOCK': BLOCK}\n", - " _kernels[key] = triton.kernel(_src, device=device, defines=defines)\n", - " return _kernels[key]\n", - "\n", - "class _softmax(torch.autograd.Function):\n", - " \n", - " @staticmethod\n", - " def forward(ctx, x):\n", - " # constraints of the op\n", - " assert x.dtype == torch.float32\n", - " y = torch.empty_like(x)\n", - " # *create launch grid*:\n", - " # here we just launch a grid of M programs\n", - " M, N = y.shape\n", - " grid = lambda opt: (M, )\n", - " # *launch kernel*:\n", - " kernel = make_kernel(N, y.device)\n", - " kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid = grid)\n", - " return y\n", - " \n", - "softmax = _softmax.apply" - ] - }, - { - "cell_type": "markdown", - "id": "exclusive-salvation", - "metadata": {}, - "source": [ - "## Writing a Unit Test" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "pretty-prospect", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],\n", - " [0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],\n", - " [0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],\n", - " ...,\n", - " [0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],\n", - " [0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],\n", - " [0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],\n", - " device='cuda:0')\n", - "tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],\n", - " [0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],\n", - " [0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],\n", - " ...,\n", - " [0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],\n", - " [0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],\n", - " [0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],\n", - " device='cuda:0')\n", - "True\n" - ] - } - ], - "source": [ - "x = torch.randn(1823, 781, device='cuda')\n", - "y_tri = softmax(x)\n", - "y_ref = torch.softmax(x, axis=1)\n", - "print(y_tri)\n", - "print(y_ref)\n", - "print(torch.allclose(y_tri, y_ref))" - ] - }, - { - "cell_type": "markdown", - "id": "regular-andrew", - "metadata": {}, - "source": [ - "Seems to work!" - ] - }, - { - "cell_type": "markdown", - "id": "polished-australia", - "metadata": {}, - "source": [ - "## Writing a Benchmark" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "chubby-audit", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABH8ElEQVR4nO3dd3xUVfr48c+TmfRKGi0EAlKkSQfFgmIv4LoWXAu6KmtddVdFd/enbnddXVdXvypr731X7GtHRToIUqUTICQkkF6mnN8f52YIGNLIZDLJ8/Y1rztz596Zc3Fyn3tPeY4YY1BKKaUAIkJdAKWUUu2HBgWllFIBGhSUUkoFaFBQSikVoEFBKaVUgDvUBTgU6enppk+fPqEuhlJKhZXFixfvNsZk1PdeWAeFPn36sGjRolAXQymlwoqIbDnYe1p9pJRSKkCDglJKqQANCkoppQI0KCillArQoKCUUiogaEFBRHqJyOciskpEVorIjc76VBH5WER+cJZdnPUiIg+JyHoRWS4io4JVNqWUUvUL5p2CF/i1MWYwMAG4TkQGA7cDnxpj+gOfOq8BTgP6O48ZwKNBLJtSSql6BG2cgjFmJ7DTeV4qIquBnsBUYJKz2bPAF8BMZ/1zxubyniciKSLS3fkcFabKqr34fIbkuMhD/iyf35C7p4KeKbG4XU27nqnx+lm3q5SSKg8V1T7Ka7xU1Pgor/ZSWeNj0sBMhmUlH3LZWkOVx0dJpYe9lR5KqzwM7p5MbJQr1MVSnUybDF4TkT7ASGA+0LXOiT4P6Oo87wlsq7NbrrNuv6AgIjOwdxJkZ2cHr9CdkMfnJ7+0mt2l1STEuMlMjCYh2o2INOtz/H7DtxsLeX3RNj5cmYfXZzh5SFcuHJfNxH7pREQ0/fN27K3kqx8KmLNuN1+v301xpYfEaDfj+6Yx8bA0Jh6WTv/MhEAZ/X7DmrxSvlm/m2827Gb+xiIqPb6Dfv6sORt5+/qJ9M1IaHKZvD4/xZUe9lR42FtRQ1F5DXsrPJRWezHGICLUHqEICFDp8VNa5aGkykNpldd5eCip9FJc6WFvZQ1VHv9+3/OL4/pyx2mHN7lcSrWGoAcFEUkA3gRuMsaU1D3BGGOMiDRrlh9jzCxgFsCYMWN0hqADFFd4WJNXwpq8UtbklbBuVxl+Y4iNdNlH1L6lz2/YVVJFXkkVecXVFJZXc+CcS3FRLjITo8lMiiEzMZruyTFkp8aRlRpHdmocPVNiiYm0V7NbCyt4Y/E23lyyne17K0mMcfPTUVnERrp4c0ku76/IIzs1jmnjenHe6F5kJEYHvscYw+6yGrbtqWBbUQXLtu1lzroCNhSUA9A1KZqTBnfliF4prNpRzDfrC/lk9S4AMhKjOapfGj6/Ye6GQorKawDolxHP+WOyGJeTRlpCFPFRbuKiXYHl3nIPUx/5ml88v5j/XjeR+OiG/xz2VtRw1XOLWLh5T4v+37gjhMQYN0mxkSTGuEmMjqR3WhwpcZGkxEWRHBtJcmwkKXGRPPrFBuZvLGrR9yh1KIIaFEQkEhsQXjTGvOWs3lVbLSQi3YF8Z/12oFed3bOcdaoRn6zaxcsLtrJ6Zwk7iqsC61PiIhnQNZE4t4vKGh/FlR4qPT6qanxUenyICF2TYuiaFM3QHsl0TYqhW3IMGQnRlFV72VVSRX5pNfml1ewqqWLljhI+XrWLau++K1oR6JoYQ0pcJGvyShGBY/pnMPO0QZw8uGsgYNxyykA+WpnHS/O3cu+Ha3ng43UcNyADr9+wraiC3D2V+31utDuC8X3TuHBcNscOyNjvbqDWtqIK5m7YzdwNhczdUEiEwKQBGUw8LJ2Jh6XTLTmmwX+3pJhI/nXhKC59aj63vbmchy8cedC7opIqD9OfWsDqnaVcM6kfXROj6RIfRUpcFKlxUaTERZIUE2lvC5zAajAYY1/GRrqIiYxo8l3X99tLePLrjVR5fIF/Q6XaQtCCgthf/5PAamPMP+q8NRuYDtzjLN+us/56EXkFGA8Ua3tCw4orPfzhnVW8uSSXrC6xjMtJZVD3JAZ1S+Tw7klkJkY3u+qnMcYYCkqr2VpUwbY9FWwtrGRrUQX5pVWcdUQPzhnVk+7JsT/aLybSxdQRPZk6oifr88t449t1DFrxV/KicuiWcQLHD+xNr9Q4srrE0su5C2nsZNgrNY4LUrO5YGzLqxGP7p/OracM4m8frmFkrxSuPKbvj7Ypr/by86cXsnJHCY9dPJoTB3et55Na1+jeXXjsS8OK7cWM7ZMa9O9TqlYw7xQmApcAK0RkmbPuN9hg8JqIXAFsAc533nsfOB1YD1QAlwexbO1KjdfPsm17mbthN99uKKSwvIafjOzJBWN7kZ4QXe8+c9YVMPPN5eSXVvPLEw7j+hP6E+UO/rATEbFVSUkxjGnhyeqwzARuP7wAlv7P9lHb8jh4x0CXKdD1LEhtwkm3ci/kr4b8Vc5jNbiioNd46DUOssZATNMakK8+ri/fbdvLXz9Yw5AeyRzZL23f19T4uOLZhSzZuoeHfzaqTQICwKjsFACWbNmjQaGd8fkNheXVRLtcREdGEO1u+h1gLb/fsKu0im1FleTuqcBvsFWKTrVi7fOk2Egim9iporWIObASOYyMGTPGhGuW1O+3F/PlugLmbSxk4eYiqjx+RGBIjyTio9zM31REpEs4fVh3LpnQm9G9uyAilFd7+cv7q3lx/lb6ZcTzj/NHcESvlFAfTvN9+Xf4/E9w1eew8XNYNRt2LrPvdR0K3UeA8YHfC/46S0857P4BSurULEYnQcYg8FRC/kowfkAgc7ANEL3GQ84xkJx10OKUVnk4+5FvKK708M4NR9M9OZYqj4+rnlvE1+t388D5Izh7ZM9g/ov8yPH3fUH/zARmXTqmTb83lDw+P+XVXsqqbS+xao+faq+PGq+f6sDDR7Q7grSEaNIToklLiCLxgA4RZdVecvdUkOucdHP3VFJe4yNCwBUhRIg4S3C7IkiItifgpJg6y5hIiis9bNxdzsaCcjbtLmNjQTlbCiuo8e1fhVrbZhfjtNfFR7mIi3ITH71vCcL2vZXkOtWldT/jYLonxzD7+qP3a39rDSKy2BhT7w9Lg0KwGAPbFkDecsg+EroOodLj553vdvD8vC2s2F4MwMCuiRzZL40j+6UxISct0HVzfX4ZL87fwhuLcymt8jKoWyJTRvTglQXb2LangiuPzuHXJw8M3/rmVy+GXavgl0v2rdu7FVa/Yx97NkNEJES4IMK97+GOgrTDIPNwe9LPHGxP9rUnhOpS2L7Y/ttvmw/bFkK1/bcmtR/kHAt9j4M+x0J82n5FWp9fytSHv2FAt0ReuGI8N76ylE9W53PvT4dz/thetLVfvbaMOesKWPjbE1u9GrC9ePqbTTz1zSbKq32UVXup8TZ+oqxPlDuC9PgoEmMiyS+tYk+FZ7/3YyIjSIyJxO83+I3B5zf4DfiNwePz4/E1fB6MdAnZqXH0zUigb3o8PVJi8foNVR4fVR4flU47XaXzurzaR0WNl/Jqu6682ovPb+iREkuv1Fh6dYmjV2pcoMrUHSH79Uorq/ZSVF7DXz9Yw0Xjs/nD1KEt+nc5GA0KrcHnhbzvILkXJGQefLuijfDdq7D8VdizKbC6JDKdzzxD+aRmGDtSx3P2xGGcPqz7QauHalXUeJm9bAfPfbuFVTtLyE6N477zjmBcTphXKfxzOPQcBec9E9zv8ftt9dKmObDpS9j8DdSU2ve6DYOscdDVCS6Zh/PB+iqueXEJXZOi2VVSzR/PHsolE3oHt4wH8eL8Lfz2P9/z5a2T6J0WH5IyBNvk+7/A6zcc0z+d+Gg3CVFuu4y2PcRi3LaKJsoVQXSkiyhXBFHuCGq8fnaX2R5zu0tr2O0sS6o8dE2KJquLPdnWLtPioxoMrFUeHyVOF2G79FBS5SUh2kXf9AR74m7jahyA3/5nBa8u3MYnvzqOPumt9xtoKCiE9SQ7bcZbDa9Nh3Uf2NfxGdB1iK3m6DoEMgbCjmXw3SuQuwAQyDmG/FG/5J7vk3HlzmOSbzknRS7m7KgvMGUPI9+Pgh19ITLOPqLiIDIWIuMhqQcMOgNckcRFuZk2LpsLxvZic2EF3ZJiwn9AU+Ve2LsFRk8P/ndFREC3ofZx5LXg88COpbDxSxskVrwOi0oCm5+W2J3PumbzWWEqg4cP46gUgbwSezEQmxL88tYxuncXAJZs3dMhg0JhWTUbCsqZeeogrpnUL6RliXGqfjITQ1qMH7lxcn/eWrKd+/63lod/1jaZfzQoNKZuQDj+txCdCLu+h10rYeET4N3XBZSMw+HEu2HY+eRHpPGTR+ZSUePliskzGDu2F/HxkbB9CbL+E9j4BeQuAk+FrQuvKbd16LVSsuHoX8GIi8Btr3JyWvFKIaTyVthltyPa/rtdkU47wzg47lZbzVeyvU6j9Wpy8lfx8/LPiVj3Aayrs290kg0OR14HIy8KelH7ZyaSEO1m8ZY9/GTkwdtDwlXteI9xOV1CXJL2KzMphiuPyeFfn63nF8cWt8noew0KDakbEM64H8Zeuf/7fp+tLspfBV36QLfhIEJljY8rZ31LUXkNr/3iyP3/R/Yaax/H37H/Zxljr2I95bB1Psy5F969CebcB0ffBCMvgciG+92Hjbzldtl9eGjLAbYtIjnLPvqfZFcBYgyU74birbB3GxRvs20eW+fB7BsgtS/0PjKoRXNFCCOzU1i8ZW9QvydUFm4uItodwdCe7SPNSHs149i+vDBvC3/7cA0vXDk+6N+nqbMPxlvTcEAA2wia3h8GT4XuR4AIPr/hxleWsmJ7MQ9dOLLpkV3ENqLGdoGBp8KVn8LFb0JyT3j/FnhoBMx7zN5VhLudyyGhW8NtM6EmAgkZ0HM0DDkbjroBTv87XPYudOkNb1xug0aQjcruwtq8EsqqvUH/rrryS6pYtm1vUL9j4eYijuiVQrQ7zKtDgywxJpIbTujP1+t389UPBUH/Pg0K9fHWwGuXNhwQDuIv76/mf6t2ceeZgznpUPq0i8BhJ8LPP4JLZ9ueMx/OhIfHwcr/8qN8FOEkb3n7uEtoiZhkOP85qCiCt66yd4tBNLp3F/wGvgvyCfpAd769kmmzvqU8SMGovNrLyh0ljNMxGE1y0YRssrrEcs8Ha/D7g/u3r0HhQN4aeN25Qzj9vmYFhOe+3cyTX2/isqP6cPnEnNYpj4jtQnn5e3Dp2xCTZMv3zBn2ijvceKqgYK2tagtX3YbZu4YNn8FX9wf1q0ZkpyACi7e0LN9SSxRXevhsTT5VHj+frclvfIcWWLp1Lz6/YWy496JrI9FuF7ecPJCVO0p4Z/mOoH6XBoUDvX0drH3fBoRxVwVW5+6p4ImvNvL5mnz2VtT8aLfP1uzi7tkrOfHwTP7fmYODU7a+k+AXc+DMB6BgDTx+LMz+JZQF/5ay1eSvsg3q4XqnUGvUpTB8Gnz+F9tp4GD2boUPZsIX90Bx81N5JcVEMiAzsU2Dwkff51Hj8xPtjuD9FcHJNLNgcxERsm/ktmrclCN6cHj3JO7/37oWj+doCm1ormvPZljxGky8ab+A8NmaXdz86ncUV+4bENM3I55R2V0YmZ1CZmIMN76ylME9knhw2khczUgN3WwRLhjzcxhyDsz5O8x/DFb+BybdDuOvsV0w27PaRuZwvlMAewd35j/sKOw3r4RffAVJ3fe9X15o7yIW/tu+9nngy3th4Gkw5nLoe0L9/6/q9obKGgOxXRjVuwvvLt+B32+alXa8pd7+bjt90uI4dkAGry3aRnm1t9EMss21cFMRg3skkRhz6PNsdBYREcLMUwdy2dMLeXnBVqYf1Sco36NBoa5lLwESCAg+v+GBj9fx8OfrGdw9iVd/MYE95R6WbN3D0q17+GxNPm8szgWgR3IMT04f2+p/PAcVmwKn/BlGXwYf/cY+Nn8DP3m0yTl/QmLncohOtr21wl1UPJz3LPz7eHjzCtv246uGef8H3zwENWUw4mcw6Q4bFBY/A0tfgDXv2uMffRkMONWm7di5zI512fkdVDgN2EdeD6f8mdG9u/Dygq1sKCijf9fgdqTPL6li7oZCbjihP0f1S+O5b7fw2Zp8zjqiR6t9R43Xz9Jte5h2CIkMO6vjBmRwZN80Hvr0B346OouEIJxvNCjU8vtg6Ytw2GRIzqKgtJobX1nK3A2FTBvbi7unDAmklKhNmGaMYUthBd/vKGZ07y50TQpBl9H0/vCz12D+4zYw/PsEuOBFyBzU9mVpirzltk6+o6RtyBwEZ/4T/jPDpu7YsQTKdsHAM2Dynfv/fzjp93D8b2waj0VPwSd32weAuGzqjgGn2p5sC2bZsTDsq2JZvGVP0IPCO8t3YgxMHdGDPmnxZCRG8/6Kna0aFFbuKKbK4w//UfkhICLcftogpj7yDf+es5GbTxrQ6t+hQaHWxi+gJBdO+RMLNhVx/UtLKK708Pdzh3PemPrz3ogIfdLjW3X4eYuIwISr7cn29enwxGQ4+/9sV9n2xO+zJ7pRbTCSuS0dcQFsnWvvBLKPhPOfh+yD9Cd3R8Owc+0jfw1sX2QHPXYdbEe019q+2KbmAHLS4+kSF8niLXuYNi64V9ezl21nWM9k+jkz0Z02tBuvLdpGRY2XuKjWOV0s3GwnD9Lsry1zRK8Ubjl5AMf0zwjK57fzCug2tPR5iE3ltZJhXPjvecRFufjvdRMPGhDapT4TYcaXNmPoa5faq9Agd5lslsL1dgR3uDcy1+f0++Hqr+HyDw4eEA6UOQhGXgxZo/cPCAAZA6B0B1QVIyKM7t2FxVuD29i8aXc53+UWM3XEvruC04d1b/VeSAs27SEnPb7VM392Jtef0D9o2ZE1KIDtc77mPRh+AQ98sYXhWcnMvuFoDu+eFOqSNV9yT7j8fVtf/fUD8MJPobLteq40aGcHaWSuj8vdutViGU610+4fABjVuwsbC8rZU/7jnm+t5e1l2237+fB9QWFsn1QyEqN5b3nr9ELy+w2LthQxto+mtmivghYUROQpEckXke/rrBshIvNEZJmILBKRcc56EZGHRGS9iCwXkbbJ/FRr+Wvgq8E/4iIKSqs5sm+anVoxXLmj4awH4ayHYPNX8NHvQl0iK+87cEXbBIKqYenOv1HBGsCObAZYui04Ad4Yw+xlO5iQk7bfNKauCOG0od34fG0+FTWHPpBtfUEZeys8LZ6gSQVfMO8UngFOPWDdvcDvjTEjgDud1wCnAf2dxwzg0SCWa3/G2KqjHiPZmzQQr990nNva0dNh3C/gu5ds/XWo7VxuG1NdYRxw20qXPnYmuYK1AByRlYIrQoI2XuH77SVs3F2+X9VRrdasQqptT9CRzO1X0IKCMWYOUHTgaqC2TiYZqB2aNxV4zljzgBQR6U5b2LnMZj0deTEFpdUAZCZ2kMRzAMf8GqIS4NM/hLYcxoR3eou25nJDWv9AUIiNcjGkR1LQgsLby7YT6RJOG/rjP7uxfVJJT2idKqSFm4rISIymd1rcIX+WCo62blO4Cfi7iGwD7gNqU4X2BLbV2S7XWRd8S18AdwwMPTcQFDrMnQLY2cUm/hLWvmczfIZKca5t2+iI7QnBkjEgUH0Etgrpu23FeJswjWNz+PyGd5bvYNLAzMDMf3W5IoTTh7VOFdLCzXsY1ye1w84k1xG0dVC4BrjZGNMLuBl4srkfICIznPaIRQUFh5jewVMJy1+Hw6dAbAoFZXZuhA4VFAAmXAsJXW1vpFAl0qudQ6F7COZQCFcZg2yajJoKwDY2V3p8rMkrbdWvmb+pkF0l1fVWHdVqjSqk7Xsr2b63UhuZ27m2DgrTgbec568D45zn24G6fT+znHU/YoyZZYwZY4wZk5FxiP10V79r5+8ddQkA+SUd8E4B7Mjb426Drd/Cuo9CU4a85YDYmepU06QPAAwU2h5ItTOxNacKKXdPBc9/u5lHPl9P8QHzFteavWwH8VEuJg86eFbf2iqkQ8mFtHCTrU3WRub2ra0Hr+0AjgO+AE4AfnDWzwauF5FXgPFAsTEmOJm46lr6HKT0ht5HA1BQWk1spIv4cJ/usj6jpsO3j8Cnv7eTyUS08THuXA5ph9kApZqmtltqwTrofgQ9kmPolhTD4i17Dpr3xuc3LN26h0/X5PPZ6nzW7tp3V/HvrzbyyxP6c/GE3kS57fVgtdfH+yt2csqQbg1O81pbhXQoA9kWbi4iMdodnl29O5GgBQUReRmYBKSLSC5wF3AV8KCIuIEqbE8jgPeB04H1QAVwebDKFbBnsx0xevzvAonJCsqqyUiM7pj1na5IOOH/2clhlr9qc/K0pbzldgpM1XRp/Wz6C6ddoXYQ29frd3P37JV4fH68PoPH78fjM1TW+Fi8pYg9FR7cEcLYPqn87ozDOWFQJlUeP395fzV/eHcVz327mdtPG8QpQ7rx5doCSqq8TB3ZeBPe6cO6B3Ih1R3L0FQLNxcxqneX4CaMVIcsaEHBGHPhQd4aXc+2BrguWGWp19IXAYER+4pZUFpNZkerOqpr8NnQ/UGb7nnIOW03vWdFkZ3OshlzUyjseJPUnP0am08d2o0v1ubz5pJcolwRuF1CpCvCeQjHD8pk8qCuHDMg/UdjbZ6/YhxfrCvgL++t5uoXljC2TxcEIS0+iolOPq+G1K1Cam5Q2FNew7pdZUwd0Tb9R1TLdc7cR36fzYjqJL+rVVBazWGZCSEsWJBFRMCJd8PzZ8OiJ+0E9G2hPc3JHG4yBsHudYGXZx3Ro8XJ6USE4wdmcsxh6by2KJd/fLyO3WXVTD+yN25X482Lh1KFtMhpBxnTWxuZ27vOmeZi4+c2+d3Ii/dbXVt91KH1Ox76Hg9z7oOq4v3fKy+0VWpLX7RX960lkN5Cex41W/oAKNxgZwRsJW5XBD8bn80Xt07izz8Zyg2T+zd539peSLe8/h3r85veC2rh5iKiXBFBy9ejWk/nvFNI7GEnqhl4emBVtdfH3goPGQkdPCiAvVuYdRy8ezMk9YBdq2z20rK8fdtkDrGT1Me1Qk+RvOWQ1NOOmVDNkzHIzlRXtLHV06EnRLu5aHzvZu0zrk8q107qx9PfbOaD7/M4ZXA3rjv+MIZlNTyHx8LNRQzPSg6kn1ftV+cMCl0H2ykt69hdZq/EOvydAkCPETDsPFjx+r5cRP1OsP8uXYdAdZmdTez5s+3EMbEph/Z9O5froLWWyqiTA6kdzJERESHcduogrjymL09/s4ln5m7mw5V5HDsgg+uPP4xxOanUeP3sKqkir6SKncVV5BVX8v32Yq48pm+oi6+aoHMGhXp0yNHMDZnyMBx3u5Njp56fgfsFeOVn8MI5cMl/IaaF3QhrKmw/+yFnH0JhO7F0p2rHSXfRXqTGR/Hrkwcy49i+PD9vC09+tYnzH/+WlLhI9tYzHiIpxs3Jgw8+DkK1HxoUHJ0uKETGQPphB39/wMlw/rN2XoYXz4WL34LoFjTC71oJxq93Ci0VFQ8p2bC7fQWFWokxkVw76TAuPyqH1xdvY01eKV0TY+ieHEO35H1LnYs5fGhQcHTIZHiHatAZ8NMn4I2fw8vT7LSfUfUkMqsosnMLu2MgPh3i0iAmxfZ20p5Hhy59YLu7UzhQbJSLS4/sE+piqFagQcFRGxTSEqJCXJJ2ZshPwOeFt66y1UkXvmJnBNs6b9+jvqtYiYDYVPB7bYBIDqMZ7NqbjIG2V5jf1/Yj0VWno0HBUVBWRWp8FJFN6K/d6Qw/D3w18Pa1cG+OnVITIDrZTj05/HzIGmNPWhVFUFEIFbudZSH0mtB6M5J1RhmDwFdtR+Gn9Qt1aVQHp0HBkV9S3Tm6o7bUyIvsPMLrP4Weo+wE9RmDAilCVBAFeiCt1aCggk6DgqNTDFw7VEPPsQ/VttIH2OXutdgUYUoFj17mOQpKNSiodio2BRK6tfvGZtUxaFDATlre4ZPhqfCW0f57ILXIkufgP9fY9OCqXdCgAJRWe6n2+vVOQbVfGYNsUAjVzHnBYAx8cQ989xL833h4+zrYu63x/VRQaVCgEw5cU+EnYwB4yu1c1x3FjqVQsh1O+iOMvxqWvwb/GgUf3gHlu0Nduk5LgwJ1goL2PlLtVe0sbO10ZHOLrHnPTiI08mI49a9wwxLbvXn+Y/DgEXbej+rWnY9aNS5oQUFEnhKRfBH5/oD1N4jIGhFZKSL31ll/h4isF5G1InJKsMpVn3y9U1DtXXqdbqkdxZp3ofdR+zLxpvSCqY/AtfPtXCdf/g0eGgWLnrYDKFWbCOadwjPAqXVXiMjxwFTgCGPMEOA+Z/1gYBowxNnn/0SkzYZuavWRavfi0+0I8Y4SFHavt5lfDz/rx+9lDIDzn4MrP7PjMt69CR4/BtZ/0ubF7IyCFhSMMXOAA2dquQa4xxhT7WyT76yfCrxijKk2xmzCztXcZhP6FpRWE+WKIDlWk3apdkpkX2NzR7DmXbsc2MC4i6zRcPkHNkB4KuGFn8Lz59j5P1TQtHWbwgDgGBGZLyJfishYZ31PoG63g1xn3Y+IyAwRWSQiiwoKClqlULVjFERTMaj2LGOAvbruCD2Q1rwH3UfYKqOGiMDgqXDdfDjlL7B9ETw2Ed6/1aZlV62urYOCG0gFJgC3Aq9JM8/ExphZxpgxxpgxGRkZrVKogrJq0rXqSLV3GYOgai+Ut87FUMiU5kHuAhh0ZtP3cUfbOcV/uQzGXgkLZsGsSfumelWtpq2DQi7wlrEWAH4gHdgO1L1kyHLWtYmCUs17pMJAbbqLgjX1v++pAr+/7crTUmvft8vDmxEUasWlwul/h0v+Y+cYf2IyfPtIeBx3mGjroPBf4HgAERkARAG7gdnANBGJFpEcoD+woK0KVVBapY3Mqv2r7ZZ6YLtCVQl8cjfckw33D4S3r4c177ff6pXV70Jq333H0xL9ToBr5sJhJ8JHv4EXfwqlu1qvjJ1Y0BLiicjLwCQgXURygbuAp4CnnG6qNcB0Y4wBVorIa8AqwAtcZ4zxBatsdXl9fgrLazQoqPYvqQdEJe4LCj4vLHkGPv+rTVU+9Fw7y92qt2Hp83bSo76TYMCpMOAUu3+oVRXbuSEmXHPo6dTj02DaS7DoKRsYHj3KdmkdeGrj+6qDClpQMMZceJC3Lj7I9n8G/hys8hxMUXkNxmh3VBUGRPY1Nq/7CP73/+xgtt4T4eTXbUpzAG8NbJ0Laz+wVTXrPrTr0w6DnGOhzzH2kVCnTc7vt3Npb5sPW+fbpafSdhkddi70HN06c2L88DH4Pc1rT2iICIy9wv4bvHklvHwBjPk5nPwnO5WparZOnzo7X0czq3CSMQiWvQibv4LUfnDBi3ba1LonbHeUvUPoOwlOvQfyV9s+/pu/guWv2ytrgIzDIXsClOywDb+Ve+z62FToNd5+5qInYf6jkNIbhv7UProOaXmAWPMuxGdC1tjGt22OzEFw1afw6R9sG8PGL+GcWXbyJ9UsnT4oFJQ5czMnaVBQYaDfCfaEd9QN9orY3cj0sSLQdbB9TPylrXLaucxW4Wz+Cla8YauVBp1pA0Sv8faOovakX1Vs2wC+fxO+eRC+/ocNTJPvgkHNnNvBU2XvFIadG5zJmdzRcMqfbVXZf66BJ0+GY2+FY28Bl45BaioNCnqnoMLJsHPto6Vcbnv1nDUGjvlV49vHJNtZ90ZeBGUFsPptWPgkvHIhjL7Mjh1oajXNpjlQUwaD6hnF3JpyjoVrvoEPZsKX98AP/4Nz/g3phwX3ezuIRsO1iMSIyLki8qCIvC4iz4nIbSIypC0KGGya4kKpJkrIsGMEZnwJE2+Exc/CY8fA9sVN23/NO7ahPOeY4JYT7MRE5zwO5z0DezbBY0fDNw+13x5Z7UiDQUFEfg98AxwJzAceB17D9hC6R0Q+FpHhQS9lEBWUVpMY4yYmss1SLSkV3txRcNIfYPo74K2GJ06CL//ecNI6v882fA842VbztJUhP4FrvrWB6OP/Z7Ovzv0X1JS3XRnCTGPVRwuMMXcd5L1/iEgmkN3KZWpTOg2nUi2Uc4ytpnnv1/D5n2D9x3D2ozaJ3YG2LbAjsQed0fblTOoOF70Om7+xmVf/9zv4+p+2XWbslRCd0PZlascavFMwxrx34DoRiRCRJOf9fGPMomAVri3oaGalDkFsCpz7JJzzhO3l9K9RNv3El/dC3op9eZrWvAuuKDjspNCVtc9EmD4bfv4RdBsGn9wFDw6Hrx+wdzwKaOKIZhF5SUSSRCQe+B5YJSK3BrdobaOgrJrMpJhQF0Op8Db8PLh2nu2VFBFpJ8h57Gj45zB47xZY+V/IOQ5ikkJdUtvL6tL/whUfQ4+RdjT4v0+wQUw1Oc3FYGNMCXA28AGQA1wSrEK1Jb1TUKqVJPe0PZqu/BhuWQdTHoZuw2HpC1CSa7Odtie9xsHFb8LPXrNVW7OOh6/u7/QT+jS1S2qkiERig8LDxhiPiIR9/t6KGi9l1V5tU1CqtSVkwqhL7MNTaa/Ce44OdanqN+AUe5fz3q/s4Le1H8DZj3XaLqxNvVN4HNgMxANzRKQ3UBKsQrUV7Y6qVBuIjLVX5RHtuIdfXKrtvvrTJ2H3D7bqa/6sTpl9tbEuqUeKiBhjHjLG9DTGnO4ksNuKk+00nGlQUErtZ9i59q6hz0T44FZ4firs2RLqUrWpxu4ULgUWi8grInKZiHQDcOZDCPuKNx3NrJT6kaTucNEbcOY/YfsSm3110VMdY8a7JmisS+o1xphRwN1AF+AZEflWRP4iIseKSDu+H2xcbd4jvVNQSu1HBMZcDtd+a1OCvHszPH827N0a6pIFXZPaFIwxa4wxDxhjTgVOAL4GzsOOcg5bBaXVuCKE1PhGkooppTqnlGy45L9w5gOQuwj+7yhY9HSHvmtoVqpCEYkDhgALjTE3GGMOmpdWRJ4SkXxnQp0D3/u1iBgRSXdei4g8JCLrRWS5iIxq7oG0REFpNWnxUbgiWiFPvFKqYxKxGWmvmWvnrHj3JnvXsPO7UJcsKBpraJ4iIptFZImInA6sBB4GVojI9EY++xngR1MgiUgv4GRsY3Wt07BTcPYHZgCPNvkIDoGmuFBKNVmX3nDp23DGP2D7Unj8WHhpWtMTAoaJxu4U/og9gf8CmwhvsjFmAjAcuKWhHY0xc4Ciet56ALgNqHv/NRV4zmnAngekiEj3ph1Cy+VrUFBKNUftTG83r4Djfwdbv7WjoV841+Z36gAaCwp+Y8w6Y8xCYJMxZiPYnEfYTKnNIiJTge3GmAPvu3oC2+q8znXW1fcZM0RkkYgsKigoaG4R9qOjmZVSLRKTDMfdCjetsKk9ti+GJ0+C586GLXNDXbpD0lhQiBCRLiKSBvid56kiktqEfffjtEf8BrizhWUFwBgzyxgzxhgzJiMjo/EdDsLvN+wu0zsFpdQhiEmyqT1uWgEn/RF2fQ9PnwZPnmLn0Q7DBunG0lwkA4uB2pbYJXXea+7R9sPmTPpO7FR/WcASERkHbAd61dk2y1kXNHsrPXj9hkwNCh2Sx+MhNzeXqqqqUBclJGJiYsjKyiIyUqehbBPRCXa607FX2lxPcx+Cl86HrkPh6Jth8Nl21rsw0GApjTF9WuuLjDErgMza1yKyGRhjjNktIrOB60XkFWA8UGyM2dla312ffaOZNUNqR5Sbm0tiYiJ9+vRBWjrJfJgyxlBYWEhubi45OTmhLk7nEhUH42fYMQ4rXrdpud+8Aj77Exx9E4y4qN3PF91Y7yOXiCTUeT3BGbR2rIgkNrLvy8C3wEARyRWRKxrY/H1gI7Ae+DdwbZOPoIU0xUXHVlVVRVpaWqcLCAAiQlpaWqe9S2oXXJEw4mdw7Xw4/3k778Q7N9rR0Ws/aNfVSo3dz/wNyAfudV6/jJ1PIQZblTTzYDsaYy5s6IPr3oU4+ZSua7y4rSe/1P7BaFDouDpjQKjVmY+9XYmIgMFT4PCzYO378PGd8PI06HMMnPxHO59DO9NYY/Fk4B91Xu81xpyF7aY6MWilagN6p6CUajMidirSa+fB6fdB/io7Q91bv4Di3FCXbj+N9j46IPHdTAhc2Yf1xKYFpdXERrqIjwrr9E2qnSosLGTEiBGMGDGCbt260bNnz8DrmpoaAGbPns0999wDwH//+19WrVoVyiKrtuCKhHFXwS+X2gbolf+Bf42GT/8I1WWhLh3QeFCIqtt2YIz5H4CIJGOrkMJWgdMdVW+zVTCkpaWxbNkyli1bxtVXX83NN98ceB0VFYXX62XKlCncfvvtgAaFTicmGU68G25YZKuWvroPHh4D370S8jkcGgsK/wZeFZHs2hXOBDsvA08Es2DBVlBard1RVZu67LLLuPrqqxk/fjy33XYbzzzzDNdffz1z585l9uzZ3HrrrYwYMYINGzawbNkyJkyYwPDhw/nJT37Cnj17AJg0aRIzZ85k3LhxDBgwgK+++irER6UOSUo2/PQJO190Ynf4zy/gyRNh28KQFamxLqn/EJEK4GsRiXdWlwH3GGPaJD9RsBSUVnNYZljXgKkm+v07K1m1o3UnChzcI4m7zhrS7P1yc3OZO3cuLpeLZ555BoCjjjqKKVOmcOaZZ3LuuecCMHz4cP71r39x3HHHceedd/L73/+ef/7znwB4vV4WLFjA+++/z+9//3s++eST1josFSq9xsGVn8LyV+GTu21gGHY+TP5/kNzLtkm0kUZHUxhjHgMeq61GMsaUBr1UbaCgrJoj+6WFuhiqkznvvPNwuRpuxyouLmbv3r0cd9xxAEyfPp3zzjsv8P4555wDwOjRo9m8eXPQyqraWEQEjLjQVid9/Q+Y+zCseA1cURDbBWJT7TLOWQ44FQ4/s9WL0WBQEJGLgZeMMf76goGI9AO6G2O+bvWSBVG118feCo/mPeokWnJFHyzx8fGNb9SI6Gj7u3W5XHi9YT8BojpQdAJMvhNGXQqrZkNFIVQWQUURVO6Foo32eUp22wcFIA1YKiKLsekuCrANzIcBxwG7gdtbvVRBtrvM9v7Q7qiqvUhMTKS01F53JScn06VLF7766iuOOeYYnn/++cBdg+pEuvSxqTPaWGNtCg+KyMPY2dYmYlNmVwKrgUuMMWE5N52OUVDtzbRp07jqqqt46KGHeOONN3j22We5+uqrqaiooG/fvjz99NOhLqLqJMS04+HWjRkzZoxZtGhRs/f7eNUurnpuEbOvn8jwrJTWL5gKudWrV3P44YeHuhghpf8G6mBEZPHBZs5sVvrrjqJrUjQXjsumZ0psqIuilFLtSnjkcm1lw7NS9A5BKaXq0SnvFJRSStWvSXcKIhIN/BToU3cfY8wfglMspZRSodDU6qO3gWJst9Tq4BVHKaVUKDU1KGQZY05tzgeLyFPAmUC+MWaos+7vwFlADbABuNwYs9d57w7gCsAH/NIY81Fzvk8ppdSha2qbwlwRGdbMz34GODCQfAwMNcYMB9YBdwCIyGBgGjDE2ef/RERzWquw1ZTU2c1x9913c9999wWhpErtr7E0FysA42x3uYhsxFYfCXZaheEH29cYM0dE+hyw7n91Xs4DznWeTwVeMcZUA5tEZD0wDjudp1JhpzZ1NtgTekJCArfcckuj+/l8vkZzIykVTI3dKZyJre45DZva4mTnde36Q/Fz4APneU9gW533cp11PyIiM0RkkYgsKigoOMQiKNV2Pv30U0aOHMmwYcP4+c9/TnW1bZ7r06cPM2fOZNSoUbz++ut8+OGHjBo1iiOOOILJkycH9l+1ahWTJk2ib9++PPTQQ6E6DNXBNZbmYguAiDxvjLmk7nsi8jxwSb07NkJEfgt4gRebu68xZhYwC+yI5pZ8v+pkPrgd8la07md2Gwan3dPkzauqqrjsssv49NNPGTBgAJdeeimPPvooN910E2DvLJYsWUJBQQGjRo1izpw55OTkUFRUFPiMNWvW8Pnnn1NaWsrAgQO55ppriIyMbN3jUp1eU9sU9ksz6dT3j27JF4rIZdg7jYvMvhwb24FedTbLctYp1SH4fD5ycnIYMGAAYNNhz5kzJ/D+BRdcAMC8efM49thjycnJASA1NTWwzRlnnEF0dDTp6elkZmaya9euNjwC1Vk01qZwB/AbIFZEamcpEWzvoVnN/TIRORW4DTjOGFNR563ZwEsi8g+gB9AfWNDcz1eqXs24og+VpqTUrk2ZDZo2WwVPg3cKxpi/GmMSgb8bY5KcR6IxJs0Yc0dD+4rIy9iG4oEikisiVwAPA4nAxyKyTEQec75nJfAasAr4ELjOGOM79MNTqn1wuVxs3ryZ9evXAxw0HfaECROYM2cOmzZtAtiv+kipttDYncIo5+nrdZ4HGGOWHGxfY8yF9ax+soHt/wz8uaHyKBWuYmJiePrppznvvPPwer2MHTuWq6+++kfbZWRkMGvWLM455xz8fj+ZmZl8/PHHISix6qwaTJ0tIp87T2OAMcB32Oqj4cAiY8yRQS9hA1qaOlt1fJo2Wv8N1MG1OHW2MeZ4Y8zxwE5glDFmjDFmNDASbQhWSqkOp6m9jwYaYwJ9+owx3wN6CaKUUh1MU3MfLReRJ4AXnNcXAcuDUySllFKh0tSgcDlwDXCj83oO8GhQSqSUUipkmhQUjDFVwAPOQymlVAfVWJfU14wx59dJjLefhhLiKaWUCj+NNTTXVhfVJsA78KGUOggR4de//nXg9X333cfdd9/d4D6zZ8/mnnva/whs1XE11iV1p/P0RCDKGLOl7iP4xVMqfEVHR/PWW2+xe/fuJu8zZcoUbr/99iCWSqmGNbVLajbwuIhsFJHXReQGERkRxHIpFfbcbjczZszggQd+3BT3zjvvMH78eEaOHMmJJ54YSG73zDPPcP3111NcXEzv3r3x+/0AlJeX06tXLzweDxs2bODUU09l9OjRHHPMMaxZs6ZNj0t1bE1taL4LQERigauAW4F/AjobiGr3/rbgb6wpat0T56DUQcwcN7PR7a677jqGDx/Obbfdtt/6o48+mnnz5iEiPPHEE9x7773cf//9gfeTk5MZMWIEX375Jccffzzvvvsup5xyCpGRkcyYMYPHHnuM/v37M3/+fK699lo+++yzVj0+1Xk1KSiIyO+AiUACsBS4BfgqiOVSqkNISkri0ksv5aGHHiI2NjawPjc3lwsuuICdO3dSU1MTSJVd1wUXXMCrr77K8ccfzyuvvMK1115LWVkZc+fO5bzzzgtsVztZj1KtoanjFM7BTorzHvAl8K0zdaZS7V5TruiD6aabbmLUqFFcfvnlgXU33HADv/rVr5gyZQpffPFFvQ3QU6ZM4Te/+Q1FRUUsXryYE044gfLyclJSUgJTfSrV2prUpmCMGYVtbF4AnASsEJGvg1kwpTqK1NRUzj//fJ58cl+S4OLiYnr2tDPOPvvss/Xul5CQwNixY7nxxhs588wzcblcJCUlkZOTw+uvvw6AMYbvvvsu+AehOo0mBQURGYpNbTEduACbDE8rMZVqol//+tf79UK6++67Oe+88xg9ejTp6ekH3e+CCy7ghRdeCMzMBvDiiy/y5JNPcsQRRzBkyBDefvvtoJZddS4Nps4ObCTyLrYN4StgoTHG04R9nsKOb8g3xgx11qUCrwJ9gM3A+caYPSIiwIPA6UAFcFlDczXU0tTZ6mA0bbT+G6iDa3Hq7FrGmDONMX8zxsxtSkBwPAOcesC624FPjTH9gU+d1wCnYafg7A/MQPMqKaVUSDSW5qLe9Ba1GkpzYYyZIyJ9Dlg9FZjkPH8W+AKY6ax/ztjblnkikiIi3esMnlNKKdUGGut9dKazvM5ZPu8sL2rh93Wtc6LPA7o6z3sC2+psl+us+1FQEJEZ2LsJsrOzW1gM1RkYY7A1k51PU6qFlapPY2kuatNZnGSMuc0Ys8J53A6cfChf7NwVNPuXa4yZ5cwANyYjI+NQiqA6sJiYGAoLCzvlydEYQ2FhITExMaEuigpDTR2nICIy0RjzjfPiKJqeIqOuXbXVQiLSHch31m8HetXZLgud7lMdgqysLHJzcykoKAh1UUIiJiaGrKysUBdDhaGmBoUrgKdEJBkQYA/w8xZ832xst9Z7nOXbddZfLyKvAOOBYm1PUIciMjKy3lHCSqmGNTX30WLgCCcoYIwpbmwfEXkZ26icLiK5wF3YYPCaiFwBbAHOdzZ/H9sddT22S+rlP/pApZRSQdfU3EfRwE+x4wvctY13xpg/HGwfY8yFB3lrcj3bGvY1ZiullAqRplYfvQ0UA4sBzXmklFIdVFODQpYx5sCBaEoppTqYpvYgmisiw4JaEqWUUiHX1DuFo4HLRGQTtvpIsE0BBx3RrJRSKvw0NSicFtRSKKWUahea2iV1C4CIZAI6TFIppTqops6nMEVEfgA2YWde2wx8EMRyKaWUCoGmNjT/EZgArDPG5GDHGswLWqmUUkqFRFODgscYUwhEiEiEMeZzoN4JGpRSSoWvpjY07xWRBGAO8KKI5APlwSuWUkqpUGjqncJUbE6im4EPgQ3AWcEqlFJKqdBoau+j2rsCv4i8BxSazpioXimlOrgG7xREZIKIfCEib4nISBH5HvgeOy+Cpr1QSqkOprE7hYeB3wDJwGfAacaYeSIyCHgZW5WklFKqg2isTcFtjPmfMeZ1IM8YMw/AGLMm+EVTSinV1hoLCv46zysPeK/FbQoicrOIrBSR70XkZRGJEZEcEZkvIutF5FURiWrp5yullGqZxoLCESJSIiKlwHDnee3rFmVNFZGewC+BMcaYoYALmAb8DXjAGHMYdrrPK1ry+UoppVquwaBgjHEZY5KMMYnGGLfzvPZ15CF8rxuIFRE3EAfsBE4A3nDefxY4+xA+XymlVAs0dZxCqzHGbAfuA7Zig0HtjG57jTFeZ7NcoGd9+4vIDBFZJCKLCgoK2qLISinVabR5UBCRLtjBcDlADyAeaHL3VmPMLGPMGGPMmIyMjCCVUimlOqc2DwrAicAmY0yBMcYDvAVMBFKc6iSALGB7CMqmlFKdWiiCwlZggojEiYhgM66uAj4HznW2mQ68HYKyKaVUpxaKNoX52AblJcAKpwyzgJnAr0RkPZAGPNnWZVNKqc6uqVlSW5Ux5i7grgNWbwTGhaA4SimlHKGoPlJKKdVOaVBQSikVoEFBKaVUgAYFpZRSARoUlFJKBWhQUEopFaBBQSmlVIAGBaWUUgEaFJRSSgVoUFBKKRWgQUEppVSABgWllFIBGhSUUkoFaFBQSikVoEFBKaVUQEiCgoikiMgbIrJGRFaLyJEikioiH4vID86ySyjKppRSnVmo7hQeBD40xgwCjgBWA7cDnxpj+gOfOq+VUkq1oTYPCiKSDByLM92mMabGGLMXmAo862z2LHB2W5dNKaU6u1DcKeQABcDTIrJURJ4QkXigqzFmp7NNHtA1BGVTSqlOLRRBwQ2MAh41xowEyjmgqsgYYwBT384iMkNEFonIooKCgqAXVimlOpNQBIVcINcYM995/QY2SOwSke4AzjK/vp2NMbOMMWOMMWMyMjLapMBKKdVZtHlQMMbkAdtEZKCzajKwCpgNTHfWTQfebuuyKaVUZ+cO0ffeALwoIlHARuBybIB6TUSuALYA54eobEop1WmFJCgYY5YBY+p5a3IbF0UppVQdOqJZKaVUgAYFpZRSARoUlFJKBWhQUEopFaBBQSmlVECouqQqpTqRvPI8iquLATAHJCvwGz9+48dnfHbpt0s/fqT2P7HLCIlARPbbr3ZfYwx+4w98bu0+AILQLb4bOck5gf0bYoxhU/Em8ivz8fl9eP1evH4vHuPB6/cSQQQ9E3uSnZhNSnTKQT/Tb/zkleexpWQLuyp2YYwJHL9N3GC5I9xEu6KJdEUS7Yom2hVNlCsKt7ip8ddQ46uh2ldNjW/f8/5d+jM0fWgz/i80jQYF1W7V/rH7/D58xv5h1p44ak8ABmOf48fr9wb+aOr+IfmNn/TYdLrFdyM1JpUIqf8Gubi6mO1l29letp38inw8Pg9e4w18r8/vw2u8+/0x7/ccE9i29iTiNd5A+Wu3N+wre1JUElP6TWFct3FNOlmFo/c2vscdX93xo2AQCqkxqYztNpZx3cYxptsYcpJskPAbPxv2bmDRrkUszFvI4l2LKaoqatJnJkYlkp2YTXZSNtmJ2VT7qtlSsoWtJVvZVrqNGn9NUI7l8qGXByUoSN0fdbgZM2aMWbRoUaiL0aEYY/AZHxXeCio8FZTVlFHuLae8ppwyTxnVvmq8fu++E7Zzkvb6vVT7qqnyVlHpraTKV0WV1z5q/DWBE2PdE73f+Knx77sCqvZVB557/J79rvpai1vcZMZl0jW+K13julLjq2F72XZ2lO2g1FPa6L4REoErwnXwbSLcREZE4hIX7gh3YBkhEfuueCVw/Utehb2C7pvclwsHXchZ/c4iPjK+tQ87ZObtnMc1n1zDiIwRXHz4xfu/KbULwSUu+28rLiIi7LL2Kr9u4DfGBAJxREQEEUQQIfsetf+2tfvUMtgr/4V5C1mQt4D8CptFJz02nf4p/VlTtIY91XsA6B7fnTFdxzC221iyErOIjIjEHeHe7/+nz+9je9l2e/Iv3crWkq1sLd3KzvKduMVNr8ReZCdl0zupt10m9qZ7QnfcYq/DD7wAqL2gqf0b8Pg9gb+1KFcUURFRgbuHKJd9nhSVREJUQov+v4jIYmNMfWPFNCiEM7/xs7tyN9tKt1FQURD4QdU9uVb5qiivKafUU0ppTSllNWWU1tjnld7KfSd14w3cih+KCIkg1h1LjCuGGHcMse7YwB9V4I/eObG6xBX4gdf+4OveNkdGRAa2q/2DrP2M2pNr7UkB7Ak58BkRUYE/IEHYXbmbXRW7yCvPY1fFrsDzyIhIeib0pGdCT7ISswLPu8V3I9oVHfj+2u9sbdW+aj7c9CEvrXmJVYWriI+MZ2q/qUwbNI2c5JxW/762tLZoLdM/nE73+O48e9qzJEUlhbpIgL3w2Va6jQV5C1iYt5D1e9czKHUQY7uNZWy3sfRM6Nniz/b4PLgiXAe9G20vNCiECWMMZZ4y9lbttVfnHvuo8FRQ7rFX6nnleeSW5rKtdBu5ZblU+6ob/MyoiCgSohJIjEokITKBhKgEe4URmUCsO3b/k22dE2CsO5aEqATiIuNIiEwgPjKeOHccse7Y/barPUm7IlzEuGJwR7g7bDVIMBljWLF7BS+veZkPN3+I1+/ljL5nMHPsTLrEhN8khDvLdnLx+xcjIrxw+gt0i+8W6iKpOjQohJDf+CmpLmF35W4KqwrtsrKQwqrCwLKoqojCSrv0+D0Nfl6sO5Zeib3ISsiyy0S7zIzLJMYdQ4wrhihXFDHuGCIjItv9FYv6sd2Vu3lp9Us8/f3TJEUnccf4Ozil9ylhE2yLq4u59INLKago4NnTnqV/l/6hLpI6gAaFIKnyVlFcXcze6r0UVBawo2wHeeV57Czfyc7ynYGqCq/f+6N9IyMiSYtNIzUmldSYVNJi0kiNtcuU6BQSIu1VenxkfOARFxlHYmRi2Jwc1KFZW7SWO+feyarCVUzOnszvJvyO9Nj0UBerQdW+amb8bwYrdq/g8ZMeZ2y3saEukqqHBoVW8MOeH7h/0f0UVBZQXF1McXUxVb6qH23nEhdd47rSLb4b3RO60y2uG+mx6aTHppMWm2YfMWkkRSXpyV01yuv38tyq53hk6SPEuGOYOW4mZ/U9q8ndKr1+L5W+Srx+Ly5xBdpp3OJusMG8JXx+H7fOuZWPt3zM34/9O6fmnNqqn69ajwaFQ+TxebjgvQvIr8hnZOZIUqJTSIlOITk62T6iksmMy6RbfDcyYjNa/Y9NqU3Fm7hr7l0szV/K4amHkxCVYLu/+n14/Pu6ztb4amzvL28V1b7qQFfY+ggSCBAHtinVbcyv29cfbM+Z2vOGqf3PCUCFVYXcOuZWLh1yafD/UVSLNRQUdJxCEzyx4gl+2PMDD5/wMMf1Oi7UxVGdUE5yDs+c+gwvr3mZjzZ/hM/vIyoiCrfbHTixuyPcxLhjiHZFB9qXYt2xRLuibTfKOmMtPH5PYFDWgV2L6w4kq3vyr2WM2e9Ope7gsqHpQzl/oE6FEs46ZVCo8lbx1g9vMW3QtEYbYtcWrWXW8lmc0fcMDQgqpCIkgosOv4iLDr8o1EVRHVjIuqaIiEtElorIu87rHBGZLyLrReRVZ1a2oPhg0wf8dcFf+cv8v9BQ9ZnX7+XOuXeSFJ3EzLEzg1UcpZRqN0LZX/FGYHWd138DHjDGHAbsAa4I1heffdjZXD70cl5d+2qDgeG5Vc+xqnAVd4y/Iyz7iiulVHOFJCiISBZwBvCE81qAE4A3nE2eBc4O4vdz86ibuWzIZbyy9hXuWXDPjwLDpuJNPLL0ESZnT+aU3qcEqyhKKdWuhKpN4Z/AbUCi8zoN2GuMqe3QnwvUO9ZcRGYAMwCys7NbXAAR4Vejf4Xf+Hlu1XNESAS3jb0tkBzrrrl3Ee2O5rfjf6tdR5VSnUabBwURORPIN8YsFpFJzd3fGDMLmAW2S+ohloVbxtyC3/h5YfULiAi3jrmVl9e8zNL8pfxp4p/IiMs4lK9QSqmwEoo7hYnAFBE5HYgBkoAHgRQRcTt3C1nA9rYojIhw29jbMBieX/U85Z5yPtj0ARN7TmRKvyltUQSllGo32rxNwRhzhzEmyxjTB5gGfGaMuQj4HDjX2Ww68HZblUlEmDl2JhcOupC3fngLQbhrwl1abaSU6nTa0ziFmcArIvInYCnwZFt+uYhwx7g76BHfg74pfeme0L0tv14ppdoFTXOhlFKdTENpLjSvslJKqQANCkoppQI0KCillArQoKCUUipAg4JSSqkADQpKKaUCNCgopZQK0KCglFIqIKwHr4lIAbAFSAd2h7g4rUWPpf3pKMcBHedYOspxQGiOpbcxpt5sn2EdFGqJyKKDjc4LN3os7U9HOQ7oOMfSUY4D2t+xaPWRUkqpAA0KSimlAjpKUJgV6gK0Ij2W9qejHAd0nGPpKMcB7exYOkSbglJKqdbRUe4UlFJKtQINCkoppQLCPiiIyKkislZE1ovI7aEuT31E5CkRyReR7+usSxWRj0XkB2fZxVkvIvKQczzLRWRUnX2mO9v/ICLTQ3AcvUTkcxFZJSIrReTGcDwWEYkRkQUi8p1zHL931ueIyHynvK+KSJSzPtp5vd55v0+dz7rDWb9WRE5py+OoS0RcIrJURN51XoflsYjIZhFZISLLRGSRsy6sfl/O96eIyBsiskZEVovIkWFzHMaYsH0ALmAD0BeIAr4DBoe6XPWU81hgFPB9nXX3Arc7z28H/uY8Px34ABBgAjDfWZ8KbHSWXZznXdr4OLoDo5znicA6YHC4HYtTngTneSQw3ynfa8A0Z/1jwDXO82uBx5zn04BXneeDnd9cNJDj/BZdIfqN/Qp4CXjXeR2WxwJsBtIPWBdWvy+nDM8CVzrPo4CUcDmONv/xtvI//JHAR3Ve3wHcEepyHaSsfdg/KKwFujvPuwNrneePAxceuB1wIfB4nfX7bReiY3obOCmcjwWIA5YA47GjSt0H/raAj4AjneduZzs58PdWd7s2PoYs4FPgBOBdp2zheiyb+XFQCKvfF5AMbMLpyBNuxxHu1Uc9gW11Xuc668JBV2PMTud5HtDVeX6wY2pXx+pUO4zEXmWH3bE41S3LgHzgY+yV8V5jjLeeMgXK67xfDKTRDo7D8U/gNsDvvE4jfI/FAP8TkcUiMsNZF26/rxygAHjaqdJ7QkTiCZPjCPeg0CEYexkQNn2DRSQBeBO4yRhTUve9cDkWY4zPGDMCe5U9DhgU2hK1jIicCeQbYxaHuiyt5GhjzCjgNOA6ETm27pth8vtyY6uLHzXGjATKsdVFAe35OMI9KGwHetV5neWsCwe7RKQ7gLPMd9Yf7JjaxbGKSCQ2ILxojHnLWR2WxwJgjNkLfI6tYkkREXc9ZQqU13k/GSikfRzHRGCKiGwGXsFWIT1IeB4LxpjtzjIf+A82YIfb7ysXyDXGzHdev4ENEmFxHOEeFBYC/Z2eFlHYhrPZIS5TU80GansTTMfWz9euv9TpkTABKHZuOT8CThaRLk6vhZOddW1GRAR4ElhtjPlHnbfC6lhEJENEUpznsdh2kdXY4HDuQY6j9vjOBT5zrvRmA9OcHj05QH9gQZschMMYc4cxJssY0wf7+//MGHMRYXgsIhIvIom1z7G/i+8Js9+XMSYP2CYiA51Vk4FVYXMcbdX4EsRGndOxvWA2AL8NdXkOUsaXgZ2AB3sVcQW2HvdT4AfgEyDV2VaAR5zjWQGMqfM5PwfWO4/LQ3AcR2NveZcDy5zH6eF2LMBwYKlzHN8Ddzrr+2JPhOuB14FoZ32M83q9837fOp/1W+f41gKnhfh3Nol9vY/C7licMn/nPFbW/j2H2+/L+f4RwCLnN/ZfbO+hsDgOTXOhlFIqINyrj5RSSrUiDQpKKaUCNCgopZQK0KCglFIqQIOCUkqpAA0KSrUiETEicn+d17eIyN0hLJJSzaJBQanWVQ2cIyLpoS6IUi2hQUGp1uXFzrl7c6gLolRLaFBQqvU9AlwkIsmhLohSzaVBQalWZmzm2OeAX4a6LEo1lwYFpYLjn9gcV/EhLodSzaJBQakgMMYUYafEvCLUZVGqOTQoKBU89wPaC0mFFc2SqpRSKkDvFJRSSgVoUFBKKRWgQUEppVSABgWllFIBGhSUUkoFaFBQSikVoEFBKaVUwP8Hw/2QpdfE5gYAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "M = 4096\n", - "Ns = [128*i for i in range(2, 50)]\n", - "tri_ms = []\n", - "ref_ms = []\n", - "def_ms = []\n", - "for N in Ns:\n", - " x = torch.randn(M, N, device='cuda', dtype=torch.float32)\n", - " gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)\n", - " tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]\n", - " ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]\n", - " def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]\n", - "plt.xlabel('N')\n", - "plt.ylabel('Bandwidth (GB/s)')\n", - "plt.plot(Ns, tri_ms, label = 'Triton')\n", - "plt.plot(Ns, ref_ms, label = 'Torch')\n", - "plt.plot(Ns, def_ms, label = 'Naive')\n", - "plt.legend()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py new file mode 100644 index 000000000..dfbb273fe --- /dev/null +++ b/python/tutorials/02-fused-softmax.py @@ -0,0 +1,181 @@ +""" +Fused Softmax +================= +""" + +# %% +# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch + + +# Compute the row-wise softmax of x \in R^{M \times N} +def naive_softmax(x): + # read MN elements ; write M elements + x_max = torch.max(x, axis=1)[0] + # read 2MN elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(x) + # read MN elements ; write M elements + denominator = torch.sum(numerator, axis=1) + # read 2MN elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 7MN elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in pytorch, computing :math:`y` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements. +# Instead, we want to write a custom "fused" pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory. + +# %% +# Writing the Compute Kernel +# ---------------------------- +# Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes: +# +# .. code-block:: C +# +# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){ +# // row index +# int m = get_program_id(0); +# // column indices +# int n [BLOCK] = 0 ... BLOCK; +# // the memory address of all the elements +# // that we want to load can be computed as follows +# float* px [BLOCK] = X + m*stride_xm + n; +# // because BLOCK has to be a power of two +# // (per Triton-C specs), it is important +# // to guard each memory operation with predicates +# // or we will read out of bounds +# bool check[BLOCK] = n < N; +# float x [BLOCK] = check ? *px : -F32_INFINITY; +# // syntax for reduction in Triton is: +# // x[..., OPERATOR, ...] +# // ^ +# // index +# // The operators currently supported are {min, max, +} +# float z [BLOCK] = x - x[max]; +# // The exponential in Triton is fast but approximate +# // (i.e., like __expf in CUDA) +# float num [BLOCK] = exp(z); +# float denom = num[+]; +# // The result of the reduction is now stored in y +# float y [BLOCK] = num / denom; +# // We write it back +# float* py [BLOCK] = Y + m*stride_ym + n; +# *?(check)py = y; +# } + +# %% +# Writing the Compute Kernel +# ---------------------------- + +import torch +import triton + +# %% +# source-code for Triton compute kernel +_src = """ +__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){ + int m = get_program_id(0); + int n [BLOCK] = 0 ... BLOCK; + float* px [BLOCK] = X + m*stride_xm + n; + bool check[BLOCK] = n < N; + float x [BLOCK] = check ? *px : -F32_INFINITY; + float z [BLOCK] = x - x[max]; + float num [BLOCK] = exp(z); + float denom = num[+]; + float y [BLOCK] = num / denom; + float* py [BLOCK] = Y + m*stride_ym + n; + *?(check)py = y; +} +""" + + +# %% +# Writing the Torch bindings +# ---------------------------- +# We need to make sure that BLOCK is the smallest power of two +# greater than the number of rows N of the input matrix. +# Different values of BLOCK will result in different kernels +def next_power_of_2(n): + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + + +_kernels = dict() + + +def make_kernel(N, device): + BLOCK = next_power_of_2(N) + key = (BLOCK, device) + if key not in _kernels: + defines = {'BLOCK': BLOCK} + _kernels[key] = triton.kernel(_src, device=device, defines=defines) + return _kernels[key] + + +class _softmax(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + # constraints of the op + assert x.dtype == torch.float32 + y = torch.empty_like(x) + # *create launch grid*: + # here we just launch a grid of M programs + M, N = y.shape + grid = lambda opt: (M, ) + # *launch kernel*: + kernel = make_kernel(N, y.device) + kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid) + return y + + +softmax = _softmax.apply + +# %% +# Unit Test +# ---------- + +x = torch.randn(1823, 781, device='cuda') +y_tri = softmax(x) +y_ref = torch.softmax(x, axis=1) +print(y_tri) +print(y_ref) +print(torch.allclose(y_tri, y_ref)) + +# %% +# Seems to work! + +# %% +# Benchmark +# ---------- + +import matplotlib.pyplot as plt + +M = 4096 +Ns = [128 * i for i in range(2, 50)] +tri_ms = [] +ref_ms = [] +def_ms = [] +for N in Ns: + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))] + ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))] + def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))] +plt.xlabel('N') +plt.ylabel('Bandwidth (GB/s)') +plt.plot(Ns, tri_ms, label='Triton') +plt.plot(Ns, ref_ms, label='Torch') +plt.plot(Ns, def_ms, label='Naive') +plt.legend() +plt.show() \ No newline at end of file diff --git a/python/tutorials/README.rst b/python/tutorials/README.rst new file mode 100644 index 000000000..f47d9864f --- /dev/null +++ b/python/tutorials/README.rst @@ -0,0 +1,4 @@ +Triton Tutorials +================== + +Below is a gallery of tutorials to help you get started with Triton. \ No newline at end of file