diff --git a/master/.buildinfo b/master/.buildinfo
index 089f93c02..9e27b9682 100644
--- a/master/.buildinfo
+++ b/master/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 255e5b7e27427f0ab6fda308ee6aef63
+config: 8d52c5eda79abb41e578ed40b306519c
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/master/.doctrees/environment.pickle b/master/.doctrees/environment.pickle
index 9b5c9c2d3..228c82170 100644
Binary files a/master/.doctrees/environment.pickle and b/master/.doctrees/environment.pickle differ
diff --git a/master/.doctrees/getting-started/installation.doctree b/master/.doctrees/getting-started/installation.doctree
index 6190cd7e2..27f8fc457 100644
Binary files a/master/.doctrees/getting-started/installation.doctree and b/master/.doctrees/getting-started/installation.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/01-vector-add.doctree b/master/.doctrees/getting-started/tutorials/01-vector-add.doctree
index 62d77901f..3b3a60da3 100644
Binary files a/master/.doctrees/getting-started/tutorials/01-vector-add.doctree and b/master/.doctrees/getting-started/tutorials/01-vector-add.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree b/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree
index 72c82fd7d..2fd1643ea 100644
Binary files a/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree and b/master/.doctrees/getting-started/tutorials/02-fused-softmax.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree b/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree
index 52d467dbb..405936c08 100644
Binary files a/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree and b/master/.doctrees/getting-started/tutorials/03-matrix-multiplication.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree b/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree
index 89a1a4b75..48060c346 100644
Binary files a/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree and b/master/.doctrees/getting-started/tutorials/04-low-memory-dropout.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree b/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree
index cbe9fac56..122cb1c5b 100644
Binary files a/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree and b/master/.doctrees/getting-started/tutorials/05-layer-norm.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree b/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree
new file mode 100644
index 000000000..a19207450
Binary files /dev/null and b/master/.doctrees/getting-started/tutorials/06-fused-attention.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree b/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree
new file mode 100644
index 000000000..b1970bcf6
Binary files /dev/null and b/master/.doctrees/getting-started/tutorials/07-libdevice-function.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/index.doctree b/master/.doctrees/getting-started/tutorials/index.doctree
index dd4cc705b..26b6f1bdb 100644
Binary files a/master/.doctrees/getting-started/tutorials/index.doctree and b/master/.doctrees/getting-started/tutorials/index.doctree differ
diff --git a/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree b/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree
index 34eeecbd3..2740ae711 100644
Binary files a/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree and b/master/.doctrees/getting-started/tutorials/sg_execution_times.doctree differ
diff --git a/master/.doctrees/index.doctree b/master/.doctrees/index.doctree
index abcb3bfaa..b0b8ae7db 100644
Binary files a/master/.doctrees/index.doctree and b/master/.doctrees/index.doctree differ
diff --git a/master/.doctrees/programming-guide/chapter-1/introduction.doctree b/master/.doctrees/programming-guide/chapter-1/introduction.doctree
index 8828c92c6..e6fafc875 100644
Binary files a/master/.doctrees/programming-guide/chapter-1/introduction.doctree and b/master/.doctrees/programming-guide/chapter-1/introduction.doctree differ
diff --git a/master/.doctrees/programming-guide/chapter-2/related-work.doctree b/master/.doctrees/programming-guide/chapter-2/related-work.doctree
index 14fe22504..bfc855a8e 100644
Binary files a/master/.doctrees/programming-guide/chapter-2/related-work.doctree and b/master/.doctrees/programming-guide/chapter-2/related-work.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.Config.doctree b/master/.doctrees/python-api/generated/triton.Config.doctree
index af4eee914..840d99316 100644
Binary files a/master/.doctrees/python-api/generated/triton.Config.doctree and b/master/.doctrees/python-api/generated/triton.Config.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.autotune.doctree b/master/.doctrees/python-api/generated/triton.autotune.doctree
index 961b39c57..147595499 100644
Binary files a/master/.doctrees/python-api/generated/triton.autotune.doctree and b/master/.doctrees/python-api/generated/triton.autotune.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.heuristics.doctree b/master/.doctrees/python-api/generated/triton.heuristics.doctree
index 95f4939bd..e13bad010 100644
Binary files a/master/.doctrees/python-api/generated/triton.heuristics.doctree and b/master/.doctrees/python-api/generated/triton.heuristics.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.jit.doctree b/master/.doctrees/python-api/generated/triton.jit.doctree
index e28f973b0..c585e9642 100644
Binary files a/master/.doctrees/python-api/generated/triton.jit.doctree and b/master/.doctrees/python-api/generated/triton.jit.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.arange.doctree b/master/.doctrees/python-api/generated/triton.language.arange.doctree
index f6e0c711e..328bffa17 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.arange.doctree and b/master/.doctrees/python-api/generated/triton.language.arange.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree
index aed2fb893..d20fcb9f8 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_add.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree
index 47db74f71..7f17e9888 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_and.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree
index 109a81608..bffe0c6e7 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_cas.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree
index c96c0b82f..a4704bc1e 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_max.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree
index d7b752ea3..0f426700b 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_min.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree
index be238e853..57b07b41c 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_or.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree
index 64d5a0847..95149aa00 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_xchg.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree b/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree
index 9dbf4a3c8..ae15a586d 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree and b/master/.doctrees/python-api/generated/triton.language.atomic_xor.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree b/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree
index 939971777..358ab043e 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree and b/master/.doctrees/python-api/generated/triton.language.broadcast_to.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.cos.doctree b/master/.doctrees/python-api/generated/triton.language.cos.doctree
index 49eb4707b..504e5e619 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.cos.doctree and b/master/.doctrees/python-api/generated/triton.language.cos.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.dot.doctree b/master/.doctrees/python-api/generated/triton.language.dot.doctree
index 3870262c6..ae6f8e68a 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.dot.doctree and b/master/.doctrees/python-api/generated/triton.language.dot.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.exp.doctree b/master/.doctrees/python-api/generated/triton.language.exp.doctree
index 767f42eb5..54536be64 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.exp.doctree and b/master/.doctrees/python-api/generated/triton.language.exp.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.load.doctree b/master/.doctrees/python-api/generated/triton.language.load.doctree
index 7caedc95b..ca2ae6e0f 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.load.doctree and b/master/.doctrees/python-api/generated/triton.language.load.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.log.doctree b/master/.doctrees/python-api/generated/triton.language.log.doctree
index d0fd10524..d53a32dd4 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.log.doctree and b/master/.doctrees/python-api/generated/triton.language.log.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.max.doctree b/master/.doctrees/python-api/generated/triton.language.max.doctree
index 42d6c57df..b19b58c84 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.max.doctree and b/master/.doctrees/python-api/generated/triton.language.max.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.maximum.doctree b/master/.doctrees/python-api/generated/triton.language.maximum.doctree
index cc7f88946..7c78bb4a8 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.maximum.doctree and b/master/.doctrees/python-api/generated/triton.language.maximum.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.min.doctree b/master/.doctrees/python-api/generated/triton.language.min.doctree
index 4976ec400..4e2142cea 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.min.doctree and b/master/.doctrees/python-api/generated/triton.language.min.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.minimum.doctree b/master/.doctrees/python-api/generated/triton.language.minimum.doctree
index 7a181d868..ebbf565ba 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.minimum.doctree and b/master/.doctrees/python-api/generated/triton.language.minimum.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree b/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree
index 804ddc047..fc6f7e4fc 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree and b/master/.doctrees/python-api/generated/triton.language.multiple_of.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.num_programs.doctree b/master/.doctrees/python-api/generated/triton.language.num_programs.doctree
index ebe0dd843..6a8bc44b3 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.num_programs.doctree and b/master/.doctrees/python-api/generated/triton.language.num_programs.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.program_id.doctree b/master/.doctrees/python-api/generated/triton.language.program_id.doctree
index 0b28f85ff..703f5d0be 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.program_id.doctree and b/master/.doctrees/python-api/generated/triton.language.program_id.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.rand.doctree b/master/.doctrees/python-api/generated/triton.language.rand.doctree
index 6232ace14..a73f482fb 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.rand.doctree and b/master/.doctrees/python-api/generated/triton.language.rand.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.randint.doctree b/master/.doctrees/python-api/generated/triton.language.randint.doctree
index 27d4373a6..9c8214953 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.randint.doctree and b/master/.doctrees/python-api/generated/triton.language.randint.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.randint4x.doctree b/master/.doctrees/python-api/generated/triton.language.randint4x.doctree
index bb31c78bc..772709f0d 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.randint4x.doctree and b/master/.doctrees/python-api/generated/triton.language.randint4x.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.randn.doctree b/master/.doctrees/python-api/generated/triton.language.randn.doctree
index 281efe064..4fd8f42ea 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.randn.doctree and b/master/.doctrees/python-api/generated/triton.language.randn.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.ravel.doctree b/master/.doctrees/python-api/generated/triton.language.ravel.doctree
index bfe37463c..954262cd4 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.ravel.doctree and b/master/.doctrees/python-api/generated/triton.language.ravel.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.reshape.doctree b/master/.doctrees/python-api/generated/triton.language.reshape.doctree
index 9d6707e69..05f3e82c0 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.reshape.doctree and b/master/.doctrees/python-api/generated/triton.language.reshape.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree b/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree
index 1d7385111..6c4a18f65 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree and b/master/.doctrees/python-api/generated/triton.language.sigmoid.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.sin.doctree b/master/.doctrees/python-api/generated/triton.language.sin.doctree
index 53c9b4122..cb427d01d 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.sin.doctree and b/master/.doctrees/python-api/generated/triton.language.sin.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.softmax.doctree b/master/.doctrees/python-api/generated/triton.language.softmax.doctree
index 1d82361ea..2d5becdfb 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.softmax.doctree and b/master/.doctrees/python-api/generated/triton.language.softmax.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.sqrt.doctree b/master/.doctrees/python-api/generated/triton.language.sqrt.doctree
index 34f3cc3ed..3828ac4fd 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.sqrt.doctree and b/master/.doctrees/python-api/generated/triton.language.sqrt.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.store.doctree b/master/.doctrees/python-api/generated/triton.language.store.doctree
index a17c57279..02efc289a 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.store.doctree and b/master/.doctrees/python-api/generated/triton.language.store.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.sum.doctree b/master/.doctrees/python-api/generated/triton.language.sum.doctree
index 2e1d42093..311694380 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.sum.doctree and b/master/.doctrees/python-api/generated/triton.language.sum.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.where.doctree b/master/.doctrees/python-api/generated/triton.language.where.doctree
index 8093efb0f..f5bca8b4b 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.where.doctree and b/master/.doctrees/python-api/generated/triton.language.where.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.language.zeros.doctree b/master/.doctrees/python-api/generated/triton.language.zeros.doctree
index f32f1b54d..be2e95ee7 100644
Binary files a/master/.doctrees/python-api/generated/triton.language.zeros.doctree and b/master/.doctrees/python-api/generated/triton.language.zeros.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree b/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree
index 193618794..3d7eec07c 100644
Binary files a/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree and b/master/.doctrees/python-api/generated/triton.testing.Benchmark.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree b/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree
index 3b555cbef..64c6facdd 100644
Binary files a/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree and b/master/.doctrees/python-api/generated/triton.testing.do_bench.doctree differ
diff --git a/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree b/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree
index fd490963c..b1e92ec76 100644
Binary files a/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree and b/master/.doctrees/python-api/generated/triton.testing.perf_report.doctree differ
diff --git a/master/.doctrees/python-api/triton.doctree b/master/.doctrees/python-api/triton.doctree
index 666b7335d..e8dd455f0 100644
Binary files a/master/.doctrees/python-api/triton.doctree and b/master/.doctrees/python-api/triton.doctree differ
diff --git a/master/.doctrees/python-api/triton.language.doctree b/master/.doctrees/python-api/triton.language.doctree
index 0ac2e2ebc..214cb8df7 100644
Binary files a/master/.doctrees/python-api/triton.language.doctree and b/master/.doctrees/python-api/triton.language.doctree differ
diff --git a/master/.doctrees/python-api/triton.testing.doctree b/master/.doctrees/python-api/triton.testing.doctree
index 69db5f8ce..562048866 100644
Binary files a/master/.doctrees/python-api/triton.testing.doctree and b/master/.doctrees/python-api/triton.testing.doctree differ
diff --git a/master/_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb b/master/_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb
new file mode 100644
index 000000000..d3b665487
--- /dev/null
+++ b/master/_downloads/1bc2e471d2fb0ec017c4d1d0890db4e2/07-libdevice-function.ipynb
@@ -0,0 +1,97 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Libdevice function\nTriton can invoke a custom function from an external library.\nIn this example, we will use the `libdevice` library to apply `asin` on a tensor.\nPlease refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.\n\nIn `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.\nFor example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.\nUsing triton, you can simply call `tl.libdevice.asinf`.\ntriton automatically selects the correct underlying device function to invoke based on input and output types.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## asin Kernel\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef asin_kernel(\n x_ptr,\n y_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n x = tl.libdevice.asin(x)\n tl.store(y_ptr + offsets, x, mask=mask)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Using the default libdevice library path\nWe can use the default libdevice library path encoded in `triton/language/libdevice.py`\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(0)\nsize = 98432\nx = torch.rand(size, device='cuda')\noutput_triton = torch.zeros(size, device='cuda')\noutput_torch = torch.asin(x)\nassert x.is_cuda and output_triton.is_cuda\nn_elements = output_torch.numel()\ngrid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Customize the libdevice library path\nWe can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "output_triton = torch.empty_like(x)\nasin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,\n extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})\nprint(output_torch)\nprint(output_triton)\nprint(\n f'The maximum difference between torch and triton is '\n f'{torch.max(torch.abs(output_torch - output_triton))}'\n)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb b/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb
new file mode 100644
index 000000000..f3bc0d2e5
--- /dev/null
+++ b/master/_downloads/3176accb6c7288b0e45f41d94eebacb9/06-fused-attention.ipynb
@@ -0,0 +1,54 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Fused Attention\nThis is a Triton implementation of the Flash Attention algorithm \n(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "import pytest\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n t_ptrs = TMP + off_hz * N_CTX + offs_m\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs)\n # loop over k, v and update accumulator\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + start_n * stride_kn)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k, trans_b=True)\n qk *= sm_scale\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n tl.store(t_ptrs, acc_scale)\n acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + start_n * stride_vk)\n p = p.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_i)\n tl.store(m_ptrs, m_i)\n # initialize pointers to output\n offs_n = tl.arange(0, BLOCK_DMODEL)\n off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # offset pointers for batch/head\n Q += off_z * stride_qz + off_h * stride_qh\n K += off_z * stride_qz + off_h * stride_qh\n V += off_z * stride_qz + off_h * stride_qh\n DO += off_z * stride_qz + off_h * stride_qh\n DQ += off_z * stride_qz + off_h * stride_qh\n DK += off_z * stride_qz + off_h * stride_qh\n DV += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_ptrs)\n v = tl.load(v_ptrs)\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_ptrs)\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, k, trans_b=True)\n qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_ptrs)\n dv += tl.dot(p.to(tl.float16), do, trans_a=True)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, v, trans_b=True)\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(ds.to(tl.float16), q, trans_a=True)\n # # compute dq\n dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n # # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_ptrs += BLOCK_M * stride_qm\n do_ptrs += BLOCK_M * stride_qm\n # write-back\n dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n tl.store(dv_ptrs, dv)\n tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n tmp, L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=64, num_warps=4,\n num_stages=1,\n )\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.BLOCK = BLOCK\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = 64\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n ctx.grid[0],\n BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n\n\n@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])\ndef test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):\n torch.manual_seed(20)\n q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\").normal_(mean=0, std=.5).requires_grad_()\n sm_scale = 0.3\n dout = torch.randn_like(q)\n # reference implementation\n M = torch.tril(torch.ones((N_CTX, N_CTX), device=\"cuda\"))\n p = torch.matmul(q, k.transpose(2, 3)) * sm_scale\n for z in range(Z):\n for h in range(H):\n p[:, :, M == 0] = float(\"-inf\")\n p = torch.softmax(p.float(), dim=-1).half()\n ref_out = torch.matmul(p, v)\n ref_out.backward(dout)\n ref_dv, v.grad = v.grad.clone(), None\n ref_dk, k.grad = k.grad.clone(), None\n ref_dq, q.grad = q.grad.clone(), None\n # triton implementation\n tri_out = attention(q, k, v, sm_scale)\n tri_out.backward(dout)\n tri_dv, v.grad = v.grad.clone(), None\n tri_dk, k.grad = k.grad.clone(), None\n tri_dq, q.grad = q.grad.clone(), None\n # compare\n triton.testing.assert_almost_equal(ref_out, tri_out)\n triton.testing.assert_almost_equal(ref_dv, tri_dv)\n triton.testing.assert_almost_equal(ref_dk, tri_dk)\n triton.testing.assert_almost_equal(ref_dq, tri_dq)\n\n\ntry:\n from flash_attn.flash_attn_interface import flash_attn_func\n HAS_FLASH = True\nexcept BaseException:\n HAS_FLASH = False\n\nBATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64\n# vary seq length for fixed head and batch=4\nconfigs = [triton.testing.Benchmark(\n x_names=['N_CTX'],\n x_vals=[2**i for i in range(10, 16)],\n line_arg='provider',\n line_vals=['triton'] + (['flash'] if HAS_FLASH else []),\n line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),\n styles=[('red', '-'), ('blue', '-')],\n ylabel='ms',\n plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',\n args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}\n) for mode in ['bwd']]\n\n\n@triton.testing.perf_report(configs)\ndef bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device=\"cuda\"):\n assert mode in ['fwd', 'bwd']\n warmup = 25\n rep = 100\n if provider == \"triton\":\n q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=\"cuda\", requires_grad=True)\n sm_scale = 1.3\n fn = lambda: attention(q, k, v, sm_scale)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n if provider == \"flash\":\n lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)\n cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)\n cu_seqlens[1:] = lengths.cumsum(0)\n qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)\n fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)\n if mode == 'bwd':\n o = fn()\n do = torch.randn_like(o)\n fn = lambda: o.backward(do, retain_graph=True)\n ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)\n return ms\n\n# only works on A100 at the moment\n# bench_flash_attention.run(save_path='.', print_data=True)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/master/_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py b/master/_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py
new file mode 100644
index 000000000..bb5f7b26d
--- /dev/null
+++ b/master/_downloads/3ff29f967ace7985da24aab10352fc76/07-libdevice-function.py
@@ -0,0 +1,74 @@
+"""
+Libdevice function
+===============
+Triton can invoke a custom function from an external library.
+In this example, we will use the `libdevice` library to apply `asin` on a tensor.
+Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
+
+In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
+For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
+Using triton, you can simply call `tl.libdevice.asinf`.
+triton automatically selects the correct underlying device function to invoke based on input and output types.
+"""
+
+# %%
+# asin Kernel
+# --------------------------
+
+import torch
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def asin_kernel(
+ x_ptr,
+ y_ptr,
+ n_elements,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ x = tl.libdevice.asin(x)
+ tl.store(y_ptr + offsets, x, mask=mask)
+
+# %%
+# Using the default libdevice library path
+# --------------------------
+# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
+
+
+torch.manual_seed(0)
+size = 98432
+x = torch.rand(size, device='cuda')
+output_triton = torch.zeros(size, device='cuda')
+output_torch = torch.asin(x)
+assert x.is_cuda and output_triton.is_cuda
+n_elements = output_torch.numel()
+grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
+print(output_torch)
+print(output_triton)
+print(
+ f'The maximum difference between torch and triton is '
+ f'{torch.max(torch.abs(output_torch - output_triton))}'
+)
+
+# %%
+# Customize the libdevice library path
+# --------------------------
+# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
+
+output_triton = torch.empty_like(x)
+asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
+ extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
+print(output_torch)
+print(output_triton)
+print(
+ f'The maximum difference between torch and triton is '
+ f'{torch.max(torch.abs(output_torch - output_triton))}'
+)
diff --git a/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py b/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py
new file mode 100644
index 000000000..c19ee498a
--- /dev/null
+++ b/master/_downloads/54a35f6ec55f9746935b9566fb6bb1df/06-fused-attention.py
@@ -0,0 +1,354 @@
+"""
+Fused Attention
+===============
+This is a Triton implementation of the Flash Attention algorithm
+(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
+"""
+
+import pytest
+import torch
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel(
+ Q, K, V, sm_scale,
+ TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
+ Out,
+ stride_qz, stride_qh, stride_qm, stride_qk,
+ stride_kz, stride_kh, stride_kn, stride_kk,
+ stride_vz, stride_vh, stride_vk, stride_vn,
+ stride_oz, stride_oh, stride_om, stride_on,
+ Z, H, N_CTX,
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ start_m = tl.program_id(0)
+ off_hz = tl.program_id(1)
+ # initialize offsets
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+ off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
+ off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
+ # Initialize pointers to Q, K, V
+ q_ptrs = Q + off_q
+ k_ptrs = K + off_k
+ v_ptrs = V + off_v
+ # initialize pointer to m and l
+ t_ptrs = TMP + off_hz * N_CTX + offs_m
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+ # load q: it will stay in SRAM throughout
+ q = tl.load(q_ptrs)
+ # loop over k, v and update accumulator
+ for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ # -- compute qk ----
+ k = tl.load(k_ptrs + start_n * stride_kn)
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k, trans_b=True)
+ qk *= sm_scale
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
+ # -- compute m_ij, p, l_ij
+ m_ij = tl.max(qk, 1)
+ p = tl.exp(qk - m_ij[:, None])
+ l_ij = tl.sum(p, 1)
+ # -- update m_i and l_i
+ m_i_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_i_new)
+ beta = tl.exp(m_ij - m_i_new)
+ l_i_new = alpha * l_i + beta * l_ij
+ # -- update output accumulator --
+ # scale p
+ p_scale = beta / l_i_new
+ p = p * p_scale[:, None]
+ # scale acc
+ acc_scale = l_i / l_i_new * alpha
+ tl.store(t_ptrs, acc_scale)
+ acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
+ acc = acc * acc_scale[:, None]
+ # update acc
+ v = tl.load(v_ptrs + start_n * stride_vk)
+ p = p.to(tl.float16)
+ acc += tl.dot(p, v)
+ # update m_i and l_i
+ l_i = l_i_new
+ m_i = m_i_new
+ # rematerialize offsets to save registers
+ start_m = tl.program_id(0)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ # write back l and m
+ l_ptrs = L + off_hz * N_CTX + offs_m
+ m_ptrs = M + off_hz * N_CTX + offs_m
+ tl.store(l_ptrs, l_i)
+ tl.store(m_ptrs, m_i)
+ # initialize pointers to output
+ offs_n = tl.arange(0, BLOCK_DMODEL)
+ off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+ out_ptrs = Out + off_o
+ tl.store(out_ptrs, acc)
+
+
+@triton.jit
+def _bwd_preprocess(
+ Out, DO, L,
+ NewDO, Delta,
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
+):
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = tl.arange(0, D_HEAD)
+ # load
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
+ denom = tl.load(L + off_m).to(tl.float32)
+ # compute
+ do = do / denom[:, None]
+ delta = tl.sum(o * do, axis=1)
+ # write-back
+ tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
+ tl.store(Delta + off_m, delta)
+
+
+@triton.jit
+def _bwd_kernel(
+ Q, K, V, sm_scale, Out, DO,
+ DQ, DK, DV,
+ L, M,
+ D,
+ stride_qz, stride_qh, stride_qm, stride_qk,
+ stride_kz, stride_kh, stride_kn, stride_kk,
+ stride_vz, stride_vh, stride_vk, stride_vn,
+ Z, H, N_CTX,
+ num_block,
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ off_hz = tl.program_id(0)
+ off_z = off_hz // H
+ off_h = off_hz % H
+ # offset pointers for batch/head
+ Q += off_z * stride_qz + off_h * stride_qh
+ K += off_z * stride_qz + off_h * stride_qh
+ V += off_z * stride_qz + off_h * stride_qh
+ DO += off_z * stride_qz + off_h * stride_qh
+ DQ += off_z * stride_qz + off_h * stride_qh
+ DK += off_z * stride_qz + off_h * stride_qh
+ DV += off_z * stride_qz + off_h * stride_qh
+ for start_n in range(0, num_block):
+ lo = start_n * BLOCK_M
+ # initialize row/col offsets
+ offs_qm = lo + tl.arange(0, BLOCK_M)
+ offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_m = tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_DMODEL)
+ # initialize pointers to value-like data
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
+ v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+ do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+ # pointer to row-wise quantities in value-like data
+ D_ptrs = D + off_hz * N_CTX
+ m_ptrs = M + off_hz * N_CTX
+ # initialize dv amd dk
+ dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+ dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+ # k and v stay in SRAM throughout
+ k = tl.load(k_ptrs)
+ v = tl.load(v_ptrs)
+ # loop over rows
+ for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
+ offs_m_curr = start_m + offs_m
+ # load q, k, v, do on-chip
+ q = tl.load(q_ptrs)
+ # recompute p = softmax(qk, dim=-1).T
+ # NOTE: `do` is pre-divided by `l`; no normalization here
+ qk = tl.dot(q, k, trans_b=True)
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
+ m = tl.load(m_ptrs + offs_m_curr)
+ p = tl.exp(qk * sm_scale - m[:, None])
+ # compute dv
+ do = tl.load(do_ptrs)
+ dv += tl.dot(p.to(tl.float16), do, trans_a=True)
+ # compute dp = dot(v, do)
+ Di = tl.load(D_ptrs + offs_m_curr)
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
+ dp += tl.dot(do, v, trans_b=True)
+ # compute ds = p * (dp - delta[:, None])
+ ds = p * dp * sm_scale
+ # compute dk = dot(ds.T, q)
+ dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
+ # # compute dq
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
+ dq += tl.dot(ds.to(tl.float16), k)
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
+ # # increment pointers
+ dq_ptrs += BLOCK_M * stride_qm
+ q_ptrs += BLOCK_M * stride_qm
+ do_ptrs += BLOCK_M * stride_qm
+ # write-back
+ dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+ dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
+ tl.store(dv_ptrs, dv)
+ tl.store(dk_ptrs, dk)
+
+
+class _attention(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, q, k, v, sm_scale):
+ BLOCK = 128
+ # shape constraints
+ Lq, Lk = q.shape[-1], k.shape[-1]
+ assert Lq == Lk
+ o = torch.empty_like(q)
+ grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
+ tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
+ L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
+ m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
+ _fwd_kernel[grid](
+ q, k, v, sm_scale,
+ tmp, L, m,
+ o,
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
+ q.shape[0], q.shape[1], q.shape[2],
+ BLOCK_M=BLOCK, BLOCK_N=BLOCK,
+ BLOCK_DMODEL=64, num_warps=4,
+ num_stages=1,
+ )
+ ctx.save_for_backward(q, k, v, o, L, m)
+ ctx.BLOCK = BLOCK
+ ctx.grid = grid
+ ctx.sm_scale = sm_scale
+ ctx.BLOCK_DMODEL = 64
+ return o
+
+ @staticmethod
+ def backward(ctx, do):
+ q, k, v, o, l, m = ctx.saved_tensors
+ do = do.contiguous()
+ dq = torch.zeros_like(q, dtype=torch.float32)
+ dk = torch.empty_like(k)
+ dv = torch.empty_like(v)
+ do_scaled = torch.empty_like(do)
+ delta = torch.empty_like(l)
+ _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
+ o, do, l,
+ do_scaled, delta,
+ BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
+ )
+ _bwd_kernel[(ctx.grid[1],)](
+ q, k, v, ctx.sm_scale,
+ o, do_scaled,
+ dq, dk, dv,
+ l, m,
+ delta,
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
+ q.shape[0], q.shape[1], q.shape[2],
+ ctx.grid[0],
+ BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
+ num_stages=1,
+ )
+ return dq, dk, dv, None
+
+
+attention = _attention.apply
+
+
+@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
+def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
+ torch.manual_seed(20)
+ q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+ k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+ v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+ sm_scale = 0.3
+ dout = torch.randn_like(q)
+ # reference implementation
+ M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
+ p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
+ for z in range(Z):
+ for h in range(H):
+ p[:, :, M == 0] = float("-inf")
+ p = torch.softmax(p.float(), dim=-1).half()
+ ref_out = torch.matmul(p, v)
+ ref_out.backward(dout)
+ ref_dv, v.grad = v.grad.clone(), None
+ ref_dk, k.grad = k.grad.clone(), None
+ ref_dq, q.grad = q.grad.clone(), None
+ # triton implementation
+ tri_out = attention(q, k, v, sm_scale)
+ tri_out.backward(dout)
+ tri_dv, v.grad = v.grad.clone(), None
+ tri_dk, k.grad = k.grad.clone(), None
+ tri_dq, q.grad = q.grad.clone(), None
+ # compare
+ triton.testing.assert_almost_equal(ref_out, tri_out)
+ triton.testing.assert_almost_equal(ref_dv, tri_dv)
+ triton.testing.assert_almost_equal(ref_dk, tri_dk)
+ triton.testing.assert_almost_equal(ref_dq, tri_dq)
+
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_func
+ HAS_FLASH = True
+except BaseException:
+ HAS_FLASH = False
+
+BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
+# vary seq length for fixed head and batch=4
+configs = [triton.testing.Benchmark(
+ x_names=['N_CTX'],
+ x_vals=[2**i for i in range(10, 16)],
+ line_arg='provider',
+ line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
+ line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
+ styles=[('red', '-'), ('blue', '-')],
+ ylabel='ms',
+ plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
+ args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
+) for mode in ['bwd']]
+
+
+@triton.testing.perf_report(configs)
+def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
+ assert mode in ['fwd', 'bwd']
+ warmup = 25
+ rep = 100
+ if provider == "triton":
+ q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ sm_scale = 1.3
+ fn = lambda: attention(q, k, v, sm_scale)
+ if mode == 'bwd':
+ o = fn()
+ do = torch.randn_like(o)
+ fn = lambda: o.backward(do, retain_graph=True)
+ ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
+ return ms
+ if provider == "flash":
+ lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
+ cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
+ cu_seqlens[1:] = lengths.cumsum(0)
+ qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
+ fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
+ if mode == 'bwd':
+ o = fn()
+ do = torch.randn_like(o)
+ fn = lambda: o.backward(do, retain_graph=True)
+ ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
+ return ms
+
+# only works on A100 at the moment
+# bench_flash_attention.run(save_path='.', print_data=True)
diff --git a/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip b/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip
index 6dc0a1b1e..25a09311b 100644
Binary files a/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip and b/master/_downloads/662999063954282841dc90b8945f85ce/tutorials_jupyter.zip differ
diff --git a/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip b/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip
index e46aa1875..cafc41408 100644
Binary files a/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip and b/master/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip differ
diff --git a/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py b/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py
index 9880b428f..333cb80ec 100644
--- a/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py
+++ b/master/_downloads/935c0dd0fbeb4b2e69588471cbb2d4b2/05-layer-norm.py
@@ -128,17 +128,19 @@ def _layer_norm_bwd_dwdb(
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for i in range(0, M, BLOCK_SIZE_M):
- rows = i + tl.arange(0, BLOCK_SIZE_M)
- mask = (rows[:, None] < M) & (cols[None, :] < N)
- offs = rows[:, None] * N + cols[None, :]
- a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
- dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
- mean = tl.load(Mean + rows, mask=rows < M, other=0.)
- rstd = tl.load(Var + rows, mask=rows < M, other=0.)
- a_hat = (a - mean[:, None]) * rstd[:, None]
- dw += dout * a_hat
- db += dout
+ UNROLL: tl.constexpr = 4
+ for i in range(0, M, BLOCK_SIZE_M * UNROLL):
+ for j in range(UNROLL):
+ rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ mask = (rows[:, None] < M) & (cols[None, :] < N)
+ offs = rows[:, None] * N + cols[None, :]
+ a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
+ dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
+ mean = tl.load(Mean + rows, mask=rows < M, other=0.)
+ rstd = tl.load(Var + rows, mask=rows < M, other=0.)
+ a_hat = (a - mean[:, None]) * rstd[:, None]
+ dw += dout * a_hat
+ db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -211,7 +213,15 @@ class LayerNorm(torch.autograd.Function):
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
- # accumulate partial sums in separate kernel
+ if N > 10240:
+ BLOCK_SIZE_N = 128
+ BLOCK_SIZE_M = 32
+ num_warps = 4
+ else:
+ # maximize occupancy for small N
+ BLOCK_SIZE_N = 16
+ BLOCK_SIZE_M = 16
+ num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -220,17 +230,11 @@ class LayerNorm(torch.autograd.Function):
dbias,
M,
N,
- BLOCK_SIZE_M=32,
- BLOCK_SIZE_N=128,
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
+ num_warps=num_warps
)
- return (da, None, dweight, dbias, None, None,
- None, None, None, None,
- None,
- None, None, None,
- None,
- None, None, None,
- None, None, None,
- None, None, None)
+ return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):
diff --git a/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb b/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb
index 0930144cf..6838e6470 100644
--- a/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb
+++ b/master/_downloads/ae7fff29e1b574187bc930ed94bcc353/05-layer-norm.ipynb
@@ -26,7 +26,7 @@
},
"outputs": [],
"source": [
- "import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n# Backward pass (DA + partial DW + partial DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n # allocate output\n out = torch.empty_like(a)\n # reshape input data into 2D tensor\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = weight.shape[0]\n # allocate output\n da = torch.empty_like(dout)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n # accumulate partial sums in separate kernel\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=32,\n BLOCK_SIZE_N=128,\n )\n return (da, None, dweight, dbias, None, None,\n None, None, None, None,\n None,\n None, None, None,\n None,\n None, None, None,\n None, None, None,\n None, None, None)\n\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n torch.manual_seed(0)\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),\n line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),\n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\n# test_layer_norm(1151, 8192, torch.float16)\nbench_layer_norm.run(save_path='.', print_data=True)"
+ "import torch\n\nimport triton\nimport triton.language as tl\n\ntry:\n # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it\n # should not be added to extras_require in setup.py.\n import apex\n HAS_APEX = True\nexcept ModuleNotFoundError:\n HAS_APEX = False\n\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n Out,\n A,\n Weight,\n Bias,\n Mean, Rstd,\n stride, N, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n # position of elements processed by this program\n row = tl.program_id(0)\n Out += row * stride\n A += row * stride\n # compute mean\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n # compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy=\"evict_last\").to(tl.float32)\n a = tl.where(cols < N, a - mean, 0.)\n _var += a * a\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # write-back mean/rstd\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n # multiply by weight and add bias\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n weight = tl.load(Weight + cols, mask=mask)\n bias = tl.load(Bias + cols, mask=mask)\n a = tl.load(A + cols, mask=mask, other=0., eviction_policy=\"evict_first\").to(tl.float32)\n a_hat = (a - mean) * rstd\n out = a_hat * weight + bias\n # # write-back\n tl.store(Out + cols, out, mask=mask)\n\n# Backward pass (DA + partial DW + partial DB)\n\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(\n _DA,\n _DOut,\n _A,\n Weight,\n Mean, Rstd,\n stride, NumRows, NumCols, eps,\n BLOCK_SIZE_N: tl.constexpr,\n):\n # position of elements processed by this program\n pid = tl.program_id(0)\n row = pid\n A = _A + row * stride\n DOut = _DOut + row * stride\n DA = _DA + row * stride\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n # load data to SRAM\n _mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n _mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n _mean1 += a_hat * wdout\n _mean2 += wdout\n mean1 = tl.sum(_mean1, axis=0) / NumCols\n mean2 = 0.\n mean2 = tl.sum(_mean2, axis=0) / NumCols\n for off in range(0, NumCols, BLOCK_SIZE_N):\n cols = off + tl.arange(0, BLOCK_SIZE_N)\n mask = cols < NumCols\n a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)\n dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)\n weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)\n a_hat = (a - mean) * rstd\n wdout = weight * dout\n da = (wdout - (a_hat * mean1 + mean2)) * rstd\n # write-back dx\n tl.store(DA + cols, da, mask=mask)\n\n\n# Backward pass (total DW + total DB)\n@triton.jit\ndef _layer_norm_bwd_dwdb(\n A, DOut,\n Mean, Var,\n DW,\n DB,\n M, N,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n UNROLL: tl.constexpr = 4\n for i in range(0, M, BLOCK_SIZE_M * UNROLL):\n for j in range(UNROLL):\n rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)\n dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)\n mean = tl.load(Mean + rows, mask=rows < M, other=0.)\n rstd = tl.load(Var + rows, mask=rows < M, other=0.)\n a_hat = (a - mean[:, None]) * rstd[:, None]\n dw += dout * a_hat\n db += dout\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(DW + cols, sum_dw, mask=cols < N)\n tl.store(DB + cols, sum_db, mask=cols < N)\n\n\nclass LayerNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, a, normalized_shape, weight, bias, eps):\n # allocate output\n out = torch.empty_like(a)\n # reshape input data into 2D tensor\n a_arg = a.reshape(-1, a.shape[-1])\n M, N = a_arg.shape\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // a.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n BLOCK_SIZE = max(BLOCK_SIZE, 128)\n BLOCK_SIZE = min(BLOCK_SIZE, 4096)\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M,)](\n out,\n a_arg,\n weight,\n bias,\n mean, rstd,\n a_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.save_for_backward(\n a, weight, bias, mean, rstd,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n if hasattr(bias, \"config\"):\n assert bias.config.grad_scale_name == weight.config.grad_scale_name\n grad_scale_name = bias.config.grad_scale_name\n else:\n grad_scale_name = None\n ctx.grad_scale_gain_bias_name = grad_scale_name\n return out\n\n @staticmethod\n def backward(ctx, dout):\n assert dout.is_contiguous()\n a, weight, bias, mean, var = ctx.saved_tensors\n # heuristics for amount of parallel reduction stream for DG/DB\n N = weight.shape[0]\n # allocate output\n da = torch.empty_like(dout)\n # enqueue kernel using forward pass heuristics\n # also compute partial sums for DW and DB\n x_arg = a.reshape(-1, a.shape[-1])\n M, N = x_arg.shape\n dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)\n _layer_norm_bwd_dx_fused[(M,)](\n da,\n dout,\n a,\n weight,\n mean, var,\n x_arg.stride(0), M, N,\n ctx.eps,\n BLOCK_SIZE_N=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n if N > 10240:\n BLOCK_SIZE_N = 128\n BLOCK_SIZE_M = 32\n num_warps = 4\n else:\n # maximize occupancy for small N\n BLOCK_SIZE_N = 16\n BLOCK_SIZE_M = 16\n num_warps = 8\n grid = lambda meta: [triton.cdiv(N, meta[\"BLOCK_SIZE_N\"])]\n _layer_norm_bwd_dwdb[grid](\n a, dout,\n mean, var,\n dweight,\n dbias,\n M,\n N,\n BLOCK_SIZE_M=BLOCK_SIZE_M,\n BLOCK_SIZE_N=BLOCK_SIZE_N,\n num_warps=num_warps\n )\n return (da, None, dweight, dbias, None)\n\n\ndef layer_norm(a, normalized_shape, weight, bias, eps):\n return LayerNorm.apply(a, normalized_shape, weight, bias, eps)\n\n\ndef test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):\n torch.manual_seed(0)\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # forward pass\n y_tri = layer_norm(x, w_shape, weight, bias, eps)\n y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)\n # backward pass (triton)\n y_tri.backward(dy, retain_graph=True)\n dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]\n x.grad, weight.grad, bias.grad = None, None, None\n # backward pass (torch)\n y_ref.backward(dy, retain_graph=True)\n dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]\n # compare\n triton.testing.assert_almost_equal(y_tri, y_ref)\n triton.testing.assert_almost_equal(dx_tri, dx_ref)\n triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)\n triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)\n\n\n@triton.testing.perf_report(\n triton.testing.Benchmark(\n x_names=['N'],\n x_vals=[512 * i for i in range(2, 32)],\n line_arg='provider',\n line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),\n line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),\n styles=[('blue', '-'), ('green', '-'), ('orange', '-')],\n ylabel='GB/s',\n plot_name='layer-norm',\n args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}\n )\n)\ndef bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):\n # create data\n x_shape = (M, N)\n w_shape = (x_shape[-1], )\n weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)\n x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')\n dy = .1 * torch.randn_like(x)\n x.requires_grad_(True)\n # utility functions\n if provider == 'triton':\n y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'torch':\n y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)\n if provider == 'apex':\n apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)\n y_fwd = lambda: apex_layer_norm(x)\n # forward pass\n if mode == 'forward':\n gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6\n ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)\n # backward pass\n if mode == 'backward':\n gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6\n y = y_fwd()\n ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),\n grad_to_none=[x], rep=500)\n return gbps(ms), gbps(max_ms), gbps(min_ms)\n\n\n# test_layer_norm(1151, 8192, torch.float16)\nbench_layer_norm.run(save_path='.', print_data=True)"
]
}
],
diff --git a/master/_images/sphx_glr_01-vector-add_001.png b/master/_images/sphx_glr_01-vector-add_001.png
index 264f2bc3f..faedfc5df 100644
Binary files a/master/_images/sphx_glr_01-vector-add_001.png and b/master/_images/sphx_glr_01-vector-add_001.png differ
diff --git a/master/_images/sphx_glr_01-vector-add_thumb.png b/master/_images/sphx_glr_01-vector-add_thumb.png
index bf1427cde..3d6da8f70 100644
Binary files a/master/_images/sphx_glr_01-vector-add_thumb.png and b/master/_images/sphx_glr_01-vector-add_thumb.png differ
diff --git a/master/_images/sphx_glr_02-fused-softmax_001.png b/master/_images/sphx_glr_02-fused-softmax_001.png
index ca3de72d8..4ccac83a3 100644
Binary files a/master/_images/sphx_glr_02-fused-softmax_001.png and b/master/_images/sphx_glr_02-fused-softmax_001.png differ
diff --git a/master/_images/sphx_glr_02-fused-softmax_thumb.png b/master/_images/sphx_glr_02-fused-softmax_thumb.png
index 38856739b..b52715562 100644
Binary files a/master/_images/sphx_glr_02-fused-softmax_thumb.png and b/master/_images/sphx_glr_02-fused-softmax_thumb.png differ
diff --git a/master/_images/sphx_glr_03-matrix-multiplication_001.png b/master/_images/sphx_glr_03-matrix-multiplication_001.png
index 5a1c16378..61fb63b0d 100644
Binary files a/master/_images/sphx_glr_03-matrix-multiplication_001.png and b/master/_images/sphx_glr_03-matrix-multiplication_001.png differ
diff --git a/master/_images/sphx_glr_03-matrix-multiplication_thumb.png b/master/_images/sphx_glr_03-matrix-multiplication_thumb.png
index 3b0d7c0cf..8f4303601 100644
Binary files a/master/_images/sphx_glr_03-matrix-multiplication_thumb.png and b/master/_images/sphx_glr_03-matrix-multiplication_thumb.png differ
diff --git a/master/_images/sphx_glr_05-layer-norm_001.png b/master/_images/sphx_glr_05-layer-norm_001.png
index 5dbc57259..33c12cadf 100644
Binary files a/master/_images/sphx_glr_05-layer-norm_001.png and b/master/_images/sphx_glr_05-layer-norm_001.png differ
diff --git a/master/_images/sphx_glr_05-layer-norm_thumb.png b/master/_images/sphx_glr_05-layer-norm_thumb.png
index 98d623bcf..588cf049c 100644
Binary files a/master/_images/sphx_glr_05-layer-norm_thumb.png and b/master/_images/sphx_glr_05-layer-norm_thumb.png differ
diff --git a/master/_images/sphx_glr_06-fused-attention_thumb.png b/master/_images/sphx_glr_06-fused-attention_thumb.png
new file mode 100644
index 000000000..8a5fed589
Binary files /dev/null and b/master/_images/sphx_glr_06-fused-attention_thumb.png differ
diff --git a/master/_images/sphx_glr_07-libdevice-function_thumb.png b/master/_images/sphx_glr_07-libdevice-function_thumb.png
new file mode 100644
index 000000000..8a5fed589
Binary files /dev/null and b/master/_images/sphx_glr_07-libdevice-function_thumb.png differ
diff --git a/master/_sources/getting-started/tutorials/01-vector-add.rst.txt b/master/_sources/getting-started/tutorials/01-vector-add.rst.txt
index 20f8ec750..e3022bcf7 100644
--- a/master/_sources/getting-started/tutorials/01-vector-add.rst.txt
+++ b/master/_sources/getting-started/tutorials/01-vector-add.rst.txt
@@ -238,7 +238,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
3 32768.0 76.800002 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
- 6 262144.0 341.333321 384.000001
+ 6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
@@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
.. rst-class:: sphx-glr-timing
- **Total running time of the script:** ( 1 minutes 34.829 seconds)
+ **Total running time of the script:** ( 1 minutes 50.020 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
diff --git a/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt b/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt
index 6766f31b8..b4b0ad79a 100644
--- a/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt
+++ b/master/_sources/getting-started/tutorials/02-fused-softmax.rst.txt
@@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
softmax-performance:
N Triton Torch (native) Torch (jit)
- 0 256.0 546.133347 512.000001 188.321838
- 1 384.0 614.400016 585.142862 153.600004
- 2 512.0 655.360017 585.142849 154.566038
+ 0 256.0 546.133347 546.133347 186.181817
+ 1 384.0 614.400016 585.142862 151.703707
+ 2 512.0 655.360017 606.814814 154.566038
3 640.0 706.206879 640.000002 160.000000
- 4 768.0 722.823517 664.216187 162.754967
+ 4 768.0 722.823517 664.216187 163.839992
.. ... ... ... ...
93 12160.0 812.359066 406.179533 198.936606
- 94 12288.0 812.429770 415.222812 199.298541
- 95 12416.0 812.498981 412.149375 198.954424
- 96 12544.0 812.566838 412.758863 199.209928
- 97 12672.0 811.007961 412.097543 199.264875
+ 94 12288.0 812.429770 415.222812 199.096718
+ 95 12416.0 812.498981 412.149375 198.854847
+ 96 12544.0 810.925276 412.971190 199.012395
+ 97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns]
@@ -306,7 +306,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing
- **Total running time of the script:** ( 3 minutes 18.076 seconds)
+ **Total running time of the script:** ( 3 minutes 32.089 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
diff --git a/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt b/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt
index 612443a28..5a0558c63 100644
--- a/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt
+++ b/master/_sources/getting-started/tutorials/03-matrix-multiplication.rst.txt
@@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
- 0 256.0 2.730667 ... 2.978909 2.978909
+ 0 256.0 2.730667 ... 3.276800 2.978909
1 384.0 7.372800 ... 7.899428 7.899428
- 2 512.0 14.563555 ... 15.420235 15.420235
+ 2 512.0 14.563555 ... 16.384000 15.420235
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 35.389441 34.028308
- 5 896.0 37.971025 ... 40.140799 39.025776
+ 5 896.0 39.025776 ... 40.140799 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
- 7 1152.0 45.242181 ... 48.161033 47.396572
+ 7 1152.0 45.242181 ... 47.396572 47.396572
8 1280.0 51.200001 ... 57.690139 57.690139
- 9 1408.0 64.138541 ... 68.147202 65.684049
- 10 1536.0 79.526831 ... 81.355034 78.643199
- 11 1664.0 63.372618 ... 63.372618 62.492442
+ 9 1408.0 64.138541 ... 68.147202 66.485074
+ 10 1536.0 80.430545 ... 80.430545 78.643199
+ 11 1664.0 62.929456 ... 63.372618 62.492442
12 1792.0 72.983276 ... 72.983276 59.154861
- 13 1920.0 68.776119 ... 71.626943 70.892307
- 14 2048.0 73.584279 ... 78.033565 76.959706
- 15 2176.0 83.155572 ... 87.494120 86.367588
- 16 2304.0 68.446623 ... 78.064941 77.057651
- 17 2432.0 71.305746 ... 86.179335 85.393507
- 18 2560.0 77.833728 ... 82.956960 81.715711
- 19 2688.0 83.737433 ... 91.185232 89.464755
- 20 2816.0 82.446516 ... 84.523664 83.712490
- 21 2944.0 81.967162 ... 83.758038 82.373605
- 22 3072.0 82.420822 ... 88.750943 86.579673
- 23 3200.0 81.528664 ... 91.233074 95.665176
- 24 3328.0 83.516586 ... 85.908470 83.323259
- 25 3456.0 81.435930 ... 92.138932 90.180725
- 26 3584.0 83.954614 ... 91.189190 95.858629
- 27 3712.0 85.822459 ... 83.806497 87.783251
- 28 3840.0 80.901241 ... 89.259080 89.548180
- 29 3968.0 87.913500 ... 92.829164 84.096442
- 30 4096.0 93.825748 ... 89.299883 90.139506
+ 13 1920.0 69.120002 ... 71.257735 71.257735
+ 14 2048.0 73.584279 ... 78.398206 77.314362
+ 15 2176.0 83.155572 ... 87.494120 85.998493
+ 16 2304.0 68.446623 ... 78.320893 77.558029
+ 17 2432.0 71.305746 ... 86.711310 75.421383
+ 18 2560.0 77.833728 ... 82.747477 81.715711
+ 19 2688.0 83.552988 ... 90.532356 89.464755
+ 20 2816.0 84.197315 ... 84.035084 84.035084
+ 21 2944.0 82.784108 ... 83.969728 83.060049
+ 22 3072.0 81.825298 ... 89.593522 88.473602
+ 23 3200.0 84.768213 ... 96.096095 95.808380
+ 24 3328.0 83.226931 ... 85.908470 84.596116
+ 25 3456.0 81.766291 ... 91.824110 91.097818
+ 26 3584.0 87.466332 ... 91.194972 94.847460
+ 27 3712.0 85.822459 ... 87.246590 87.860458
+ 28 3840.0 81.859361 ... 87.011801 90.168771
+ 29 3968.0 89.921841 ... 91.954739 85.271796
+ 30 4096.0 93.596744 ... 88.243079 90.382307
[31 rows x 5 columns]
@@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
.. rst-class:: sphx-glr-timing
- **Total running time of the script:** ( 5 minutes 52.578 seconds)
+ **Total running time of the script:** ( 7 minutes 13.827 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
diff --git a/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt b/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt
index 0f9d8dc7a..edbf188ee 100644
--- a/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt
+++ b/master/_sources/getting-started/tutorials/04-low-memory-dropout.rst.txt
@@ -240,7 +240,7 @@ References
.. rst-class:: sphx-glr-timing
- **Total running time of the script:** ( 0 minutes 0.476 seconds)
+ **Total running time of the script:** ( 0 minutes 0.279 seconds)
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:
diff --git a/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt b/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt
index 8750ca528..295b2a90c 100644
--- a/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt
+++ b/master/_sources/getting-started/tutorials/05-layer-norm.rst.txt
@@ -21,7 +21,7 @@
Layer Normalization
====================
-.. GENERATED FROM PYTHON SOURCE LINES 5-312
+.. GENERATED FROM PYTHON SOURCE LINES 5-316
@@ -40,34 +40,34 @@ Layer Normalization
N Triton Torch Apex
0 1024.0 585.142849 277.694907 468.114273
1 1536.0 630.153868 323.368435 511.999982
- 2 2048.0 682.666643 334.367358 520.126988
- 3 2560.0 694.237267 365.714281 518.481028
- 4 3072.0 712.347810 378.092307 501.551037
- 5 3584.0 725.873439 384.859062 458.751978
- 6 4096.0 728.177767 381.023256 458.293714
- 7 4608.0 670.254540 396.387087 426.173427
- 8 5120.0 694.237267 397.669909 426.666652
- 9 5632.0 704.000002 396.969169 413.357796
- 10 6144.0 702.171410 402.885254 411.313806
+ 2 2048.0 668.734716 337.814445 528.516136
+ 3 2560.0 694.237267 362.477870 512.000013
+ 4 3072.0 712.347810 375.206126 501.551037
+ 5 3584.0 725.873439 384.859062 451.527536
+ 6 4096.0 728.177767 381.023256 455.111095
+ 7 4608.0 670.254540 396.387087 421.302872
+ 8 5120.0 688.403381 395.748783 422.268057
+ 9 5632.0 698.542675 396.969169 409.599997
+ 10 6144.0 702.171410 402.885254 409.600010
11 6656.0 700.631610 400.360920 400.360920
- 12 7168.0 695.078767 396.844306 388.772874
- 13 7680.0 682.666656 393.846167 387.634072
- 14 8192.0 642.509816 393.609605 372.363633
- 15 8704.0 627.315309 389.005597 380.502740
- 16 9216.0 606.814809 407.337026 383.999986
- 17 9728.0 589.575753 409.599987 383.369452
- 18 10240.0 566.920437 408.578556 382.803739
- 19 10752.0 549.623009 411.559798 381.445676
- 20 11264.0 536.380957 406.826188 373.134567
- 21 11776.0 523.377770 410.492372 377.587162
- 22 12288.0 517.389457 414.784810 383.251457
- 23 12800.0 505.679014 410.420828 376.470582
- 24 13312.0 494.180982 405.699062 376.976995
- 25 13824.0 482.934503 411.888257 379.389355
- 26 14336.0 471.967074 406.695045 374.185964
- 27 14848.0 461.297068 408.192434 375.304904
- 28 15360.0 454.269882 406.214870 378.092307
- 29 15872.0 447.887117 407.627589 376.225175
+ 12 7168.0 678.627194 386.154893 384.859062
+ 13 7680.0 682.666656 391.337574 386.415087
+ 14 8192.0 645.674867 390.095241 376.643677
+ 15 8704.0 624.502255 390.095225 379.465939
+ 16 9216.0 604.327881 405.098894 383.002605
+ 17 9728.0 585.142883 409.599987 382.427505
+ 18 10240.0 564.965524 409.600010 382.803739
+ 19 10752.0 546.133312 410.577576 380.601764
+ 20 11264.0 531.634232 395.228063 370.069806
+ 21 11776.0 520.486200 409.599991 376.831982
+ 22 12288.0 516.031509 413.911572 383.251457
+ 23 12800.0 504.433489 410.420828 375.779805
+ 24 13312.0 494.180982 405.699062 376.310952
+ 25 13824.0 481.882350 411.888257 378.739711
+ 26 14336.0 471.967074 401.709294 372.969090
+ 27 14848.0 461.297068 407.492270 375.898745
+ 28 15360.0 453.431739 406.887417 378.092307
+ 29 15872.0 447.098578 406.323209 376.225175
@@ -204,17 +204,19 @@ Layer Normalization
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for i in range(0, M, BLOCK_SIZE_M):
- rows = i + tl.arange(0, BLOCK_SIZE_M)
- mask = (rows[:, None] < M) & (cols[None, :] < N)
- offs = rows[:, None] * N + cols[None, :]
- a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
- dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
- mean = tl.load(Mean + rows, mask=rows < M, other=0.)
- rstd = tl.load(Var + rows, mask=rows < M, other=0.)
- a_hat = (a - mean[:, None]) * rstd[:, None]
- dw += dout * a_hat
- db += dout
+ UNROLL: tl.constexpr = 4
+ for i in range(0, M, BLOCK_SIZE_M * UNROLL):
+ for j in range(UNROLL):
+ rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ mask = (rows[:, None] < M) & (cols[None, :] < N)
+ offs = rows[:, None] * N + cols[None, :]
+ a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
+ dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
+ mean = tl.load(Mean + rows, mask=rows < M, other=0.)
+ rstd = tl.load(Var + rows, mask=rows < M, other=0.)
+ a_hat = (a - mean[:, None]) * rstd[:, None]
+ dw += dout * a_hat
+ db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -287,7 +289,15 @@ Layer Normalization
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
- # accumulate partial sums in separate kernel
+ if N > 10240:
+ BLOCK_SIZE_N = 128
+ BLOCK_SIZE_M = 32
+ num_warps = 4
+ else:
+ # maximize occupancy for small N
+ BLOCK_SIZE_N = 16
+ BLOCK_SIZE_M = 16
+ num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -296,17 +306,11 @@ Layer Normalization
dbias,
M,
N,
- BLOCK_SIZE_M=32,
- BLOCK_SIZE_N=128,
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
+ num_warps=num_warps
)
- return (da, None, dweight, dbias, None, None,
- None, None, None, None,
- None,
- None, None, None,
- None,
- None, None, None,
- None, None, None,
- None, None, None)
+ return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):
@@ -389,7 +393,7 @@ Layer Normalization
.. rst-class:: sphx-glr-timing
- **Total running time of the script:** ( 5 minutes 24.641 seconds)
+ **Total running time of the script:** ( 5 minutes 32.552 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
diff --git a/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt b/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt
new file mode 100644
index 000000000..2876c4654
--- /dev/null
+++ b/master/_sources/getting-started/tutorials/06-fused-attention.rst.txt
@@ -0,0 +1,416 @@
+
+.. DO NOT EDIT.
+.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
+.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
+.. "getting-started/tutorials/06-fused-attention.py"
+.. LINE NUMBERS ARE GIVEN BELOW.
+
+.. only:: html
+
+ .. note::
+ :class: sphx-glr-download-link-note
+
+ Click :ref:`here
Total running time of the script: ( 1 minutes 34.829 seconds)
+Total running time of the script: ( 1 minutes 50.020 seconds)