diff --git a/master/.buildinfo b/master/.buildinfo index 089f93c02..9e27b9682 100644 --- a/master/.buildinfo +++ b/master/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 255e5b7e27427f0ab6fda308ee6aef63 +config: 8d52c5eda79abb41e578ed40b306519c tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/master/.doctrees/environment.pickle b/master/.doctrees/environment.pickle index 9b5c9c2d3..228c82170 100644 Binary files a/master/.doctrees/environment.pickle and b/master/.doctrees/environment.pickle differ diff --git a/master/.doctrees/getting-started/installation.doctree b/master/.doctrees/getting-started/installation.doctree index 6190cd7e2..27f8fc457 100644 Binary files a/master/.doctrees/getting-started/installation.doctree and b/master/.doctrees/getting-started/installation.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/01-vector-add.doctree b/master/.doctrees/getting-started/tutorials/01-vector-add.doctree index 62d77901f..3b3a60da3 100644 Binary files a/master/.doctrees/getting-started/tutorials/01-vector-add.doctree and b/master/.doctrees/getting-started/tutorials/01-vector-add.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree b/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree index 72c82fd7d..2fd1643ea 100644 Binary files a/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree and b/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree b/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree index 52d467dbb..405936c08 100644 Binary files a/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree and b/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree b/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree index 89a1a4b75..48060c346 100644 Binary files a/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree and b/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree b/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree index cbe9fac56..122cb1c5b 100644 Binary files a/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree and b/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree b/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree new file mode 100644 index 000000000..a19207450 Binary files /dev/null and b/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree b/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree new file mode 100644 index 000000000..b1970bcf6 Binary files /dev/null and b/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/index.doctree b/master/.doctrees/getting-started/tutorials/index.doctree index dd4cc705b..26b6f1bdb 100644 Binary files a/master/.doctrees/getting-started/tutorials/index.doctree and b/master/.doctrees/getting-started/tutorials/index.doctree differ diff --git a/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree b/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree index 34eeecbd3..2740ae711 100644 Binary files a/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree and b/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree differ diff --git a/master/.doctrees/index.doctree b/master/.doctrees/index.doctree index abcb3bfaa..b0b8ae7db 100644 Binary files a/master/.doctrees/index.doctree and b/master/.doctrees/index.doctree differ diff --git a/master/.doctrees/programming-guide/chapter-1/introduction.doctree b/master/.doctrees/programming-guide/chapter-1/introduction.doctree index 8828c92c6..e6fafc875 100644 Binary files a/master/.doctrees/programming-guide/chapter-1/introduction.doctree and b/master/.doctrees/programming-guide/chapter-1/introduction.doctree differ diff --git a/master/.doctrees/programming-guide/chapter-2/related-work.doctree b/master/.doctrees/programming-guide/chapter-2/related-work.doctree index 14fe22504..bfc855a8e 100644 Binary files a/master/.doctrees/programming-guide/chapter-2/related-work.doctree and b/master/.doctrees/programming-guide/chapter-2/related-work.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.Config.doctree b/master/.doctrees/python-api/generated/triton.Config.doctree index af4eee914..840d99316 100644 Binary files a/master/.doctrees/python-api/generated/triton.Config.doctree and b/master/.doctrees/python-api/generated/triton.Config.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.autotune.doctree b/master/.doctrees/python-api/generated/triton.autotune.doctree index 961b39c57..147595499 100644 Binary files a/master/.doctrees/python-api/generated/triton.autotune.doctree and b/master/.doctrees/python-api/generated/triton.autotune.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.heuristics.doctree b/master/.doctrees/python-api/generated/triton.heuristics.doctree index 95f4939bd..e13bad010 100644 Binary files a/master/.doctrees/python-api/generated/triton.heuristics.doctree and b/master/.doctrees/python-api/generated/triton.heuristics.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.jit.doctree b/master/.doctrees/python-api/generated/triton.jit.doctree index e28f973b0..c585e9642 100644 Binary files a/master/.doctrees/python-api/generated/triton.jit.doctree and b/master/.doctrees/python-api/generated/triton.jit.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.arange.doctree b/master/.doctrees/python-api/generated/triton.language.arange.doctree index f6e0c711e..328bffa17 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.arange.doctree and b/master/.doctrees/python-api/generated/triton.language.arange.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree index aed2fb893..d20fcb9f8 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree index 47db74f71..7f17e9888 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree index 109a81608..bffe0c6e7 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree index c96c0b82f..a4704bc1e 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree index d7b752ea3..0f426700b 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree index be238e853..57b07b41c 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree index 64d5a0847..95149aa00 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree index 9dbf4a3c8..ae15a586d 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree b/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree index 939971777..358ab043e 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree and b/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.cos.doctree b/master/.doctrees/python-api/generated/triton.language.cos.doctree index 49eb4707b..504e5e619 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.cos.doctree and b/master/.doctrees/python-api/generated/triton.language.cos.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.dot.doctree b/master/.doctrees/python-api/generated/triton.language.dot.doctree index 3870262c6..ae6f8e68a 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.dot.doctree and b/master/.doctrees/python-api/generated/triton.language.dot.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.exp.doctree b/master/.doctrees/python-api/generated/triton.language.exp.doctree index 767f42eb5..54536be64 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.exp.doctree and b/master/.doctrees/python-api/generated/triton.language.exp.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.load.doctree b/master/.doctrees/python-api/generated/triton.language.load.doctree index 7caedc95b..ca2ae6e0f 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.load.doctree and b/master/.doctrees/python-api/generated/triton.language.load.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.log.doctree b/master/.doctrees/python-api/generated/triton.language.log.doctree index d0fd10524..d53a32dd4 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.log.doctree and b/master/.doctrees/python-api/generated/triton.language.log.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.max.doctree b/master/.doctrees/python-api/generated/triton.language.max.doctree index 42d6c57df..b19b58c84 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.max.doctree and b/master/.doctrees/python-api/generated/triton.language.max.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.maximum.doctree b/master/.doctrees/python-api/generated/triton.language.maximum.doctree index cc7f88946..7c78bb4a8 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.maximum.doctree and b/master/.doctrees/python-api/generated/triton.language.maximum.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.min.doctree b/master/.doctrees/python-api/generated/triton.language.min.doctree index 4976ec400..4e2142cea 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.min.doctree and b/master/.doctrees/python-api/generated/triton.language.min.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.minimum.doctree b/master/.doctrees/python-api/generated/triton.language.minimum.doctree index 7a181d868..ebbf565ba 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.minimum.doctree and b/master/.doctrees/python-api/generated/triton.language.minimum.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree b/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree index 804ddc047..fc6f7e4fc 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree and b/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.num_programs.doctree b/master/.doctrees/python-api/generated/triton.language.num_programs.doctree index ebe0dd843..6a8bc44b3 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.num_programs.doctree and b/master/.doctrees/python-api/generated/triton.language.num_programs.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.program_id.doctree b/master/.doctrees/python-api/generated/triton.language.program_id.doctree index 0b28f85ff..703f5d0be 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.program_id.doctree and b/master/.doctrees/python-api/generated/triton.language.program_id.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.rand.doctree b/master/.doctrees/python-api/generated/triton.language.rand.doctree index 6232ace14..a73f482fb 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.rand.doctree and b/master/.doctrees/python-api/generated/triton.language.rand.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.randint.doctree b/master/.doctrees/python-api/generated/triton.language.randint.doctree index 27d4373a6..9c8214953 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.randint.doctree and b/master/.doctrees/python-api/generated/triton.language.randint.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.randint4x.doctree b/master/.doctrees/python-api/generated/triton.language.randint4x.doctree index bb31c78bc..772709f0d 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.randint4x.doctree and b/master/.doctrees/python-api/generated/triton.language.randint4x.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.randn.doctree b/master/.doctrees/python-api/generated/triton.language.randn.doctree index 281efe064..4fd8f42ea 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.randn.doctree and b/master/.doctrees/python-api/generated/triton.language.randn.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.ravel.doctree b/master/.doctrees/python-api/generated/triton.language.ravel.doctree index bfe37463c..954262cd4 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.ravel.doctree and b/master/.doctrees/python-api/generated/triton.language.ravel.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.reshape.doctree b/master/.doctrees/python-api/generated/triton.language.reshape.doctree index 9d6707e69..05f3e82c0 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.reshape.doctree and b/master/.doctrees/python-api/generated/triton.language.reshape.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree b/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree index 1d7385111..6c4a18f65 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree and b/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sin.doctree b/master/.doctrees/python-api/generated/triton.language.sin.doctree index 53c9b4122..cb427d01d 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sin.doctree and b/master/.doctrees/python-api/generated/triton.language.sin.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.softmax.doctree b/master/.doctrees/python-api/generated/triton.language.softmax.doctree index 1d82361ea..2d5becdfb 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.softmax.doctree and b/master/.doctrees/python-api/generated/triton.language.softmax.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sqrt.doctree b/master/.doctrees/python-api/generated/triton.language.sqrt.doctree index 34f3cc3ed..3828ac4fd 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sqrt.doctree and b/master/.doctrees/python-api/generated/triton.language.sqrt.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.store.doctree b/master/.doctrees/python-api/generated/triton.language.store.doctree index a17c57279..02efc289a 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.store.doctree and b/master/.doctrees/python-api/generated/triton.language.store.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.sum.doctree b/master/.doctrees/python-api/generated/triton.language.sum.doctree index 2e1d42093..311694380 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.sum.doctree and b/master/.doctrees/python-api/generated/triton.language.sum.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.where.doctree b/master/.doctrees/python-api/generated/triton.language.where.doctree index 8093efb0f..f5bca8b4b 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.where.doctree and b/master/.doctrees/python-api/generated/triton.language.where.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.language.zeros.doctree b/master/.doctrees/python-api/generated/triton.language.zeros.doctree index f32f1b54d..be2e95ee7 100644 Binary files a/master/.doctrees/python-api/generated/triton.language.zeros.doctree and b/master/.doctrees/python-api/generated/triton.language.zeros.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree b/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree index 193618794..3d7eec07c 100644 Binary files a/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree and b/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree b/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree index 3b555cbef..64c6facdd 100644 Binary files a/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree and b/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree differ diff --git a/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree b/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree index fd490963c..b1e92ec76 100644 Binary files a/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree and b/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree differ diff --git a/master/.doctrees/python-api/triton.doctree b/master/.doctrees/python-api/triton.doctree index 666b7335d..e8dd455f0 100644 Binary files a/master/.doctrees/python-api/triton.doctree and b/master/.doctrees/python-api/triton.doctree differ diff --git a/master/.doctrees/python-api/triton.language.doctree b/master/.doctrees/python-api/triton.language.doctree index 0ac2e2ebc..214cb8df7 100644 Binary files a/master/.doctrees/python-api/triton.language.doctree and b/master/.doctrees/python-api/triton.language.doctree differ diff --git a/master/.doctrees/python-api/triton.testing.doctree b/master/.doctrees/python-api/triton.testing.doctree index 69db5f8ce..562048866 100644 Binary files a/master/.doctrees/python-api/triton.testing.doctree and b/master/.doctrees/python-api/triton.testing.doctree differ diff --git a/master/_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb b/master/_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb new file mode 100644 index 000000000..d3b665487 --- /dev/null +++ b/master/_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Libdevice function\nTriton can invoke a custom function from an external library.\nIn this example, we will use the `libdevice` library to apply `asin` on a tensor.\nPlease refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.\n\nIn `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.\nFor example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.\nUsing triton, you can simply call `tl.libdevice.asinf`.\ntriton automatically selects the correct underlying device function to invoke based on input and output types.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## asin Kernel\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x = tl.libdevice.asin(x)\n tl.store(y_ptr + offsets, x, mask=mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the default libdevice library path\nWe can use the default libdevice library path encoded in `triton/language/libdevice.py`\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Customize the libdevice library path\nWe can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "output_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb b/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb new file mode 100644 index 000000000..f3bc0d2e5 --- /dev/null +++ b/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb @@ -0,0 +1,54 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Fused Attention\nThis is a Triton implementation of the Flash Attention algorithm \n(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import pytest\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(tl.float16), do, trans_a=True)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(ds.to(tl.float16), q, trans_a=True)\n # # compute dq\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n # # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=64, num_warps=4,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = 64\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n\n\n@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])\ndef test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):\n torch.manual_seed(20)\n q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n sm_scale = 0.3\n dout = torch.randn_like(q)\n # reference implementation\n M = torch.tril(torch.ones((N_CTX, N_CTX), device=\"cuda\"))\n p = torch.matmul(q, k.transpose(2, 3)) * sm_scale\n for z in range(Z):\n for h in range(H):\n p[:, :, M == 0] = float(\"-inf\")\n p = torch.softmax(p.float(), dim=-1).half()\n ref_out = torch.matmul(p, v)\n ref_out.backward(dout)\n ref_dv, v.grad = v.grad.clone(), None\n ref_dk, k.grad = k.grad.clone(), None\n ref_dq, q.grad = q.grad.clone(), None\n # triton implementation\n tri_out = attention(q, k, v, sm_scale)\n tri_out.backward(dout)\n tri_dv, v.grad = v.grad.clone(), None\n tri_dk, k.grad = k.grad.clone(), None\n tri_dq, q.grad = q.grad.clone(), None\n # compare\n triton.testing.assert_almost_equal(ref_out, tri_out)\n triton.testing.assert_almost_equal(ref_dv, tri_dv)\n triton.testing.assert_almost_equal(ref_dk, tri_dk)\n triton.testing.assert_almost_equal(ref_dq, tri_dq)\n\n\ntry:\n from flash_attn.flash_attn_interface import flash_attn_func\n HAS_FLASH = True\nexcept BaseException:\n HAS_FLASH = False\n\nBATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64\n# vary seq length for fixed head and batch=4\nconfigs = [triton.testing.Benchmark(\n x_names=['N_CTX'],\n x_vals=[2**i for i in range(10, 16)],\n line_arg='provider',\n line_vals=['triton'] + (['flash'] if HAS_FLASH else []),\n line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),\n styles=[('red', '-'), ('blue', '-')],\n ylabel='ms',\n plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',\n args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}\n) for mode in ['bwd']]\n\n\n@triton.testing.perf_report(configs)\ndef bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device=\"cuda\"):\n assert mode in ['fwd', 'bwd']\n warmup = 25\n rep = 100\n if provider == \"triton\":\n q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n sm_scale = 1.3\n fn = lambda: attention(q, k, v, sm_scale)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n if provider == \"flash\":\n lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)\n cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)\n cu_seqlens[1:] = lengths.cumsum(0)\n qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)\n fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n\n# only works on A100 at the moment\n# bench_flash_attention.run(save_path='.', print_data=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/master/_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py b/master/_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py new file mode 100644 index 000000000..bb5f7b26d --- /dev/null +++ b/master/_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py @@ -0,0 +1,74 @@ +""" +Libdevice function +=============== +Triton can invoke a custom function from an external library. +In this example, we will use the `libdevice` library to apply `asin` on a tensor. +Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions. + +In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. +For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. +Using triton, you can simply call `tl.libdevice.asinf`. +triton automatically selects the correct underlying device function to invoke based on input and output types. +""" + +# %% +# asin Kernel +# -------------------------- + +import torch + +import triton +import triton.language as tl + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + +# %% +# Using the default libdevice library path +# -------------------------- +# We can use the default libdevice library path encoded in `triton/language/libdevice.py` + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device='cuda') +output_triton = torch.zeros(size, device='cuda') +output_torch = torch.asin(x) +assert x.is_cuda and output_triton.is_cuda +n_elements = output_torch.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) + +# %% +# Customize the libdevice library path +# -------------------------- +# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. + +output_triton = torch.empty_like(x) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, + extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) diff --git a/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py b/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py new file mode 100644 index 000000000..c19ee498a --- /dev/null +++ b/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py @@ -0,0 +1,354 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _fwd_kernel[grid]( + q, k, v, sm_scale, + tmp, L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=64, num_warps=4, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = 64 + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + return dq, dk, dv, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + triton.testing.assert_almost_equal(ref_out, tri_out) + triton.testing.assert_almost_equal(ref_dv, tri_dv) + triton.testing.assert_almost_equal(ref_dk, tri_dk) + triton.testing.assert_almost_equal(ref_dq, tri_dq) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 16)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} +) for mode in ['bwd']] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + +# only works on A100 at the moment +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip b/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip index 6dc0a1b1e..25a09311b 100644 Binary files a/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip and b/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip differ diff --git a/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip b/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip index e46aa1875..cafc41408 100644 Binary files a/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip and b/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip differ diff --git a/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py b/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py index 9880b428f..333cb80ec 100644 --- a/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py +++ b/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py @@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb( cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] - a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) - dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) - mean = tl.load(Mean + rows, mask=rows < M, other=0.) - rstd = tl.load(Var + rows, mask=rows < M, other=0.) - a_hat = (a - mean[:, None]) * rstd[:, None] - dw += dout * a_hat - db += dout + UNROLL: tl.constexpr = 4 + for i in range(0, M, BLOCK_SIZE_M * UNROLL): + for j in range(UNROLL): + rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) + dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) + mean = tl.load(Mean + rows, mask=rows < M, other=0.) + rstd = tl.load(Var + rows, mask=rows < M, other=0.) + a_hat = (a - mean[:, None]) * rstd[:, None] + dw += dout * a_hat + db += dout sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(DW + cols, sum_dw, mask=cols < N) @@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function): BLOCK_SIZE_N=ctx.BLOCK_SIZE, num_warps=ctx.num_warps, ) - # accumulate partial sums in separate kernel + if N > 10240: + BLOCK_SIZE_N = 128 + BLOCK_SIZE_M = 32 + num_warps = 4 + else: + # maximize occupancy for small N + BLOCK_SIZE_N = 16 + BLOCK_SIZE_M = 16 + num_warps = 8 grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] _layer_norm_bwd_dwdb[grid]( a, dout, @@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function): dbias, M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps ) - return (da, None, dweight, dbias, None, None, - None, None, None, None, - None, - None, None, None, - None, - None, None, None, - None, None, None, - None, None, None) + return (da, None, dweight, dbias, None) def layer_norm(a, normalized_shape, weight, bias, eps): diff --git a/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb b/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb index 0930144cf..6838e6470 100644 --- a/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb +++ b/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n# Backward pass (DA + partial DW + partial DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n # allocate output\n out = torch.empty_like(a)\n # reshape input data into 2D tensor\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = weight.shape[0]\n # allocate output\n da = torch.empty_like(dout)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n # accumulate partial sums in separate kernel\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128,\n )\n return (da, None, dweight, dbias, None, None,\n None, None, None, None,\n None,\n None, None, None,\n None,\n None, None, None,\n None, None, None,\n None, None, None)\n\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n torch.manual_seed(0)\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),\n line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),\n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\n# test_layer_norm(1151, 8192, torch.float16)\nbench_layer_norm.run(save_path='.', print_data=True)" + "import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n# Backward pass (DA + partial DW + partial DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n UNROLL: tl.constexpr = 4\n for i in range(0, M, BLOCK_SIZE_M * UNROLL):\n for j in range(UNROLL):\n rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n # allocate output\n out = torch.empty_like(a)\n # reshape input data into 2D tensor\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = weight.shape[0]\n # allocate output\n da = torch.empty_like(dout)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n if N > 10240:\n BLOCK_SIZE_N = 128\n BLOCK_SIZE_M = 32\n num_warps = 4\n else:\n # maximize occupancy for small N\n BLOCK_SIZE_N = 16\n BLOCK_SIZE_M = 16\n num_warps = 8\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=BLOCK_SIZE_M,\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n num_warps=num_warps\n )\n return (da, None, dweight, dbias, None)\n\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n torch.manual_seed(0)\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),\n line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),\n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\n# test_layer_norm(1151, 8192, torch.float16)\nbench_layer_norm.run(save_path='.', print_data=True)" ] } ], diff --git a/master/_images/sphx_glr_01-vector-add_001.png b/master/_images/sphx_glr_01-vector-add_001.png index 264f2bc3f..faedfc5df 100644 Binary files a/master/_images/sphx_glr_01-vector-add_001.png and b/master/_images/sphx_glr_01-vector-add_001.png differ diff --git a/master/_images/sphx_glr_01-vector-add_thumb.png b/master/_images/sphx_glr_01-vector-add_thumb.png index bf1427cde..3d6da8f70 100644 Binary files a/master/_images/sphx_glr_01-vector-add_thumb.png and b/master/_images/sphx_glr_01-vector-add_thumb.png differ diff --git a/master/_images/sphx_glr_02-fused-softmax_001.png b/master/_images/sphx_glr_02-fused-softmax_001.png index ca3de72d8..4ccac83a3 100644 Binary files a/master/_images/sphx_glr_02-fused-softmax_001.png and b/master/_images/sphx_glr_02-fused-softmax_001.png differ diff --git a/master/_images/sphx_glr_02-fused-softmax_thumb.png b/master/_images/sphx_glr_02-fused-softmax_thumb.png index 38856739b..b52715562 100644 Binary files a/master/_images/sphx_glr_02-fused-softmax_thumb.png and b/master/_images/sphx_glr_02-fused-softmax_thumb.png differ diff --git a/master/_images/sphx_glr_03-matrix-multiplication_001.png b/master/_images/sphx_glr_03-matrix-multiplication_001.png index 5a1c16378..61fb63b0d 100644 Binary files a/master/_images/sphx_glr_03-matrix-multiplication_001.png and b/master/_images/sphx_glr_03-matrix-multiplication_001.png differ diff --git a/master/_images/sphx_glr_03-matrix-multiplication_thumb.png b/master/_images/sphx_glr_03-matrix-multiplication_thumb.png index 3b0d7c0cf..8f4303601 100644 Binary files a/master/_images/sphx_glr_03-matrix-multiplication_thumb.png and b/master/_images/sphx_glr_03-matrix-multiplication_thumb.png differ diff --git a/master/_images/sphx_glr_05-layer-norm_001.png b/master/_images/sphx_glr_05-layer-norm_001.png index 5dbc57259..33c12cadf 100644 Binary files a/master/_images/sphx_glr_05-layer-norm_001.png and b/master/_images/sphx_glr_05-layer-norm_001.png differ diff --git a/master/_images/sphx_glr_05-layer-norm_thumb.png b/master/_images/sphx_glr_05-layer-norm_thumb.png index 98d623bcf..588cf049c 100644 Binary files a/master/_images/sphx_glr_05-layer-norm_thumb.png and b/master/_images/sphx_glr_05-layer-norm_thumb.png differ diff --git a/master/_images/sphx_glr_06-fused-attention_thumb.png b/master/_images/sphx_glr_06-fused-attention_thumb.png new file mode 100644 index 000000000..8a5fed589 Binary files /dev/null and b/master/_images/sphx_glr_06-fused-attention_thumb.png differ diff --git a/master/_images/sphx_glr_07-libdevice-function_thumb.png b/master/_images/sphx_glr_07-libdevice-function_thumb.png new file mode 100644 index 000000000..8a5fed589 Binary files /dev/null and b/master/_images/sphx_glr_07-libdevice-function_thumb.png differ diff --git a/master/_sources/getting-started/tutorials/01-vector-add.rst.txt b/master/_sources/getting-started/tutorials/01-vector-add.rst.txt index 20f8ec750..e3022bcf7 100644 --- a/master/_sources/getting-started/tutorials/01-vector-add.rst.txt +++ b/master/_sources/getting-started/tutorials/01-vector-add.rst.txt @@ -238,7 +238,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p 3 32768.0 76.800002 76.800002 4 65536.0 127.999995 127.999995 5 131072.0 219.428568 219.428568 - 6 262144.0 341.333321 384.000001 + 6 262144.0 341.333321 341.333321 7 524288.0 472.615390 472.615390 8 1048576.0 614.400016 614.400016 9 2097152.0 722.823517 722.823517 @@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 1 minutes 34.829 seconds) + **Total running time of the script:** ( 1 minutes 50.020 seconds) .. _sphx_glr_download_getting-started_tutorials_01-vector-add.py: diff --git a/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt b/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt index 6766f31b8..b4b0ad79a 100644 --- a/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt +++ b/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt @@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t softmax-performance: N Triton Torch (native) Torch (jit) - 0 256.0 546.133347 512.000001 188.321838 - 1 384.0 614.400016 585.142862 153.600004 - 2 512.0 655.360017 585.142849 154.566038 + 0 256.0 546.133347 546.133347 186.181817 + 1 384.0 614.400016 585.142862 151.703707 + 2 512.0 655.360017 606.814814 154.566038 3 640.0 706.206879 640.000002 160.000000 - 4 768.0 722.823517 664.216187 162.754967 + 4 768.0 722.823517 664.216187 163.839992 .. ... ... ... ... 93 12160.0 812.359066 406.179533 198.936606 - 94 12288.0 812.429770 415.222812 199.298541 - 95 12416.0 812.498981 412.149375 198.954424 - 96 12544.0 812.566838 412.758863 199.209928 - 97 12672.0 811.007961 412.097543 199.264875 + 94 12288.0 812.429770 415.222812 199.096718 + 95 12416.0 812.498981 412.149375 198.854847 + 96 12544.0 810.925276 412.971190 199.012395 + 97 12672.0 811.007961 412.097543 199.167004 [98 rows x 4 columns] @@ -306,7 +306,7 @@ In the above plot, we can see that: .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 3 minutes 18.076 seconds) + **Total running time of the script:** ( 3 minutes 32.089 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: diff --git a/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt b/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt index 612443a28..5a0558c63 100644 --- a/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt +++ b/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt @@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we matmul-performance: M cuBLAS ... Triton Triton (+ LeakyReLU) - 0 256.0 2.730667 ... 2.978909 2.978909 + 0 256.0 2.730667 ... 3.276800 2.978909 1 384.0 7.372800 ... 7.899428 7.899428 - 2 512.0 14.563555 ... 15.420235 15.420235 + 2 512.0 14.563555 ... 16.384000 15.420235 3 640.0 22.260869 ... 24.380953 24.380953 4 768.0 32.768000 ... 35.389441 34.028308 - 5 896.0 37.971025 ... 40.140799 39.025776 + 5 896.0 39.025776 ... 40.140799 39.025776 6 1024.0 49.932191 ... 53.773130 52.428801 - 7 1152.0 45.242181 ... 48.161033 47.396572 + 7 1152.0 45.242181 ... 47.396572 47.396572 8 1280.0 51.200001 ... 57.690139 57.690139 - 9 1408.0 64.138541 ... 68.147202 65.684049 - 10 1536.0 79.526831 ... 81.355034 78.643199 - 11 1664.0 63.372618 ... 63.372618 62.492442 + 9 1408.0 64.138541 ... 68.147202 66.485074 + 10 1536.0 80.430545 ... 80.430545 78.643199 + 11 1664.0 62.929456 ... 63.372618 62.492442 12 1792.0 72.983276 ... 72.983276 59.154861 - 13 1920.0 68.776119 ... 71.626943 70.892307 - 14 2048.0 73.584279 ... 78.033565 76.959706 - 15 2176.0 83.155572 ... 87.494120 86.367588 - 16 2304.0 68.446623 ... 78.064941 77.057651 - 17 2432.0 71.305746 ... 86.179335 85.393507 - 18 2560.0 77.833728 ... 82.956960 81.715711 - 19 2688.0 83.737433 ... 91.185232 89.464755 - 20 2816.0 82.446516 ... 84.523664 83.712490 - 21 2944.0 81.967162 ... 83.758038 82.373605 - 22 3072.0 82.420822 ... 88.750943 86.579673 - 23 3200.0 81.528664 ... 91.233074 95.665176 - 24 3328.0 83.516586 ... 85.908470 83.323259 - 25 3456.0 81.435930 ... 92.138932 90.180725 - 26 3584.0 83.954614 ... 91.189190 95.858629 - 27 3712.0 85.822459 ... 83.806497 87.783251 - 28 3840.0 80.901241 ... 89.259080 89.548180 - 29 3968.0 87.913500 ... 92.829164 84.096442 - 30 4096.0 93.825748 ... 89.299883 90.139506 + 13 1920.0 69.120002 ... 71.257735 71.257735 + 14 2048.0 73.584279 ... 78.398206 77.314362 + 15 2176.0 83.155572 ... 87.494120 85.998493 + 16 2304.0 68.446623 ... 78.320893 77.558029 + 17 2432.0 71.305746 ... 86.711310 75.421383 + 18 2560.0 77.833728 ... 82.747477 81.715711 + 19 2688.0 83.552988 ... 90.532356 89.464755 + 20 2816.0 84.197315 ... 84.035084 84.035084 + 21 2944.0 82.784108 ... 83.969728 83.060049 + 22 3072.0 81.825298 ... 89.593522 88.473602 + 23 3200.0 84.768213 ... 96.096095 95.808380 + 24 3328.0 83.226931 ... 85.908470 84.596116 + 25 3456.0 81.766291 ... 91.824110 91.097818 + 26 3584.0 87.466332 ... 91.194972 94.847460 + 27 3712.0 85.822459 ... 87.246590 87.860458 + 28 3840.0 81.859361 ... 87.011801 90.168771 + 29 3968.0 89.921841 ... 91.954739 85.271796 + 30 4096.0 93.596744 ... 88.243079 90.382307 [31 rows x 5 columns] @@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 5 minutes 52.578 seconds) + **Total running time of the script:** ( 7 minutes 13.827 seconds) .. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py: diff --git a/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt b/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt index 0f9d8dc7a..edbf188ee 100644 --- a/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt +++ b/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt @@ -240,7 +240,7 @@ References .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 0 minutes 0.476 seconds) + **Total running time of the script:** ( 0 minutes 0.279 seconds) .. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py: diff --git a/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt b/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt index 8750ca528..295b2a90c 100644 --- a/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt +++ b/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt @@ -21,7 +21,7 @@ Layer Normalization ==================== -.. GENERATED FROM PYTHON SOURCE LINES 5-312 +.. GENERATED FROM PYTHON SOURCE LINES 5-316 @@ -40,34 +40,34 @@ Layer Normalization N Triton Torch Apex 0 1024.0 585.142849 277.694907 468.114273 1 1536.0 630.153868 323.368435 511.999982 - 2 2048.0 682.666643 334.367358 520.126988 - 3 2560.0 694.237267 365.714281 518.481028 - 4 3072.0 712.347810 378.092307 501.551037 - 5 3584.0 725.873439 384.859062 458.751978 - 6 4096.0 728.177767 381.023256 458.293714 - 7 4608.0 670.254540 396.387087 426.173427 - 8 5120.0 694.237267 397.669909 426.666652 - 9 5632.0 704.000002 396.969169 413.357796 - 10 6144.0 702.171410 402.885254 411.313806 + 2 2048.0 668.734716 337.814445 528.516136 + 3 2560.0 694.237267 362.477870 512.000013 + 4 3072.0 712.347810 375.206126 501.551037 + 5 3584.0 725.873439 384.859062 451.527536 + 6 4096.0 728.177767 381.023256 455.111095 + 7 4608.0 670.254540 396.387087 421.302872 + 8 5120.0 688.403381 395.748783 422.268057 + 9 5632.0 698.542675 396.969169 409.599997 + 10 6144.0 702.171410 402.885254 409.600010 11 6656.0 700.631610 400.360920 400.360920 - 12 7168.0 695.078767 396.844306 388.772874 - 13 7680.0 682.666656 393.846167 387.634072 - 14 8192.0 642.509816 393.609605 372.363633 - 15 8704.0 627.315309 389.005597 380.502740 - 16 9216.0 606.814809 407.337026 383.999986 - 17 9728.0 589.575753 409.599987 383.369452 - 18 10240.0 566.920437 408.578556 382.803739 - 19 10752.0 549.623009 411.559798 381.445676 - 20 11264.0 536.380957 406.826188 373.134567 - 21 11776.0 523.377770 410.492372 377.587162 - 22 12288.0 517.389457 414.784810 383.251457 - 23 12800.0 505.679014 410.420828 376.470582 - 24 13312.0 494.180982 405.699062 376.976995 - 25 13824.0 482.934503 411.888257 379.389355 - 26 14336.0 471.967074 406.695045 374.185964 - 27 14848.0 461.297068 408.192434 375.304904 - 28 15360.0 454.269882 406.214870 378.092307 - 29 15872.0 447.887117 407.627589 376.225175 + 12 7168.0 678.627194 386.154893 384.859062 + 13 7680.0 682.666656 391.337574 386.415087 + 14 8192.0 645.674867 390.095241 376.643677 + 15 8704.0 624.502255 390.095225 379.465939 + 16 9216.0 604.327881 405.098894 383.002605 + 17 9728.0 585.142883 409.599987 382.427505 + 18 10240.0 564.965524 409.600010 382.803739 + 19 10752.0 546.133312 410.577576 380.601764 + 20 11264.0 531.634232 395.228063 370.069806 + 21 11776.0 520.486200 409.599991 376.831982 + 22 12288.0 516.031509 413.911572 383.251457 + 23 12800.0 504.433489 410.420828 375.779805 + 24 13312.0 494.180982 405.699062 376.310952 + 25 13824.0 481.882350 411.888257 378.739711 + 26 14336.0 471.967074 401.709294 372.969090 + 27 14848.0 461.297068 407.492270 375.898745 + 28 15360.0 453.431739 406.887417 378.092307 + 29 15872.0 447.098578 406.323209 376.225175 @@ -204,17 +204,19 @@ Layer Normalization cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] - a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) - dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) - mean = tl.load(Mean + rows, mask=rows < M, other=0.) - rstd = tl.load(Var + rows, mask=rows < M, other=0.) - a_hat = (a - mean[:, None]) * rstd[:, None] - dw += dout * a_hat - db += dout + UNROLL: tl.constexpr = 4 + for i in range(0, M, BLOCK_SIZE_M * UNROLL): + for j in range(UNROLL): + rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) + dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) + mean = tl.load(Mean + rows, mask=rows < M, other=0.) + rstd = tl.load(Var + rows, mask=rows < M, other=0.) + a_hat = (a - mean[:, None]) * rstd[:, None] + dw += dout * a_hat + db += dout sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(DW + cols, sum_dw, mask=cols < N) @@ -287,7 +289,15 @@ Layer Normalization BLOCK_SIZE_N=ctx.BLOCK_SIZE, num_warps=ctx.num_warps, ) - # accumulate partial sums in separate kernel + if N > 10240: + BLOCK_SIZE_N = 128 + BLOCK_SIZE_M = 32 + num_warps = 4 + else: + # maximize occupancy for small N + BLOCK_SIZE_N = 16 + BLOCK_SIZE_M = 16 + num_warps = 8 grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] _layer_norm_bwd_dwdb[grid]( a, dout, @@ -296,17 +306,11 @@ Layer Normalization dbias, M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps ) - return (da, None, dweight, dbias, None, None, - None, None, None, None, - None, - None, None, None, - None, - None, None, None, - None, None, None, - None, None, None) + return (da, None, dweight, dbias, None) def layer_norm(a, normalized_shape, weight, bias, eps): @@ -389,7 +393,7 @@ Layer Normalization .. rst-class:: sphx-glr-timing - **Total running time of the script:** ( 5 minutes 24.641 seconds) + **Total running time of the script:** ( 5 minutes 32.552 seconds) .. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py: diff --git a/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt b/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt new file mode 100644 index 000000000..2876c4654 --- /dev/null +++ b/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt @@ -0,0 +1,416 @@ + +.. DO NOT EDIT. +.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. +.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: +.. "getting-started/tutorials/06-fused-attention.py" +.. LINE NUMBERS ARE GIVEN BELOW. + +.. only:: html + + .. note:: + :class: sphx-glr-download-link-note + + Click :ref:`here ` + to download the full example code + +.. rst-class:: sphx-glr-example-title + +.. _sphx_glr_getting-started_tutorials_06-fused-attention.py: + + +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) + +.. GENERATED FROM PYTHON SOURCE LINES 7-355 + + + + + + + +.. code-block:: default + + + import pytest + import torch + + import triton + import triton.language as tl + + + @triton.jit + def _fwd_kernel( + Q, K, V, sm_scale, + TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + + @triton.jit + def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + + @triton.jit + def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + + class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _fwd_kernel[grid]( + q, k, v, sm_scale, + tmp, L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=64, num_warps=4, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = 64 + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + num_stages=1, + ) + return dq, dk, dv, None + + + attention = _attention.apply + + + @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)]) + def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + triton.testing.assert_almost_equal(ref_out, tri_out) + triton.testing.assert_almost_equal(ref_dv, tri_dv) + triton.testing.assert_almost_equal(ref_dk, tri_dk) + triton.testing.assert_almost_equal(ref_dq, tri_dq) + + + try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True + except BaseException: + HAS_FLASH = False + + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 + # vary seq length for fixed head and batch=4 + configs = [triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 16)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} + ) for mode in ['bwd']] + + + @triton.testing.perf_report(configs) + def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + + # only works on A100 at the moment + # bench_flash_attention.run(save_path='.', print_data=True) + + +.. rst-class:: sphx-glr-timing + + **Total running time of the script:** ( 0 minutes 0.072 seconds) + + +.. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py: + + +.. only :: html + + .. container:: sphx-glr-footer + :class: sphx-glr-footer-example + + + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download Python source code: 06-fused-attention.py <06-fused-attention.py>` + + + + .. container:: sphx-glr-download sphx-glr-download-jupyter + + :download:`Download Jupyter notebook: 06-fused-attention.ipynb <06-fused-attention.ipynb>` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt b/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt new file mode 100644 index 000000000..5761392dc --- /dev/null +++ b/master/_sources/getting-started/tutorials/07-libdevice-function.rst.txt @@ -0,0 +1,183 @@ + +.. DO NOT EDIT. +.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. +.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: +.. "getting-started/tutorials/07-libdevice-function.py" +.. LINE NUMBERS ARE GIVEN BELOW. + +.. only:: html + + .. note:: + :class: sphx-glr-download-link-note + + Click :ref:`here ` + to download the full example code + +.. rst-class:: sphx-glr-example-title + +.. _sphx_glr_getting-started_tutorials_07-libdevice-function.py: + + +Libdevice function +=============== +Triton can invoke a custom function from an external library. +In this example, we will use the `libdevice` library to apply `asin` on a tensor. +Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions. + +In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. +For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. +Using triton, you can simply call `tl.libdevice.asinf`. +triton automatically selects the correct underlying device function to invoke based on input and output types. + +.. GENERATED FROM PYTHON SOURCE LINES 15-17 + +asin Kernel +-------------------------- + +.. GENERATED FROM PYTHON SOURCE LINES 17-39 + +.. code-block:: default + + + import torch + + import triton + import triton.language as tl + + + @triton.jit + def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + + + + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 40-43 + +Using the default libdevice library path +-------------------------- +We can use the default libdevice library path encoded in `triton/language/libdevice.py` + +.. GENERATED FROM PYTHON SOURCE LINES 43-61 + +.. code-block:: default + + + + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device='cuda') + output_triton = torch.zeros(size, device='cuda') + output_torch = torch.asin(x) + assert x.is_cuda and output_triton.is_cuda + n_elements = output_torch.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) + print(output_torch) + print(output_triton) + print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' + ) + + + + + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') + tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') + The maximum difference between torch and triton is 2.384185791015625e-07 + + + + +.. GENERATED FROM PYTHON SOURCE LINES 62-65 + +Customize the libdevice library path +-------------------------- +We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. + +.. GENERATED FROM PYTHON SOURCE LINES 65-75 + +.. code-block:: default + + + output_triton = torch.empty_like(x) + asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, + extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) + print(output_torch) + print(output_triton) + print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' + ) + + + + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') + tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') + The maximum difference between torch and triton is 2.384185791015625e-07 + + + + + +.. rst-class:: sphx-glr-timing + + **Total running time of the script:** ( 0 minutes 0.501 seconds) + + +.. _sphx_glr_download_getting-started_tutorials_07-libdevice-function.py: + + +.. only :: html + + .. container:: sphx-glr-footer + :class: sphx-glr-footer-example + + + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download Python source code: 07-libdevice-function.py <07-libdevice-function.py>` + + + + .. container:: sphx-glr-download sphx-glr-download-jupyter + + :download:`Download Jupyter notebook: 07-libdevice-function.ipynb <07-libdevice-function.ipynb>` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/master/_sources/getting-started/tutorials/index.rst.txt b/master/_sources/getting-started/tutorials/index.rst.txt index c8a39cdaf..ecb119afe 100644 --- a/master/_sources/getting-started/tutorials/index.rst.txt +++ b/master/_sources/getting-started/tutorials/index.rst.txt @@ -122,6 +122,48 @@ To install the dependencies for the tutorials: :hidden: /getting-started/tutorials/05-layer-norm + +.. raw:: html + +
+ +.. only:: html + + .. figure:: /getting-started/tutorials/images/thumb/sphx_glr_06-fused-attention_thumb.png + :alt: Fused Attention + + :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` + +.. raw:: html + +
+ + +.. toctree:: + :hidden: + + /getting-started/tutorials/06-fused-attention + +.. raw:: html + +
+ +.. only:: html + + .. figure:: /getting-started/tutorials/images/thumb/sphx_glr_07-libdevice-function_thumb.png + :alt: Libdevice function + + :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` + +.. raw:: html + +
+ + +.. toctree:: + :hidden: + + /getting-started/tutorials/07-libdevice-function .. raw:: html
diff --git a/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt b/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt index 7e379c471..44bfca6b7 100644 --- a/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt +++ b/master/_sources/getting-started/tutorials/sg_execution_times.rst.txt @@ -5,16 +5,20 @@ Computation times ================= -**16:10.599** total execution time for **getting-started_tutorials** files: +**18:09.339** total execution time for **getting-started_tutorials** files: +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 05:52.578 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 07:13.827 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:24.641 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:32.552 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:18.076 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:32.089 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:34.829 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:50.020 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.476 | 0.0 MB | +| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.501 | 0.0 MB | ++---------------------------------------------------------------------------------------------------------+-----------+--------+ +| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.279 | 0.0 MB | ++---------------------------------------------------------------------------------------------------------+-----------+--------+ +| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.072 | 0.0 MB | +---------------------------------------------------------------------------------------------------------+-----------+--------+ diff --git a/master/getting-started/tutorials/01-vector-add.html b/master/getting-started/tutorials/01-vector-add.html index 4447503c2..3917cd1d1 100644 --- a/master/getting-started/tutorials/01-vector-add.html +++ b/master/getting-started/tutorials/01-vector-add.html @@ -105,6 +105,8 @@
  • Matrix Multiplication
  • Low-Memory Dropout
  • Layer Normalization
  • +
  • Fused Attention
  • +
  • Libdevice function
  • @@ -328,7 +330,7 @@ for different problem sizes.

    3 32768.0 76.800002 76.800002 4 65536.0 127.999995 127.999995 5 131072.0 219.428568 219.428568 -6 262144.0 341.333321 384.000001 +6 262144.0 341.333321 341.333321 7 524288.0 472.615390 472.615390 8 1048576.0 614.400016 614.400016 9 2097152.0 722.823517 722.823517 @@ -340,7 +342,7 @@ for different problem sizes.

    15 134217728.0 849.737435 850.656574 -

    Total running time of the script: ( 1 minutes 34.829 seconds)

    +

    Total running time of the script: ( 1 minutes 50.020 seconds)