[DOCS] Switched tutorials to Python and use Sphinx Gallery
This commit is contained in:
13
docs/conf.py
13
docs/conf.py
@@ -30,9 +30,18 @@
|
|||||||
# Add any Sphinx extension module names here, as strings. They can be
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
# ones.
|
# ones.
|
||||||
extensions = ['nbsphinx', 'sphinx.ext.autosectionlabel']
|
extensions = ['sphinx.ext.autosectionlabel']
|
||||||
autosectionlabel_prefix_document = True
|
autosectionlabel_prefix_document = True
|
||||||
|
|
||||||
|
# Sphinx gallery
|
||||||
|
extensions += ['sphinx_gallery.gen_gallery']
|
||||||
|
sphinx_gallery_conf = {
|
||||||
|
'examples_dirs': '../python/tutorials/',
|
||||||
|
'gallery_dirs': 'tutorials',
|
||||||
|
'filename_pattern': '',
|
||||||
|
'ignore_pattern': r'__init__\.py',
|
||||||
|
}
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
templates_path = ['_templates']
|
templates_path = ['_templates']
|
||||||
|
|
||||||
@@ -69,7 +78,7 @@ language = None
|
|||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This patterns also effect to html_static_path and html_extra_path
|
# This patterns also effect to html_static_path and html_extra_path
|
||||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
|
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||||
|
|
||||||
# The name of the Pygments (syntax highlighting) style to use.
|
# The name of the Pygments (syntax highlighting) style to use.
|
||||||
pygments_style = 'sphinx'
|
pygments_style = 'sphinx'
|
||||||
|
@@ -15,7 +15,6 @@ Welcome to Triton's documentation!
|
|||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:caption: Tutorials
|
:caption: Installation Instructions
|
||||||
|
|
||||||
Vector Addition <tutorials/01-vector-add.ipynb>
|
tutorials/index
|
||||||
Fused Softmax <tutorials/02-fused-softmax.ipynb>
|
|
@@ -1 +0,0 @@
|
|||||||
../../python/tutorials/01-vector-add.ipynb
|
|
158
docs/tutorials/01-vector-add.ipynb
Normal file
158
docs/tutorials/01-vector-add.ipynb
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%matplotlib inline"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"\n# Vector Addition\nIn this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n* The basic syntax of the Triton programming language\n* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n* The best practices for validating and benchmarking custom ops against native reference implementations\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Writing the Compute Kernel\n\nEach compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel\non different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)\nprogramming model for more details).\n\n .. code-block:: C\n\n __global__ void add(float* z, float* x, float* y, int N){\n // The `get_program_id(i)` returns the i-th coordinate\n // of the program in the overaching SPMD context\n // (a.k.a launch grid). This is what allows us to process\n // different chunks of data in parallel.\n // For those similar with CUDA, `get_program_id({0,1,2})`\n // is similar to blockIdx.{x,y,z}\n int pid = get_program_id(0);\n // In Triton, arrays are first-class citizen. In other words,\n // they are primitives data-types and are -- contrary to C and\n // CUDA -- not implemented as pointers to contiguous chunks of\n // memory.\n // In the few lines below, we create an array of `BLOCK` pointers\n // whose memory values are, e.g.:\n // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz [BLOCK] = z + offset;\n float* px [BLOCK] = x + offset;\n float* py [BLOCK] = y + offset;\n // Simple element-wise control-flow for load/store operations can\n // be achieved using the the ternary operator `cond ? val_true : val_false`\n // or the conditional dereferencing operator `*?(cond)ptr\n // Here, we make sure that we do not access memory out-of-bounds when we\n // write-back `z`\n bool check[BLOCK] = offset < N;\n *?(check)pz = *?(check)px + *?(check)py;\n }\n\nThe existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Writing the Torch bindings\nThe only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n\nTo create a `triton.kernel`, you only need three things:\n- `source: string`: the source-code of the kernel you want to create\n- `device: torch.device`: the device you want to compile this code for\n- `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\nimport triton"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"source-code for Triton compute kernel\nhere we just copy-paste the above code without the extensive comments.\nyou may prefer to store it in a .c file and load it from there instead.\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"_src = \"\"\"\n__global__ void add(float* z, float* x, float* y, int N){\n // program id\n int pid = get_program_id(0);\n // create arrays of pointers\n int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n float* pz[BLOCK] = z + offset;\n float* px[BLOCK] = x + offset;\n float* py[BLOCK] = y + offset;\n // bounds checking\n bool check[BLOCK] = offset < N;\n // write-back\n *?(check)pz = *?(check)px + *?(check)py;\n}\n \"\"\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This function returns a callable `triton.kernel` object\ncreated from the above source code.\nFor portability, we maintain a cache of kernels for different `torch.device`\nWe compile the kernel with -DBLOCK=1024\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def make_add_kernel(device):\n cache = make_add_kernel.cache\n if device not in cache:\n defines = {'BLOCK': 1024}\n cache[device] = triton.kernel(_src, device=device, defines=defines)\n return cache[device]\n\n\nmake_add_kernel.cache = dict()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This is a standard torch custom autograd Function\nThe only difference is that we can now use the above kernel\nin the `forward` and `backward` functions.`\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class _add(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, y):\n # constraints of the op\n assert x.dtype == torch.float32\n # *allocate output*\n z = torch.empty_like(x)\n # *create launch grid*:\n # this is a function which takes compilation parameters `opt`\n # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n # triton.cdiv is a shortcut for ceil division:\n # triton.cdiv(a, b) = (a + b - 1) // b\n N = z.shape[0]\n grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n # *launch kernel*:\n # pointer to the data of torch tensors can be retrieved with\n # the `.data_ptr()` method\n kernel = make_add_kernel(z.device)\n kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)\n return z\n\n\n# Just like we standard PyTorch ops We use the `.apply` method to create a callable object for our function\nadd = _add.apply"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Writing a Unit Test\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"torch.manual_seed(0)\nx = torch.rand(98432, device='cuda')\ny = torch.rand(98432, device='cuda')\nza = x + y\nzb = add(x, y)\nprint(za)\nprint(zb)\nprint(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Writing a Benchmark\nWe can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does\n\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"warmup = 10\nrep = 200\nfor N in [2**i for i in range(17, 26, 1)]:\n x = torch.rand(N, device='cuda')\n y = torch.rand(N, device='cuda')\n triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)\n torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)\n # print the performance of triton and torch as well as the achieved bandwidth\n print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@@ -1 +0,0 @@
|
|||||||
../../python/tutorials/02-fused-softmax.ipynb
|
|
@@ -1,329 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "acute-possession",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Vector Addition"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "median-malaysia",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
|
||||||
"* The basic syntax of the Triton programming language\n",
|
|
||||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
|
||||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "identical-conditions",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Writing the Compute Kernel"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "collectible-belle",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"```c\n",
|
|
||||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
|
||||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
|
||||||
" // of the program in the overaching SPMD context\n",
|
|
||||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
|
||||||
" // different chunks of data in parallel.\n",
|
|
||||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
|
||||||
" // is similar to blockIdx.{x,y,z}\n",
|
|
||||||
" int pid = get_program_id(0);\n",
|
|
||||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
|
||||||
" // they are primitives data-types and are -- contrary to C and\n",
|
|
||||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
|
||||||
" // memory.\n",
|
|
||||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
|
||||||
" // whose memory values are, e.g.:\n",
|
|
||||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
|
||||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
|
||||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
|
||||||
" float* pz [BLOCK] = z + offset;\n",
|
|
||||||
" float* px [BLOCK] = x + offset;\n",
|
|
||||||
" float* py [BLOCK] = y + offset;\n",
|
|
||||||
" // Simple element-wise control-flow for load/store operations can\n",
|
|
||||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
|
||||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
|
||||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
|
||||||
" // write-back `z`\n",
|
|
||||||
" bool check[BLOCK] = offset < N;\n",
|
|
||||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
|
||||||
"}\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "forbidden-wednesday",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Writing the Torch bindings"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "numerical-agency",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
|
||||||
"\n",
|
|
||||||
"To create a `triton.kernel`, you only need three things:\n",
|
|
||||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
|
||||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
|
||||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
|
||||||
"\n",
|
|
||||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "sporting-keyboard",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"import triton\n",
|
|
||||||
"\n",
|
|
||||||
"# source-code for Triton compute kernel\n",
|
|
||||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
|
||||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
|
||||||
"_src = \"\"\"\n",
|
|
||||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
|
||||||
" // program id\n",
|
|
||||||
" int pid = get_program_id(0);\n",
|
|
||||||
" // create arrays of pointers\n",
|
|
||||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
|
||||||
" float* pz[BLOCK] = z + offset;\n",
|
|
||||||
" float* px[BLOCK] = x + offset;\n",
|
|
||||||
" float* py[BLOCK] = y + offset;\n",
|
|
||||||
" // bounds checking\n",
|
|
||||||
" bool check[BLOCK] = offset < N;\n",
|
|
||||||
" // write-back\n",
|
|
||||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
|
||||||
"}\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
"# This function returns a callable `triton.kernel` object\n",
|
|
||||||
"# created from the above source code.\n",
|
|
||||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
|
||||||
"# We compile the kernel with -DBLOCK=1024\n",
|
|
||||||
"_kernels = dict()\n",
|
|
||||||
"def make_add_kernel(device):\n",
|
|
||||||
" if device not in _kernels:\n",
|
|
||||||
" defines = {'BLOCK': 1024}\n",
|
|
||||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
|
||||||
" return _kernels[device]\n",
|
|
||||||
"\n",
|
|
||||||
"# This is a standard torch custom autograd Function\n",
|
|
||||||
"# The only difference is that we can now use the above kernel\n",
|
|
||||||
"# in the `forward` and `backward` functions.`\n",
|
|
||||||
"class _add(torch.autograd.Function):\n",
|
|
||||||
" \n",
|
|
||||||
" @staticmethod\n",
|
|
||||||
" def forward(ctx, x, y):\n",
|
|
||||||
" # constraints of the op\n",
|
|
||||||
" assert x.dtype == torch.float32\n",
|
|
||||||
" # *allocate output*\n",
|
|
||||||
" z = torch.empty_like(x)\n",
|
|
||||||
" # *create launch grid*:\n",
|
|
||||||
" # this is a function which takes compilation parameters `opt`\n",
|
|
||||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
|
||||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
|
||||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
|
||||||
" N = z.shape[0]\n",
|
|
||||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
|
||||||
" # *launch kernel*:\n",
|
|
||||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
|
||||||
" # the `.data_ptr()` method\n",
|
|
||||||
" kernel = make_add_kernel(z.device)\n",
|
|
||||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
|
||||||
" return z\n",
|
|
||||||
"# Just like we standard PyTorch ops\n",
|
|
||||||
"# We use the `.apply` method to create a \n",
|
|
||||||
"# callable object for our function\n",
|
|
||||||
"add = _add.apply"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "separated-polyester",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "exclusive-salvation",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Writing a Unit Test"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "supported-ribbon",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
|
||||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
|
||||||
"The maximum difference between torch and triton is 0.0\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"torch.manual_seed(0)\n",
|
|
||||||
"x = torch.rand(98432, device='cuda')\n",
|
|
||||||
"y = torch.rand(98432, device='cuda')\n",
|
|
||||||
"za = x + y\n",
|
|
||||||
"zb = add(x, y)\n",
|
|
||||||
"print(za)\n",
|
|
||||||
"print(zb)\n",
|
|
||||||
"print(f'The maximum difference between torch and triton is '\n",
|
|
||||||
" f'{torch.max(torch.abs(za - zb))}')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "otherwise-canadian",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Seems to work!"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "polished-australia",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Writing a Benchmark"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "historic-glass",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"id": "strange-luxembourg",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# We now want to benchmark the performance of `add`\n",
|
|
||||||
"# Against that of PyTorch for increasing vector sizes\n",
|
|
||||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
|
||||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
|
||||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
|
||||||
" ret = fn()\n",
|
|
||||||
" for i in range(warmup):\n",
|
|
||||||
" fn()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" start_event.record()\n",
|
|
||||||
" for i in range(rep):\n",
|
|
||||||
" fn()\n",
|
|
||||||
" end_event.record()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
|
||||||
" return time_ms"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "hairy-claim",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"id": "pleasant-valley",
|
|
||||||
"metadata": {
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"131072 0.020 0.003\n",
|
|
||||||
"262144 0.019 0.004\n",
|
|
||||||
"524288 0.016 0.016\n",
|
|
||||||
"1048576 0.033 0.033\n",
|
|
||||||
"2097152 0.071 0.070\n",
|
|
||||||
"4194304 0.142 0.144\n",
|
|
||||||
"8388608 0.287 0.286\n",
|
|
||||||
"16777216 0.572 0.568\n",
|
|
||||||
"33554432 1.139 1.110\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
|
||||||
" x = torch.rand(N, device='cuda')\n",
|
|
||||||
" y = torch.rand(N, device='cuda')\n",
|
|
||||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
|
||||||
" torch_ms = do_bench(lambda: x + y)\n",
|
|
||||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
|
||||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "juvenile-supplement",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.7.9"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
158
python/tutorials/01-vector-add.py
Normal file
158
python/tutorials/01-vector-add.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Vector Addition
|
||||||
|
=================
|
||||||
|
In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:
|
||||||
|
* The basic syntax of the Triton programming language
|
||||||
|
* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API
|
||||||
|
* The best practices for validating and benchmarking custom ops against native reference implementations
|
||||||
|
"""
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing the Compute Kernel
|
||||||
|
# --------------------------
|
||||||
|
#
|
||||||
|
# Each compute kernel is declared using the :code:`__global__` attribute, and executed many times in parallel
|
||||||
|
# on different chunks of data (See the `Single Program, Multiple Data <(https://en.wikipedia.org/wiki/SPMD>`_)
|
||||||
|
# programming model for more details).
|
||||||
|
#
|
||||||
|
# .. code-block:: C
|
||||||
|
#
|
||||||
|
# __global__ void add(float* z, float* x, float* y, int N){
|
||||||
|
# // The `get_program_id(i)` returns the i-th coordinate
|
||||||
|
# // of the program in the overaching SPMD context
|
||||||
|
# // (a.k.a launch grid). This is what allows us to process
|
||||||
|
# // different chunks of data in parallel.
|
||||||
|
# // For those similar with CUDA, `get_program_id({0,1,2})`
|
||||||
|
# // is similar to blockIdx.{x,y,z}
|
||||||
|
# int pid = get_program_id(0);
|
||||||
|
# // In Triton, arrays are first-class citizen. In other words,
|
||||||
|
# // they are primitives data-types and are -- contrary to C and
|
||||||
|
# // CUDA -- not implemented as pointers to contiguous chunks of
|
||||||
|
# // memory.
|
||||||
|
# // In the few lines below, we create an array of `BLOCK` pointers
|
||||||
|
# // whose memory values are, e.g.:
|
||||||
|
# // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]
|
||||||
|
# // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time
|
||||||
|
# int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||||
|
# float* pz [BLOCK] = z + offset;
|
||||||
|
# float* px [BLOCK] = x + offset;
|
||||||
|
# float* py [BLOCK] = y + offset;
|
||||||
|
# // Simple element-wise control-flow for load/store operations can
|
||||||
|
# // be achieved using the the ternary operator `cond ? val_true : val_false`
|
||||||
|
# // or the conditional dereferencing operator `*?(cond)ptr
|
||||||
|
# // Here, we make sure that we do not access memory out-of-bounds when we
|
||||||
|
# // write-back `z`
|
||||||
|
# bool check[BLOCK] = offset < N;
|
||||||
|
# *?(check)pz = *?(check)px + *?(check)py;
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing the Torch bindings
|
||||||
|
# --------------------------
|
||||||
|
# The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.
|
||||||
|
#
|
||||||
|
# To create a `triton.kernel`, you only need three things:
|
||||||
|
# - `source: string`: the source-code of the kernel you want to create
|
||||||
|
# - `device: torch.device`: the device you want to compile this code for
|
||||||
|
# - `defines: dict`: the set of macros that you want the pre-processor to `#define` for you
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# source-code for Triton compute kernel
|
||||||
|
# here we just copy-paste the above code without the extensive comments.
|
||||||
|
# you may prefer to store it in a .c file and load it from there instead.
|
||||||
|
|
||||||
|
_src = """
|
||||||
|
__global__ void add(float* z, float* x, float* y, int N){
|
||||||
|
// program id
|
||||||
|
int pid = get_program_id(0);
|
||||||
|
// create arrays of pointers
|
||||||
|
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||||
|
float* pz[BLOCK] = z + offset;
|
||||||
|
float* px[BLOCK] = x + offset;
|
||||||
|
float* py[BLOCK] = y + offset;
|
||||||
|
// bounds checking
|
||||||
|
bool check[BLOCK] = offset < N;
|
||||||
|
// write-back
|
||||||
|
*?(check)pz = *?(check)px + *?(check)py;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# This function returns a callable `triton.kernel` object
|
||||||
|
# created from the above source code.
|
||||||
|
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||||
|
# We compile the kernel with -DBLOCK=1024
|
||||||
|
|
||||||
|
|
||||||
|
def make_add_kernel(device):
|
||||||
|
cache = make_add_kernel.cache
|
||||||
|
if device not in cache:
|
||||||
|
defines = {'BLOCK': 1024}
|
||||||
|
cache[device] = triton.kernel(_src, device=device, defines=defines)
|
||||||
|
return cache[device]
|
||||||
|
|
||||||
|
|
||||||
|
make_add_kernel.cache = dict()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# This is a standard torch custom autograd Function
|
||||||
|
# The only difference is that we can now use the above kernel
|
||||||
|
# in the `forward` and `backward` functions.`
|
||||||
|
|
||||||
|
|
||||||
|
class _add(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, y):
|
||||||
|
# constraints of the op
|
||||||
|
assert x.dtype == torch.float32
|
||||||
|
# *allocate output*
|
||||||
|
z = torch.empty_like(x)
|
||||||
|
# *create launch grid*:
|
||||||
|
# this is a function which takes compilation parameters `opt`
|
||||||
|
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||||
|
# triton.cdiv is a shortcut for ceil division:
|
||||||
|
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||||
|
N = z.shape[0]
|
||||||
|
grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )
|
||||||
|
# *launch kernel*:
|
||||||
|
# pointer to the data of torch tensors can be retrieved with
|
||||||
|
# the `.data_ptr()` method
|
||||||
|
kernel = make_add_kernel(z.device)
|
||||||
|
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid=grid)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
# Just like we standard PyTorch ops We use the `.apply` method to create a callable object for our function
|
||||||
|
add = _add.apply
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing a Unit Test
|
||||||
|
# --------------------------
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.rand(98432, device='cuda')
|
||||||
|
y = torch.rand(98432, device='cuda')
|
||||||
|
za = x + y
|
||||||
|
zb = add(x, y)
|
||||||
|
print(za)
|
||||||
|
print(zb)
|
||||||
|
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing a Benchmark
|
||||||
|
# --------------------------
|
||||||
|
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does
|
||||||
|
|
||||||
|
warmup = 10
|
||||||
|
rep = 200
|
||||||
|
for N in [2**i for i in range(17, 26, 1)]:
|
||||||
|
x = torch.rand(N, device='cuda')
|
||||||
|
y = torch.rand(N, device='cuda')
|
||||||
|
triton_ms = triton.testing.do_bench(lambda: add(x, y), warmup=warmup, rep=rep)
|
||||||
|
torch_ms = triton.testing.do_bench(lambda: x + y, warmup=warmup, rep=rep)
|
||||||
|
# print the performance of triton and torch as well as the achieved bandwidth
|
||||||
|
print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')
|
File diff suppressed because one or more lines are too long
181
python/tutorials/02-fused-softmax.py
Normal file
181
python/tutorials/02-fused-softmax.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
Fused Softmax
|
||||||
|
=================
|
||||||
|
"""
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice.
|
||||||
|
# Let us consider instead the case of a simple (numerically stabilized) softmax operation:
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Compute the row-wise softmax of x \in R^{M \times N}
|
||||||
|
def naive_softmax(x):
|
||||||
|
# read MN elements ; write M elements
|
||||||
|
x_max = torch.max(x, axis=1)[0]
|
||||||
|
# read 2MN elements ; write MN elements
|
||||||
|
z = x - x_max[:, None]
|
||||||
|
# read MN elements ; write MN elements
|
||||||
|
numerator = torch.exp(x)
|
||||||
|
# read MN elements ; write M elements
|
||||||
|
denominator = torch.sum(numerator, axis=1)
|
||||||
|
# read 2MN elements ; write MN elements
|
||||||
|
ret = numerator / denominator[:, None]
|
||||||
|
# in total: read 7MN elements ; wrote 3MN + 2M elements
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# When implemented naively in pytorch, computing :math:`y` requires reading :math:`7MN` elements from DRAM and writing back :math:`3MN + 2M` elements.
|
||||||
|
# Instead, we want to write a custom "fused" pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing the Compute Kernel
|
||||||
|
# ----------------------------
|
||||||
|
# Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:
|
||||||
|
#
|
||||||
|
# .. code-block:: C
|
||||||
|
#
|
||||||
|
# __global__ void softmax(float* Y, float* X, int stride_xm, int stride_ym, int M, int N){
|
||||||
|
# // row index
|
||||||
|
# int m = get_program_id(0);
|
||||||
|
# // column indices
|
||||||
|
# int n [BLOCK] = 0 ... BLOCK;
|
||||||
|
# // the memory address of all the elements
|
||||||
|
# // that we want to load can be computed as follows
|
||||||
|
# float* px [BLOCK] = X + m*stride_xm + n;
|
||||||
|
# // because BLOCK has to be a power of two
|
||||||
|
# // (per Triton-C specs), it is important
|
||||||
|
# // to guard each memory operation with predicates
|
||||||
|
# // or we will read out of bounds
|
||||||
|
# bool check[BLOCK] = n < N;
|
||||||
|
# float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||||
|
# // syntax for reduction in Triton is:
|
||||||
|
# // x[..., OPERATOR, ...]
|
||||||
|
# // ^
|
||||||
|
# // index
|
||||||
|
# // The operators currently supported are {min, max, +}
|
||||||
|
# float z [BLOCK] = x - x[max];
|
||||||
|
# // The exponential in Triton is fast but approximate
|
||||||
|
# // (i.e., like __expf in CUDA)
|
||||||
|
# float num [BLOCK] = exp(z);
|
||||||
|
# float denom = num[+];
|
||||||
|
# // The result of the reduction is now stored in y
|
||||||
|
# float y [BLOCK] = num / denom;
|
||||||
|
# // We write it back
|
||||||
|
# float* py [BLOCK] = Y + m*stride_ym + n;
|
||||||
|
# *?(check)py = y;
|
||||||
|
# }
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing the Compute Kernel
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# source-code for Triton compute kernel
|
||||||
|
_src = """
|
||||||
|
__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){
|
||||||
|
int m = get_program_id(0);
|
||||||
|
int n [BLOCK] = 0 ... BLOCK;
|
||||||
|
float* px [BLOCK] = X + m*stride_xm + n;
|
||||||
|
bool check[BLOCK] = n < N;
|
||||||
|
float x [BLOCK] = check ? *px : -F32_INFINITY;
|
||||||
|
float z [BLOCK] = x - x[max];
|
||||||
|
float num [BLOCK] = exp(z);
|
||||||
|
float denom = num[+];
|
||||||
|
float y [BLOCK] = num / denom;
|
||||||
|
float* py [BLOCK] = Y + m*stride_ym + n;
|
||||||
|
*?(check)py = y;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Writing the Torch bindings
|
||||||
|
# ----------------------------
|
||||||
|
# We need to make sure that BLOCK is the smallest power of two
|
||||||
|
# greater than the number of rows N of the input matrix.
|
||||||
|
# Different values of BLOCK will result in different kernels
|
||||||
|
def next_power_of_2(n):
|
||||||
|
n -= 1
|
||||||
|
n |= n >> 1
|
||||||
|
n |= n >> 2
|
||||||
|
n |= n >> 4
|
||||||
|
n |= n >> 8
|
||||||
|
n |= n >> 16
|
||||||
|
n += 1
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
_kernels = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def make_kernel(N, device):
|
||||||
|
BLOCK = next_power_of_2(N)
|
||||||
|
key = (BLOCK, device)
|
||||||
|
if key not in _kernels:
|
||||||
|
defines = {'BLOCK': BLOCK}
|
||||||
|
_kernels[key] = triton.kernel(_src, device=device, defines=defines)
|
||||||
|
return _kernels[key]
|
||||||
|
|
||||||
|
|
||||||
|
class _softmax(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
# constraints of the op
|
||||||
|
assert x.dtype == torch.float32
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
# *create launch grid*:
|
||||||
|
# here we just launch a grid of M programs
|
||||||
|
M, N = y.shape
|
||||||
|
grid = lambda opt: (M, )
|
||||||
|
# *launch kernel*:
|
||||||
|
kernel = make_kernel(N, y.device)
|
||||||
|
kernel(y.data_ptr(), x.data_ptr(), y.stride(0), x.stride(0), M, N, grid=grid)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
softmax = _softmax.apply
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Unit Test
|
||||||
|
# ----------
|
||||||
|
|
||||||
|
x = torch.randn(1823, 781, device='cuda')
|
||||||
|
y_tri = softmax(x)
|
||||||
|
y_ref = torch.softmax(x, axis=1)
|
||||||
|
print(y_tri)
|
||||||
|
print(y_ref)
|
||||||
|
print(torch.allclose(y_tri, y_ref))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Seems to work!
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Benchmark
|
||||||
|
# ----------
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
M = 4096
|
||||||
|
Ns = [128 * i for i in range(2, 50)]
|
||||||
|
tri_ms = []
|
||||||
|
ref_ms = []
|
||||||
|
def_ms = []
|
||||||
|
for N in Ns:
|
||||||
|
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||||
|
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||||
|
tri_ms += [gbps(triton.testing.do_bench(lambda: softmax(x)))]
|
||||||
|
ref_ms += [gbps(triton.testing.do_bench(lambda: torch.softmax(x, axis=1)))]
|
||||||
|
def_ms += [gbps(triton.testing.do_bench(lambda: naive_softmax(x)))]
|
||||||
|
plt.xlabel('N')
|
||||||
|
plt.ylabel('Bandwidth (GB/s)')
|
||||||
|
plt.plot(Ns, tri_ms, label='Triton')
|
||||||
|
plt.plot(Ns, ref_ms, label='Torch')
|
||||||
|
plt.plot(Ns, def_ms, label='Naive')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
4
python/tutorials/README.rst
Normal file
4
python/tutorials/README.rst
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
Triton Tutorials
|
||||||
|
==================
|
||||||
|
|
||||||
|
Below is a gallery of tutorials to help you get started with Triton.
|
Reference in New Issue
Block a user