[GH-PAGES] Added missing files
This commit is contained in:
4
.buildinfo
Normal file
4
.buildinfo
Normal file
@@ -0,0 +1,4 @@
|
||||
# Sphinx build info version 1
|
||||
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
|
||||
config: 76c4cbf22d8ff0aa13ac3f4683f2586c
|
||||
tags: 645f666f9bcd5a90fca523b33c5a78b7
|
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
.vscode
|
||||
./docs/_build/*
|
||||
./build/*
|
BIN
_images/tutorials_02-fused-softmax_12_0.png
Normal file
BIN
_images/tutorials_02-fused-softmax_12_0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
329
_sources/tutorials/01-vector-add.ipynb.txt
Normal file
329
_sources/tutorials/01-vector-add.ipynb.txt
Normal file
@@ -0,0 +1,329 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "acute-possession",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Vector Addition"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "median-malaysia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
||||
"* The basic syntax of the Triton programming language\n",
|
||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "identical-conditions",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Compute Kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "collectible-belle",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```c\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
||||
" // of the program in the overaching SPMD context\n",
|
||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
||||
" // different chunks of data in parallel.\n",
|
||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
||||
" // is similar to blockIdx.{x,y,z}\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
||||
" // they are primitives data-types and are -- contrary to C and\n",
|
||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
||||
" // memory.\n",
|
||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
||||
" // whose memory values are, e.g.:\n",
|
||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz [BLOCK] = z + offset;\n",
|
||||
" float* px [BLOCK] = x + offset;\n",
|
||||
" float* py [BLOCK] = y + offset;\n",
|
||||
" // Simple element-wise control-flow for load/store operations can\n",
|
||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
||||
" // write-back `z`\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "forbidden-wednesday",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Torch bindings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "numerical-agency",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
||||
"\n",
|
||||
"To create a `triton.kernel`, you only need three things:\n",
|
||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
||||
"\n",
|
||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "sporting-keyboard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import triton\n",
|
||||
"\n",
|
||||
"# source-code for Triton compute kernel\n",
|
||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
||||
"_src = \"\"\"\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // program id\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // create arrays of pointers\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz[BLOCK] = z + offset;\n",
|
||||
" float* px[BLOCK] = x + offset;\n",
|
||||
" float* py[BLOCK] = y + offset;\n",
|
||||
" // bounds checking\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" // write-back\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
" \"\"\"\n",
|
||||
"# This function returns a callable `triton.kernel` object\n",
|
||||
"# created from the above source code.\n",
|
||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
||||
"# We compile the kernel with -DBLOCK=1024\n",
|
||||
"_kernels = dict()\n",
|
||||
"def make_add_kernel(device):\n",
|
||||
" if device not in _kernels:\n",
|
||||
" defines = {'BLOCK': 1024}\n",
|
||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
||||
" return _kernels[device]\n",
|
||||
"\n",
|
||||
"# This is a standard torch custom autograd Function\n",
|
||||
"# The only difference is that we can now use the above kernel\n",
|
||||
"# in the `forward` and `backward` functions.`\n",
|
||||
"class _add(torch.autograd.Function):\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def forward(ctx, x, y):\n",
|
||||
" # constraints of the op\n",
|
||||
" assert x.dtype == torch.float32\n",
|
||||
" # *allocate output*\n",
|
||||
" z = torch.empty_like(x)\n",
|
||||
" # *create launch grid*:\n",
|
||||
" # this is a function which takes compilation parameters `opt`\n",
|
||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
||||
" N = z.shape[0]\n",
|
||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
||||
" # *launch kernel*:\n",
|
||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
||||
" # the `.data_ptr()` method\n",
|
||||
" kernel = make_add_kernel(z.device)\n",
|
||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
||||
" return z\n",
|
||||
"# Just like we standard PyTorch ops\n",
|
||||
"# We use the `.apply` method to create a \n",
|
||||
"# callable object for our function\n",
|
||||
"add = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "separated-polyester",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exclusive-salvation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Unit Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "supported-ribbon",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"The maximum difference between torch and triton is 0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\n",
|
||||
"x = torch.rand(98432, device='cuda')\n",
|
||||
"y = torch.rand(98432, device='cuda')\n",
|
||||
"za = x + y\n",
|
||||
"zb = add(x, y)\n",
|
||||
"print(za)\n",
|
||||
"print(zb)\n",
|
||||
"print(f'The maximum difference between torch and triton is '\n",
|
||||
" f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "otherwise-canadian",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "polished-australia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Benchmark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "historic-glass",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "strange-luxembourg",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We now want to benchmark the performance of `add`\n",
|
||||
"# Against that of PyTorch for increasing vector sizes\n",
|
||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" ret = fn()\n",
|
||||
" for i in range(warmup):\n",
|
||||
" fn()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" start_event.record()\n",
|
||||
" for i in range(rep):\n",
|
||||
" fn()\n",
|
||||
" end_event.record()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
||||
" return time_ms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "hairy-claim",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "pleasant-valley",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"131072 0.020 0.003\n",
|
||||
"262144 0.019 0.004\n",
|
||||
"524288 0.016 0.016\n",
|
||||
"1048576 0.033 0.033\n",
|
||||
"2097152 0.071 0.070\n",
|
||||
"4194304 0.142 0.144\n",
|
||||
"8388608 0.287 0.286\n",
|
||||
"16777216 0.572 0.568\n",
|
||||
"33554432 1.139 1.110\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
||||
" x = torch.rand(N, device='cuda')\n",
|
||||
" y = torch.rand(N, device='cuda')\n",
|
||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
||||
" torch_ms = do_bench(lambda: x + y)\n",
|
||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "juvenile-supplement",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
308
_sources/tutorials/02-fused-softmax.ipynb.txt
Normal file
308
_sources/tutorials/02-fused-softmax.ipynb.txt
Normal file
File diff suppressed because one or more lines are too long
BIN
_static/css/fonts/Roboto-Slab-Bold.woff
Normal file
BIN
_static/css/fonts/Roboto-Slab-Bold.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/Roboto-Slab-Bold.woff2
Normal file
BIN
_static/css/fonts/Roboto-Slab-Bold.woff2
Normal file
Binary file not shown.
BIN
_static/css/fonts/Roboto-Slab-Regular.woff
Normal file
BIN
_static/css/fonts/Roboto-Slab-Regular.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/Roboto-Slab-Regular.woff2
Normal file
BIN
_static/css/fonts/Roboto-Slab-Regular.woff2
Normal file
Binary file not shown.
BIN
_static/css/fonts/fontawesome-webfont.eot
Normal file
BIN
_static/css/fonts/fontawesome-webfont.eot
Normal file
Binary file not shown.
2671
_static/css/fonts/fontawesome-webfont.svg
Normal file
2671
_static/css/fonts/fontawesome-webfont.svg
Normal file
File diff suppressed because it is too large
Load Diff
After Width: | Height: | Size: 434 KiB |
BIN
_static/css/fonts/fontawesome-webfont.ttf
Normal file
BIN
_static/css/fonts/fontawesome-webfont.ttf
Normal file
Binary file not shown.
BIN
_static/css/fonts/fontawesome-webfont.woff
Normal file
BIN
_static/css/fonts/fontawesome-webfont.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/fontawesome-webfont.woff2
Normal file
BIN
_static/css/fonts/fontawesome-webfont.woff2
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-bold-italic.woff
Normal file
BIN
_static/css/fonts/lato-bold-italic.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-bold-italic.woff2
Normal file
BIN
_static/css/fonts/lato-bold-italic.woff2
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-bold.woff
Normal file
BIN
_static/css/fonts/lato-bold.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-bold.woff2
Normal file
BIN
_static/css/fonts/lato-bold.woff2
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-normal-italic.woff
Normal file
BIN
_static/css/fonts/lato-normal-italic.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-normal-italic.woff2
Normal file
BIN
_static/css/fonts/lato-normal-italic.woff2
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-normal.woff
Normal file
BIN
_static/css/fonts/lato-normal.woff
Normal file
Binary file not shown.
BIN
_static/css/fonts/lato-normal.woff2
Normal file
BIN
_static/css/fonts/lato-normal.woff2
Normal file
Binary file not shown.
12
_static/documentation_options.js
Normal file
12
_static/documentation_options.js
Normal file
@@ -0,0 +1,12 @@
|
||||
var DOCUMENTATION_OPTIONS = {
|
||||
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
|
||||
VERSION: '',
|
||||
LANGUAGE: 'None',
|
||||
COLLAPSE_INDEX: false,
|
||||
BUILDER: 'html',
|
||||
FILE_SUFFIX: '.html',
|
||||
LINK_SUFFIX: '.html',
|
||||
HAS_SOURCE: true,
|
||||
SOURCELINK_SUFFIX: '.txt',
|
||||
NAVIGATION_WITH_KEYS: false
|
||||
};
|
10872
_static/jquery-3.5.1.js
vendored
Normal file
10872
_static/jquery-3.5.1.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1
_static/js/badge_only.js
Normal file
1
_static/js/badge_only.js
Normal file
@@ -0,0 +1 @@
|
||||
!function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}});
|
4
_static/js/html5shiv-printshiv.min.js
vendored
Normal file
4
_static/js/html5shiv-printshiv.min.js
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
/**
|
||||
* @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed
|
||||
*/
|
||||
!function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x<style>"+b+"</style>",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="<xyz></xyz>",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document);
|
4
_static/js/html5shiv.min.js
vendored
Normal file
4
_static/js/html5shiv.min.js
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
/**
|
||||
* @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed
|
||||
*/
|
||||
!function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x<style>"+b+"</style>",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="<xyz></xyz>",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document);
|
297
_static/language_data.js
Normal file
297
_static/language_data.js
Normal file
@@ -0,0 +1,297 @@
|
||||
/*
|
||||
* language_data.js
|
||||
* ~~~~~~~~~~~~~~~~
|
||||
*
|
||||
* This script contains the language-specific data used by searchtools.js,
|
||||
* namely the list of stopwords, stemmer, scorer and splitter.
|
||||
*
|
||||
* :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS.
|
||||
* :license: BSD, see LICENSE for details.
|
||||
*
|
||||
*/
|
||||
|
||||
var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"];
|
||||
|
||||
|
||||
/* Non-minified version JS is _stemmer.js if file is provided */
|
||||
/**
|
||||
* Porter Stemmer
|
||||
*/
|
||||
var Stemmer = function() {
|
||||
|
||||
var step2list = {
|
||||
ational: 'ate',
|
||||
tional: 'tion',
|
||||
enci: 'ence',
|
||||
anci: 'ance',
|
||||
izer: 'ize',
|
||||
bli: 'ble',
|
||||
alli: 'al',
|
||||
entli: 'ent',
|
||||
eli: 'e',
|
||||
ousli: 'ous',
|
||||
ization: 'ize',
|
||||
ation: 'ate',
|
||||
ator: 'ate',
|
||||
alism: 'al',
|
||||
iveness: 'ive',
|
||||
fulness: 'ful',
|
||||
ousness: 'ous',
|
||||
aliti: 'al',
|
||||
iviti: 'ive',
|
||||
biliti: 'ble',
|
||||
logi: 'log'
|
||||
};
|
||||
|
||||
var step3list = {
|
||||
icate: 'ic',
|
||||
ative: '',
|
||||
alize: 'al',
|
||||
iciti: 'ic',
|
||||
ical: 'ic',
|
||||
ful: '',
|
||||
ness: ''
|
||||
};
|
||||
|
||||
var c = "[^aeiou]"; // consonant
|
||||
var v = "[aeiouy]"; // vowel
|
||||
var C = c + "[^aeiouy]*"; // consonant sequence
|
||||
var V = v + "[aeiou]*"; // vowel sequence
|
||||
|
||||
var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0
|
||||
var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1
|
||||
var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1
|
||||
var s_v = "^(" + C + ")?" + v; // vowel in stem
|
||||
|
||||
this.stemWord = function (w) {
|
||||
var stem;
|
||||
var suffix;
|
||||
var firstch;
|
||||
var origword = w;
|
||||
|
||||
if (w.length < 3)
|
||||
return w;
|
||||
|
||||
var re;
|
||||
var re2;
|
||||
var re3;
|
||||
var re4;
|
||||
|
||||
firstch = w.substr(0,1);
|
||||
if (firstch == "y")
|
||||
w = firstch.toUpperCase() + w.substr(1);
|
||||
|
||||
// Step 1a
|
||||
re = /^(.+?)(ss|i)es$/;
|
||||
re2 = /^(.+?)([^s])s$/;
|
||||
|
||||
if (re.test(w))
|
||||
w = w.replace(re,"$1$2");
|
||||
else if (re2.test(w))
|
||||
w = w.replace(re2,"$1$2");
|
||||
|
||||
// Step 1b
|
||||
re = /^(.+?)eed$/;
|
||||
re2 = /^(.+?)(ed|ing)$/;
|
||||
if (re.test(w)) {
|
||||
var fp = re.exec(w);
|
||||
re = new RegExp(mgr0);
|
||||
if (re.test(fp[1])) {
|
||||
re = /.$/;
|
||||
w = w.replace(re,"");
|
||||
}
|
||||
}
|
||||
else if (re2.test(w)) {
|
||||
var fp = re2.exec(w);
|
||||
stem = fp[1];
|
||||
re2 = new RegExp(s_v);
|
||||
if (re2.test(stem)) {
|
||||
w = stem;
|
||||
re2 = /(at|bl|iz)$/;
|
||||
re3 = new RegExp("([^aeiouylsz])\\1$");
|
||||
re4 = new RegExp("^" + C + v + "[^aeiouwxy]$");
|
||||
if (re2.test(w))
|
||||
w = w + "e";
|
||||
else if (re3.test(w)) {
|
||||
re = /.$/;
|
||||
w = w.replace(re,"");
|
||||
}
|
||||
else if (re4.test(w))
|
||||
w = w + "e";
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1c
|
||||
re = /^(.+?)y$/;
|
||||
if (re.test(w)) {
|
||||
var fp = re.exec(w);
|
||||
stem = fp[1];
|
||||
re = new RegExp(s_v);
|
||||
if (re.test(stem))
|
||||
w = stem + "i";
|
||||
}
|
||||
|
||||
// Step 2
|
||||
re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/;
|
||||
if (re.test(w)) {
|
||||
var fp = re.exec(w);
|
||||
stem = fp[1];
|
||||
suffix = fp[2];
|
||||
re = new RegExp(mgr0);
|
||||
if (re.test(stem))
|
||||
w = stem + step2list[suffix];
|
||||
}
|
||||
|
||||
// Step 3
|
||||
re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/;
|
||||
if (re.test(w)) {
|
||||
var fp = re.exec(w);
|
||||
stem = fp[1];
|
||||
suffix = fp[2];
|
||||
re = new RegExp(mgr0);
|
||||
if (re.test(stem))
|
||||
w = stem + step3list[suffix];
|
||||
}
|
||||
|
||||
// Step 4
|
||||
re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/;
|
||||
re2 = /^(.+?)(s|t)(ion)$/;
|
||||
if (re.test(w)) {
|
||||
var fp = re.exec(w);
|
||||
stem = fp[1];
|
||||
re = new RegExp(mgr1);
|
||||
if (re.test(stem))
|
||||
w = stem;
|
||||
}
|
||||
else if (re2.test(w)) {
|
||||
var fp = re2.exec(w);
|
||||
stem = fp[1] + fp[2];
|
||||
re2 = new RegExp(mgr1);
|
||||
if (re2.test(stem))
|
||||
w = stem;
|
||||
}
|
||||
|
||||
// Step 5
|
||||
re = /^(.+?)e$/;
|
||||
if (re.test(w)) {
|
||||
var fp = re.exec(w);
|
||||
stem = fp[1];
|
||||
re = new RegExp(mgr1);
|
||||
re2 = new RegExp(meq1);
|
||||
re3 = new RegExp("^" + C + v + "[^aeiouwxy]$");
|
||||
if (re.test(stem) || (re2.test(stem) && !(re3.test(stem))))
|
||||
w = stem;
|
||||
}
|
||||
re = /ll$/;
|
||||
re2 = new RegExp(mgr1);
|
||||
if (re.test(w) && re2.test(w)) {
|
||||
re = /.$/;
|
||||
w = w.replace(re,"");
|
||||
}
|
||||
|
||||
// and turn initial Y back to y
|
||||
if (firstch == "y")
|
||||
w = firstch.toLowerCase() + w.substr(1);
|
||||
return w;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
var splitChars = (function() {
|
||||
var result = {};
|
||||
var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648,
|
||||
1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702,
|
||||
2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971,
|
||||
2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345,
|
||||
3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761,
|
||||
3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823,
|
||||
4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125,
|
||||
8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695,
|
||||
11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587,
|
||||
43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141];
|
||||
var i, j, start, end;
|
||||
for (i = 0; i < singles.length; i++) {
|
||||
result[singles[i]] = true;
|
||||
}
|
||||
var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709],
|
||||
[722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161],
|
||||
[1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568],
|
||||
[1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807],
|
||||
[1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047],
|
||||
[2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383],
|
||||
[2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450],
|
||||
[2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547],
|
||||
[2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673],
|
||||
[2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820],
|
||||
[2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946],
|
||||
[2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023],
|
||||
[3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173],
|
||||
[3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332],
|
||||
[3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481],
|
||||
[3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718],
|
||||
[3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791],
|
||||
[3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095],
|
||||
[4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205],
|
||||
[4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687],
|
||||
[4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968],
|
||||
[4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869],
|
||||
[5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102],
|
||||
[6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271],
|
||||
[6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592],
|
||||
[6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822],
|
||||
[6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167],
|
||||
[7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959],
|
||||
[7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143],
|
||||
[8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318],
|
||||
[8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483],
|
||||
[8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101],
|
||||
[10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567],
|
||||
[11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292],
|
||||
[12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444],
|
||||
[12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783],
|
||||
[12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311],
|
||||
[19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511],
|
||||
[42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774],
|
||||
[42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071],
|
||||
[43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263],
|
||||
[43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519],
|
||||
[43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647],
|
||||
[43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967],
|
||||
[44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295],
|
||||
[57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274],
|
||||
[64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007],
|
||||
[65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381],
|
||||
[65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]];
|
||||
for (i = 0; i < ranges.length; i++) {
|
||||
start = ranges[i][0];
|
||||
end = ranges[i][1];
|
||||
for (j = start; j <= end; j++) {
|
||||
result[j] = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
})();
|
||||
|
||||
function splitQuery(query) {
|
||||
var result = [];
|
||||
var start = -1;
|
||||
for (var i = 0; i < query.length; i++) {
|
||||
if (splitChars[query.charCodeAt(i)]) {
|
||||
if (start !== -1) {
|
||||
result.push(query.slice(start, i));
|
||||
start = -1;
|
||||
}
|
||||
} else if (start === -1) {
|
||||
start = i;
|
||||
}
|
||||
}
|
||||
if (start !== -1) {
|
||||
result.push(query.slice(start));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
999
_static/underscore-1.3.1.js
Normal file
999
_static/underscore-1.3.1.js
Normal file
@@ -0,0 +1,999 @@
|
||||
// Underscore.js 1.3.1
|
||||
// (c) 2009-2012 Jeremy Ashkenas, DocumentCloud Inc.
|
||||
// Underscore is freely distributable under the MIT license.
|
||||
// Portions of Underscore are inspired or borrowed from Prototype,
|
||||
// Oliver Steele's Functional, and John Resig's Micro-Templating.
|
||||
// For all details and documentation:
|
||||
// http://documentcloud.github.com/underscore
|
||||
|
||||
(function() {
|
||||
|
||||
// Baseline setup
|
||||
// --------------
|
||||
|
||||
// Establish the root object, `window` in the browser, or `global` on the server.
|
||||
var root = this;
|
||||
|
||||
// Save the previous value of the `_` variable.
|
||||
var previousUnderscore = root._;
|
||||
|
||||
// Establish the object that gets returned to break out of a loop iteration.
|
||||
var breaker = {};
|
||||
|
||||
// Save bytes in the minified (but not gzipped) version:
|
||||
var ArrayProto = Array.prototype, ObjProto = Object.prototype, FuncProto = Function.prototype;
|
||||
|
||||
// Create quick reference variables for speed access to core prototypes.
|
||||
var slice = ArrayProto.slice,
|
||||
unshift = ArrayProto.unshift,
|
||||
toString = ObjProto.toString,
|
||||
hasOwnProperty = ObjProto.hasOwnProperty;
|
||||
|
||||
// All **ECMAScript 5** native function implementations that we hope to use
|
||||
// are declared here.
|
||||
var
|
||||
nativeForEach = ArrayProto.forEach,
|
||||
nativeMap = ArrayProto.map,
|
||||
nativeReduce = ArrayProto.reduce,
|
||||
nativeReduceRight = ArrayProto.reduceRight,
|
||||
nativeFilter = ArrayProto.filter,
|
||||
nativeEvery = ArrayProto.every,
|
||||
nativeSome = ArrayProto.some,
|
||||
nativeIndexOf = ArrayProto.indexOf,
|
||||
nativeLastIndexOf = ArrayProto.lastIndexOf,
|
||||
nativeIsArray = Array.isArray,
|
||||
nativeKeys = Object.keys,
|
||||
nativeBind = FuncProto.bind;
|
||||
|
||||
// Create a safe reference to the Underscore object for use below.
|
||||
var _ = function(obj) { return new wrapper(obj); };
|
||||
|
||||
// Export the Underscore object for **Node.js**, with
|
||||
// backwards-compatibility for the old `require()` API. If we're in
|
||||
// the browser, add `_` as a global object via a string identifier,
|
||||
// for Closure Compiler "advanced" mode.
|
||||
if (typeof exports !== 'undefined') {
|
||||
if (typeof module !== 'undefined' && module.exports) {
|
||||
exports = module.exports = _;
|
||||
}
|
||||
exports._ = _;
|
||||
} else {
|
||||
root['_'] = _;
|
||||
}
|
||||
|
||||
// Current version.
|
||||
_.VERSION = '1.3.1';
|
||||
|
||||
// Collection Functions
|
||||
// --------------------
|
||||
|
||||
// The cornerstone, an `each` implementation, aka `forEach`.
|
||||
// Handles objects with the built-in `forEach`, arrays, and raw objects.
|
||||
// Delegates to **ECMAScript 5**'s native `forEach` if available.
|
||||
var each = _.each = _.forEach = function(obj, iterator, context) {
|
||||
if (obj == null) return;
|
||||
if (nativeForEach && obj.forEach === nativeForEach) {
|
||||
obj.forEach(iterator, context);
|
||||
} else if (obj.length === +obj.length) {
|
||||
for (var i = 0, l = obj.length; i < l; i++) {
|
||||
if (i in obj && iterator.call(context, obj[i], i, obj) === breaker) return;
|
||||
}
|
||||
} else {
|
||||
for (var key in obj) {
|
||||
if (_.has(obj, key)) {
|
||||
if (iterator.call(context, obj[key], key, obj) === breaker) return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Return the results of applying the iterator to each element.
|
||||
// Delegates to **ECMAScript 5**'s native `map` if available.
|
||||
_.map = _.collect = function(obj, iterator, context) {
|
||||
var results = [];
|
||||
if (obj == null) return results;
|
||||
if (nativeMap && obj.map === nativeMap) return obj.map(iterator, context);
|
||||
each(obj, function(value, index, list) {
|
||||
results[results.length] = iterator.call(context, value, index, list);
|
||||
});
|
||||
if (obj.length === +obj.length) results.length = obj.length;
|
||||
return results;
|
||||
};
|
||||
|
||||
// **Reduce** builds up a single result from a list of values, aka `inject`,
|
||||
// or `foldl`. Delegates to **ECMAScript 5**'s native `reduce` if available.
|
||||
_.reduce = _.foldl = _.inject = function(obj, iterator, memo, context) {
|
||||
var initial = arguments.length > 2;
|
||||
if (obj == null) obj = [];
|
||||
if (nativeReduce && obj.reduce === nativeReduce) {
|
||||
if (context) iterator = _.bind(iterator, context);
|
||||
return initial ? obj.reduce(iterator, memo) : obj.reduce(iterator);
|
||||
}
|
||||
each(obj, function(value, index, list) {
|
||||
if (!initial) {
|
||||
memo = value;
|
||||
initial = true;
|
||||
} else {
|
||||
memo = iterator.call(context, memo, value, index, list);
|
||||
}
|
||||
});
|
||||
if (!initial) throw new TypeError('Reduce of empty array with no initial value');
|
||||
return memo;
|
||||
};
|
||||
|
||||
// The right-associative version of reduce, also known as `foldr`.
|
||||
// Delegates to **ECMAScript 5**'s native `reduceRight` if available.
|
||||
_.reduceRight = _.foldr = function(obj, iterator, memo, context) {
|
||||
var initial = arguments.length > 2;
|
||||
if (obj == null) obj = [];
|
||||
if (nativeReduceRight && obj.reduceRight === nativeReduceRight) {
|
||||
if (context) iterator = _.bind(iterator, context);
|
||||
return initial ? obj.reduceRight(iterator, memo) : obj.reduceRight(iterator);
|
||||
}
|
||||
var reversed = _.toArray(obj).reverse();
|
||||
if (context && !initial) iterator = _.bind(iterator, context);
|
||||
return initial ? _.reduce(reversed, iterator, memo, context) : _.reduce(reversed, iterator);
|
||||
};
|
||||
|
||||
// Return the first value which passes a truth test. Aliased as `detect`.
|
||||
_.find = _.detect = function(obj, iterator, context) {
|
||||
var result;
|
||||
any(obj, function(value, index, list) {
|
||||
if (iterator.call(context, value, index, list)) {
|
||||
result = value;
|
||||
return true;
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
// Return all the elements that pass a truth test.
|
||||
// Delegates to **ECMAScript 5**'s native `filter` if available.
|
||||
// Aliased as `select`.
|
||||
_.filter = _.select = function(obj, iterator, context) {
|
||||
var results = [];
|
||||
if (obj == null) return results;
|
||||
if (nativeFilter && obj.filter === nativeFilter) return obj.filter(iterator, context);
|
||||
each(obj, function(value, index, list) {
|
||||
if (iterator.call(context, value, index, list)) results[results.length] = value;
|
||||
});
|
||||
return results;
|
||||
};
|
||||
|
||||
// Return all the elements for which a truth test fails.
|
||||
_.reject = function(obj, iterator, context) {
|
||||
var results = [];
|
||||
if (obj == null) return results;
|
||||
each(obj, function(value, index, list) {
|
||||
if (!iterator.call(context, value, index, list)) results[results.length] = value;
|
||||
});
|
||||
return results;
|
||||
};
|
||||
|
||||
// Determine whether all of the elements match a truth test.
|
||||
// Delegates to **ECMAScript 5**'s native `every` if available.
|
||||
// Aliased as `all`.
|
||||
_.every = _.all = function(obj, iterator, context) {
|
||||
var result = true;
|
||||
if (obj == null) return result;
|
||||
if (nativeEvery && obj.every === nativeEvery) return obj.every(iterator, context);
|
||||
each(obj, function(value, index, list) {
|
||||
if (!(result = result && iterator.call(context, value, index, list))) return breaker;
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
// Determine if at least one element in the object matches a truth test.
|
||||
// Delegates to **ECMAScript 5**'s native `some` if available.
|
||||
// Aliased as `any`.
|
||||
var any = _.some = _.any = function(obj, iterator, context) {
|
||||
iterator || (iterator = _.identity);
|
||||
var result = false;
|
||||
if (obj == null) return result;
|
||||
if (nativeSome && obj.some === nativeSome) return obj.some(iterator, context);
|
||||
each(obj, function(value, index, list) {
|
||||
if (result || (result = iterator.call(context, value, index, list))) return breaker;
|
||||
});
|
||||
return !!result;
|
||||
};
|
||||
|
||||
// Determine if a given value is included in the array or object using `===`.
|
||||
// Aliased as `contains`.
|
||||
_.include = _.contains = function(obj, target) {
|
||||
var found = false;
|
||||
if (obj == null) return found;
|
||||
if (nativeIndexOf && obj.indexOf === nativeIndexOf) return obj.indexOf(target) != -1;
|
||||
found = any(obj, function(value) {
|
||||
return value === target;
|
||||
});
|
||||
return found;
|
||||
};
|
||||
|
||||
// Invoke a method (with arguments) on every item in a collection.
|
||||
_.invoke = function(obj, method) {
|
||||
var args = slice.call(arguments, 2);
|
||||
return _.map(obj, function(value) {
|
||||
return (_.isFunction(method) ? method || value : value[method]).apply(value, args);
|
||||
});
|
||||
};
|
||||
|
||||
// Convenience version of a common use case of `map`: fetching a property.
|
||||
_.pluck = function(obj, key) {
|
||||
return _.map(obj, function(value){ return value[key]; });
|
||||
};
|
||||
|
||||
// Return the maximum element or (element-based computation).
|
||||
_.max = function(obj, iterator, context) {
|
||||
if (!iterator && _.isArray(obj)) return Math.max.apply(Math, obj);
|
||||
if (!iterator && _.isEmpty(obj)) return -Infinity;
|
||||
var result = {computed : -Infinity};
|
||||
each(obj, function(value, index, list) {
|
||||
var computed = iterator ? iterator.call(context, value, index, list) : value;
|
||||
computed >= result.computed && (result = {value : value, computed : computed});
|
||||
});
|
||||
return result.value;
|
||||
};
|
||||
|
||||
// Return the minimum element (or element-based computation).
|
||||
_.min = function(obj, iterator, context) {
|
||||
if (!iterator && _.isArray(obj)) return Math.min.apply(Math, obj);
|
||||
if (!iterator && _.isEmpty(obj)) return Infinity;
|
||||
var result = {computed : Infinity};
|
||||
each(obj, function(value, index, list) {
|
||||
var computed = iterator ? iterator.call(context, value, index, list) : value;
|
||||
computed < result.computed && (result = {value : value, computed : computed});
|
||||
});
|
||||
return result.value;
|
||||
};
|
||||
|
||||
// Shuffle an array.
|
||||
_.shuffle = function(obj) {
|
||||
var shuffled = [], rand;
|
||||
each(obj, function(value, index, list) {
|
||||
if (index == 0) {
|
||||
shuffled[0] = value;
|
||||
} else {
|
||||
rand = Math.floor(Math.random() * (index + 1));
|
||||
shuffled[index] = shuffled[rand];
|
||||
shuffled[rand] = value;
|
||||
}
|
||||
});
|
||||
return shuffled;
|
||||
};
|
||||
|
||||
// Sort the object's values by a criterion produced by an iterator.
|
||||
_.sortBy = function(obj, iterator, context) {
|
||||
return _.pluck(_.map(obj, function(value, index, list) {
|
||||
return {
|
||||
value : value,
|
||||
criteria : iterator.call(context, value, index, list)
|
||||
};
|
||||
}).sort(function(left, right) {
|
||||
var a = left.criteria, b = right.criteria;
|
||||
return a < b ? -1 : a > b ? 1 : 0;
|
||||
}), 'value');
|
||||
};
|
||||
|
||||
// Groups the object's values by a criterion. Pass either a string attribute
|
||||
// to group by, or a function that returns the criterion.
|
||||
_.groupBy = function(obj, val) {
|
||||
var result = {};
|
||||
var iterator = _.isFunction(val) ? val : function(obj) { return obj[val]; };
|
||||
each(obj, function(value, index) {
|
||||
var key = iterator(value, index);
|
||||
(result[key] || (result[key] = [])).push(value);
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
// Use a comparator function to figure out at what index an object should
|
||||
// be inserted so as to maintain order. Uses binary search.
|
||||
_.sortedIndex = function(array, obj, iterator) {
|
||||
iterator || (iterator = _.identity);
|
||||
var low = 0, high = array.length;
|
||||
while (low < high) {
|
||||
var mid = (low + high) >> 1;
|
||||
iterator(array[mid]) < iterator(obj) ? low = mid + 1 : high = mid;
|
||||
}
|
||||
return low;
|
||||
};
|
||||
|
||||
// Safely convert anything iterable into a real, live array.
|
||||
_.toArray = function(iterable) {
|
||||
if (!iterable) return [];
|
||||
if (iterable.toArray) return iterable.toArray();
|
||||
if (_.isArray(iterable)) return slice.call(iterable);
|
||||
if (_.isArguments(iterable)) return slice.call(iterable);
|
||||
return _.values(iterable);
|
||||
};
|
||||
|
||||
// Return the number of elements in an object.
|
||||
_.size = function(obj) {
|
||||
return _.toArray(obj).length;
|
||||
};
|
||||
|
||||
// Array Functions
|
||||
// ---------------
|
||||
|
||||
// Get the first element of an array. Passing **n** will return the first N
|
||||
// values in the array. Aliased as `head`. The **guard** check allows it to work
|
||||
// with `_.map`.
|
||||
_.first = _.head = function(array, n, guard) {
|
||||
return (n != null) && !guard ? slice.call(array, 0, n) : array[0];
|
||||
};
|
||||
|
||||
// Returns everything but the last entry of the array. Especcialy useful on
|
||||
// the arguments object. Passing **n** will return all the values in
|
||||
// the array, excluding the last N. The **guard** check allows it to work with
|
||||
// `_.map`.
|
||||
_.initial = function(array, n, guard) {
|
||||
return slice.call(array, 0, array.length - ((n == null) || guard ? 1 : n));
|
||||
};
|
||||
|
||||
// Get the last element of an array. Passing **n** will return the last N
|
||||
// values in the array. The **guard** check allows it to work with `_.map`.
|
||||
_.last = function(array, n, guard) {
|
||||
if ((n != null) && !guard) {
|
||||
return slice.call(array, Math.max(array.length - n, 0));
|
||||
} else {
|
||||
return array[array.length - 1];
|
||||
}
|
||||
};
|
||||
|
||||
// Returns everything but the first entry of the array. Aliased as `tail`.
|
||||
// Especially useful on the arguments object. Passing an **index** will return
|
||||
// the rest of the values in the array from that index onward. The **guard**
|
||||
// check allows it to work with `_.map`.
|
||||
_.rest = _.tail = function(array, index, guard) {
|
||||
return slice.call(array, (index == null) || guard ? 1 : index);
|
||||
};
|
||||
|
||||
// Trim out all falsy values from an array.
|
||||
_.compact = function(array) {
|
||||
return _.filter(array, function(value){ return !!value; });
|
||||
};
|
||||
|
||||
// Return a completely flattened version of an array.
|
||||
_.flatten = function(array, shallow) {
|
||||
return _.reduce(array, function(memo, value) {
|
||||
if (_.isArray(value)) return memo.concat(shallow ? value : _.flatten(value));
|
||||
memo[memo.length] = value;
|
||||
return memo;
|
||||
}, []);
|
||||
};
|
||||
|
||||
// Return a version of the array that does not contain the specified value(s).
|
||||
_.without = function(array) {
|
||||
return _.difference(array, slice.call(arguments, 1));
|
||||
};
|
||||
|
||||
// Produce a duplicate-free version of the array. If the array has already
|
||||
// been sorted, you have the option of using a faster algorithm.
|
||||
// Aliased as `unique`.
|
||||
_.uniq = _.unique = function(array, isSorted, iterator) {
|
||||
var initial = iterator ? _.map(array, iterator) : array;
|
||||
var result = [];
|
||||
_.reduce(initial, function(memo, el, i) {
|
||||
if (0 == i || (isSorted === true ? _.last(memo) != el : !_.include(memo, el))) {
|
||||
memo[memo.length] = el;
|
||||
result[result.length] = array[i];
|
||||
}
|
||||
return memo;
|
||||
}, []);
|
||||
return result;
|
||||
};
|
||||
|
||||
// Produce an array that contains the union: each distinct element from all of
|
||||
// the passed-in arrays.
|
||||
_.union = function() {
|
||||
return _.uniq(_.flatten(arguments, true));
|
||||
};
|
||||
|
||||
// Produce an array that contains every item shared between all the
|
||||
// passed-in arrays. (Aliased as "intersect" for back-compat.)
|
||||
_.intersection = _.intersect = function(array) {
|
||||
var rest = slice.call(arguments, 1);
|
||||
return _.filter(_.uniq(array), function(item) {
|
||||
return _.every(rest, function(other) {
|
||||
return _.indexOf(other, item) >= 0;
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// Take the difference between one array and a number of other arrays.
|
||||
// Only the elements present in just the first array will remain.
|
||||
_.difference = function(array) {
|
||||
var rest = _.flatten(slice.call(arguments, 1));
|
||||
return _.filter(array, function(value){ return !_.include(rest, value); });
|
||||
};
|
||||
|
||||
// Zip together multiple lists into a single array -- elements that share
|
||||
// an index go together.
|
||||
_.zip = function() {
|
||||
var args = slice.call(arguments);
|
||||
var length = _.max(_.pluck(args, 'length'));
|
||||
var results = new Array(length);
|
||||
for (var i = 0; i < length; i++) results[i] = _.pluck(args, "" + i);
|
||||
return results;
|
||||
};
|
||||
|
||||
// If the browser doesn't supply us with indexOf (I'm looking at you, **MSIE**),
|
||||
// we need this function. Return the position of the first occurrence of an
|
||||
// item in an array, or -1 if the item is not included in the array.
|
||||
// Delegates to **ECMAScript 5**'s native `indexOf` if available.
|
||||
// If the array is large and already in sort order, pass `true`
|
||||
// for **isSorted** to use binary search.
|
||||
_.indexOf = function(array, item, isSorted) {
|
||||
if (array == null) return -1;
|
||||
var i, l;
|
||||
if (isSorted) {
|
||||
i = _.sortedIndex(array, item);
|
||||
return array[i] === item ? i : -1;
|
||||
}
|
||||
if (nativeIndexOf && array.indexOf === nativeIndexOf) return array.indexOf(item);
|
||||
for (i = 0, l = array.length; i < l; i++) if (i in array && array[i] === item) return i;
|
||||
return -1;
|
||||
};
|
||||
|
||||
// Delegates to **ECMAScript 5**'s native `lastIndexOf` if available.
|
||||
_.lastIndexOf = function(array, item) {
|
||||
if (array == null) return -1;
|
||||
if (nativeLastIndexOf && array.lastIndexOf === nativeLastIndexOf) return array.lastIndexOf(item);
|
||||
var i = array.length;
|
||||
while (i--) if (i in array && array[i] === item) return i;
|
||||
return -1;
|
||||
};
|
||||
|
||||
// Generate an integer Array containing an arithmetic progression. A port of
|
||||
// the native Python `range()` function. See
|
||||
// [the Python documentation](http://docs.python.org/library/functions.html#range).
|
||||
_.range = function(start, stop, step) {
|
||||
if (arguments.length <= 1) {
|
||||
stop = start || 0;
|
||||
start = 0;
|
||||
}
|
||||
step = arguments[2] || 1;
|
||||
|
||||
var len = Math.max(Math.ceil((stop - start) / step), 0);
|
||||
var idx = 0;
|
||||
var range = new Array(len);
|
||||
|
||||
while(idx < len) {
|
||||
range[idx++] = start;
|
||||
start += step;
|
||||
}
|
||||
|
||||
return range;
|
||||
};
|
||||
|
||||
// Function (ahem) Functions
|
||||
// ------------------
|
||||
|
||||
// Reusable constructor function for prototype setting.
|
||||
var ctor = function(){};
|
||||
|
||||
// Create a function bound to a given object (assigning `this`, and arguments,
|
||||
// optionally). Binding with arguments is also known as `curry`.
|
||||
// Delegates to **ECMAScript 5**'s native `Function.bind` if available.
|
||||
// We check for `func.bind` first, to fail fast when `func` is undefined.
|
||||
_.bind = function bind(func, context) {
|
||||
var bound, args;
|
||||
if (func.bind === nativeBind && nativeBind) return nativeBind.apply(func, slice.call(arguments, 1));
|
||||
if (!_.isFunction(func)) throw new TypeError;
|
||||
args = slice.call(arguments, 2);
|
||||
return bound = function() {
|
||||
if (!(this instanceof bound)) return func.apply(context, args.concat(slice.call(arguments)));
|
||||
ctor.prototype = func.prototype;
|
||||
var self = new ctor;
|
||||
var result = func.apply(self, args.concat(slice.call(arguments)));
|
||||
if (Object(result) === result) return result;
|
||||
return self;
|
||||
};
|
||||
};
|
||||
|
||||
// Bind all of an object's methods to that object. Useful for ensuring that
|
||||
// all callbacks defined on an object belong to it.
|
||||
_.bindAll = function(obj) {
|
||||
var funcs = slice.call(arguments, 1);
|
||||
if (funcs.length == 0) funcs = _.functions(obj);
|
||||
each(funcs, function(f) { obj[f] = _.bind(obj[f], obj); });
|
||||
return obj;
|
||||
};
|
||||
|
||||
// Memoize an expensive function by storing its results.
|
||||
_.memoize = function(func, hasher) {
|
||||
var memo = {};
|
||||
hasher || (hasher = _.identity);
|
||||
return function() {
|
||||
var key = hasher.apply(this, arguments);
|
||||
return _.has(memo, key) ? memo[key] : (memo[key] = func.apply(this, arguments));
|
||||
};
|
||||
};
|
||||
|
||||
// Delays a function for the given number of milliseconds, and then calls
|
||||
// it with the arguments supplied.
|
||||
_.delay = function(func, wait) {
|
||||
var args = slice.call(arguments, 2);
|
||||
return setTimeout(function(){ return func.apply(func, args); }, wait);
|
||||
};
|
||||
|
||||
// Defers a function, scheduling it to run after the current call stack has
|
||||
// cleared.
|
||||
_.defer = function(func) {
|
||||
return _.delay.apply(_, [func, 1].concat(slice.call(arguments, 1)));
|
||||
};
|
||||
|
||||
// Returns a function, that, when invoked, will only be triggered at most once
|
||||
// during a given window of time.
|
||||
_.throttle = function(func, wait) {
|
||||
var context, args, timeout, throttling, more;
|
||||
var whenDone = _.debounce(function(){ more = throttling = false; }, wait);
|
||||
return function() {
|
||||
context = this; args = arguments;
|
||||
var later = function() {
|
||||
timeout = null;
|
||||
if (more) func.apply(context, args);
|
||||
whenDone();
|
||||
};
|
||||
if (!timeout) timeout = setTimeout(later, wait);
|
||||
if (throttling) {
|
||||
more = true;
|
||||
} else {
|
||||
func.apply(context, args);
|
||||
}
|
||||
whenDone();
|
||||
throttling = true;
|
||||
};
|
||||
};
|
||||
|
||||
// Returns a function, that, as long as it continues to be invoked, will not
|
||||
// be triggered. The function will be called after it stops being called for
|
||||
// N milliseconds.
|
||||
_.debounce = function(func, wait) {
|
||||
var timeout;
|
||||
return function() {
|
||||
var context = this, args = arguments;
|
||||
var later = function() {
|
||||
timeout = null;
|
||||
func.apply(context, args);
|
||||
};
|
||||
clearTimeout(timeout);
|
||||
timeout = setTimeout(later, wait);
|
||||
};
|
||||
};
|
||||
|
||||
// Returns a function that will be executed at most one time, no matter how
|
||||
// often you call it. Useful for lazy initialization.
|
||||
_.once = function(func) {
|
||||
var ran = false, memo;
|
||||
return function() {
|
||||
if (ran) return memo;
|
||||
ran = true;
|
||||
return memo = func.apply(this, arguments);
|
||||
};
|
||||
};
|
||||
|
||||
// Returns the first function passed as an argument to the second,
|
||||
// allowing you to adjust arguments, run code before and after, and
|
||||
// conditionally execute the original function.
|
||||
_.wrap = function(func, wrapper) {
|
||||
return function() {
|
||||
var args = [func].concat(slice.call(arguments, 0));
|
||||
return wrapper.apply(this, args);
|
||||
};
|
||||
};
|
||||
|
||||
// Returns a function that is the composition of a list of functions, each
|
||||
// consuming the return value of the function that follows.
|
||||
_.compose = function() {
|
||||
var funcs = arguments;
|
||||
return function() {
|
||||
var args = arguments;
|
||||
for (var i = funcs.length - 1; i >= 0; i--) {
|
||||
args = [funcs[i].apply(this, args)];
|
||||
}
|
||||
return args[0];
|
||||
};
|
||||
};
|
||||
|
||||
// Returns a function that will only be executed after being called N times.
|
||||
_.after = function(times, func) {
|
||||
if (times <= 0) return func();
|
||||
return function() {
|
||||
if (--times < 1) { return func.apply(this, arguments); }
|
||||
};
|
||||
};
|
||||
|
||||
// Object Functions
|
||||
// ----------------
|
||||
|
||||
// Retrieve the names of an object's properties.
|
||||
// Delegates to **ECMAScript 5**'s native `Object.keys`
|
||||
_.keys = nativeKeys || function(obj) {
|
||||
if (obj !== Object(obj)) throw new TypeError('Invalid object');
|
||||
var keys = [];
|
||||
for (var key in obj) if (_.has(obj, key)) keys[keys.length] = key;
|
||||
return keys;
|
||||
};
|
||||
|
||||
// Retrieve the values of an object's properties.
|
||||
_.values = function(obj) {
|
||||
return _.map(obj, _.identity);
|
||||
};
|
||||
|
||||
// Return a sorted list of the function names available on the object.
|
||||
// Aliased as `methods`
|
||||
_.functions = _.methods = function(obj) {
|
||||
var names = [];
|
||||
for (var key in obj) {
|
||||
if (_.isFunction(obj[key])) names.push(key);
|
||||
}
|
||||
return names.sort();
|
||||
};
|
||||
|
||||
// Extend a given object with all the properties in passed-in object(s).
|
||||
_.extend = function(obj) {
|
||||
each(slice.call(arguments, 1), function(source) {
|
||||
for (var prop in source) {
|
||||
obj[prop] = source[prop];
|
||||
}
|
||||
});
|
||||
return obj;
|
||||
};
|
||||
|
||||
// Fill in a given object with default properties.
|
||||
_.defaults = function(obj) {
|
||||
each(slice.call(arguments, 1), function(source) {
|
||||
for (var prop in source) {
|
||||
if (obj[prop] == null) obj[prop] = source[prop];
|
||||
}
|
||||
});
|
||||
return obj;
|
||||
};
|
||||
|
||||
// Create a (shallow-cloned) duplicate of an object.
|
||||
_.clone = function(obj) {
|
||||
if (!_.isObject(obj)) return obj;
|
||||
return _.isArray(obj) ? obj.slice() : _.extend({}, obj);
|
||||
};
|
||||
|
||||
// Invokes interceptor with the obj, and then returns obj.
|
||||
// The primary purpose of this method is to "tap into" a method chain, in
|
||||
// order to perform operations on intermediate results within the chain.
|
||||
_.tap = function(obj, interceptor) {
|
||||
interceptor(obj);
|
||||
return obj;
|
||||
};
|
||||
|
||||
// Internal recursive comparison function.
|
||||
function eq(a, b, stack) {
|
||||
// Identical objects are equal. `0 === -0`, but they aren't identical.
|
||||
// See the Harmony `egal` proposal: http://wiki.ecmascript.org/doku.php?id=harmony:egal.
|
||||
if (a === b) return a !== 0 || 1 / a == 1 / b;
|
||||
// A strict comparison is necessary because `null == undefined`.
|
||||
if (a == null || b == null) return a === b;
|
||||
// Unwrap any wrapped objects.
|
||||
if (a._chain) a = a._wrapped;
|
||||
if (b._chain) b = b._wrapped;
|
||||
// Invoke a custom `isEqual` method if one is provided.
|
||||
if (a.isEqual && _.isFunction(a.isEqual)) return a.isEqual(b);
|
||||
if (b.isEqual && _.isFunction(b.isEqual)) return b.isEqual(a);
|
||||
// Compare `[[Class]]` names.
|
||||
var className = toString.call(a);
|
||||
if (className != toString.call(b)) return false;
|
||||
switch (className) {
|
||||
// Strings, numbers, dates, and booleans are compared by value.
|
||||
case '[object String]':
|
||||
// Primitives and their corresponding object wrappers are equivalent; thus, `"5"` is
|
||||
// equivalent to `new String("5")`.
|
||||
return a == String(b);
|
||||
case '[object Number]':
|
||||
// `NaN`s are equivalent, but non-reflexive. An `egal` comparison is performed for
|
||||
// other numeric values.
|
||||
return a != +a ? b != +b : (a == 0 ? 1 / a == 1 / b : a == +b);
|
||||
case '[object Date]':
|
||||
case '[object Boolean]':
|
||||
// Coerce dates and booleans to numeric primitive values. Dates are compared by their
|
||||
// millisecond representations. Note that invalid dates with millisecond representations
|
||||
// of `NaN` are not equivalent.
|
||||
return +a == +b;
|
||||
// RegExps are compared by their source patterns and flags.
|
||||
case '[object RegExp]':
|
||||
return a.source == b.source &&
|
||||
a.global == b.global &&
|
||||
a.multiline == b.multiline &&
|
||||
a.ignoreCase == b.ignoreCase;
|
||||
}
|
||||
if (typeof a != 'object' || typeof b != 'object') return false;
|
||||
// Assume equality for cyclic structures. The algorithm for detecting cyclic
|
||||
// structures is adapted from ES 5.1 section 15.12.3, abstract operation `JO`.
|
||||
var length = stack.length;
|
||||
while (length--) {
|
||||
// Linear search. Performance is inversely proportional to the number of
|
||||
// unique nested structures.
|
||||
if (stack[length] == a) return true;
|
||||
}
|
||||
// Add the first object to the stack of traversed objects.
|
||||
stack.push(a);
|
||||
var size = 0, result = true;
|
||||
// Recursively compare objects and arrays.
|
||||
if (className == '[object Array]') {
|
||||
// Compare array lengths to determine if a deep comparison is necessary.
|
||||
size = a.length;
|
||||
result = size == b.length;
|
||||
if (result) {
|
||||
// Deep compare the contents, ignoring non-numeric properties.
|
||||
while (size--) {
|
||||
// Ensure commutative equality for sparse arrays.
|
||||
if (!(result = size in a == size in b && eq(a[size], b[size], stack))) break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Objects with different constructors are not equivalent.
|
||||
if ('constructor' in a != 'constructor' in b || a.constructor != b.constructor) return false;
|
||||
// Deep compare objects.
|
||||
for (var key in a) {
|
||||
if (_.has(a, key)) {
|
||||
// Count the expected number of properties.
|
||||
size++;
|
||||
// Deep compare each member.
|
||||
if (!(result = _.has(b, key) && eq(a[key], b[key], stack))) break;
|
||||
}
|
||||
}
|
||||
// Ensure that both objects contain the same number of properties.
|
||||
if (result) {
|
||||
for (key in b) {
|
||||
if (_.has(b, key) && !(size--)) break;
|
||||
}
|
||||
result = !size;
|
||||
}
|
||||
}
|
||||
// Remove the first object from the stack of traversed objects.
|
||||
stack.pop();
|
||||
return result;
|
||||
}
|
||||
|
||||
// Perform a deep comparison to check if two objects are equal.
|
||||
_.isEqual = function(a, b) {
|
||||
return eq(a, b, []);
|
||||
};
|
||||
|
||||
// Is a given array, string, or object empty?
|
||||
// An "empty" object has no enumerable own-properties.
|
||||
_.isEmpty = function(obj) {
|
||||
if (_.isArray(obj) || _.isString(obj)) return obj.length === 0;
|
||||
for (var key in obj) if (_.has(obj, key)) return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
// Is a given value a DOM element?
|
||||
_.isElement = function(obj) {
|
||||
return !!(obj && obj.nodeType == 1);
|
||||
};
|
||||
|
||||
// Is a given value an array?
|
||||
// Delegates to ECMA5's native Array.isArray
|
||||
_.isArray = nativeIsArray || function(obj) {
|
||||
return toString.call(obj) == '[object Array]';
|
||||
};
|
||||
|
||||
// Is a given variable an object?
|
||||
_.isObject = function(obj) {
|
||||
return obj === Object(obj);
|
||||
};
|
||||
|
||||
// Is a given variable an arguments object?
|
||||
_.isArguments = function(obj) {
|
||||
return toString.call(obj) == '[object Arguments]';
|
||||
};
|
||||
if (!_.isArguments(arguments)) {
|
||||
_.isArguments = function(obj) {
|
||||
return !!(obj && _.has(obj, 'callee'));
|
||||
};
|
||||
}
|
||||
|
||||
// Is a given value a function?
|
||||
_.isFunction = function(obj) {
|
||||
return toString.call(obj) == '[object Function]';
|
||||
};
|
||||
|
||||
// Is a given value a string?
|
||||
_.isString = function(obj) {
|
||||
return toString.call(obj) == '[object String]';
|
||||
};
|
||||
|
||||
// Is a given value a number?
|
||||
_.isNumber = function(obj) {
|
||||
return toString.call(obj) == '[object Number]';
|
||||
};
|
||||
|
||||
// Is the given value `NaN`?
|
||||
_.isNaN = function(obj) {
|
||||
// `NaN` is the only value for which `===` is not reflexive.
|
||||
return obj !== obj;
|
||||
};
|
||||
|
||||
// Is a given value a boolean?
|
||||
_.isBoolean = function(obj) {
|
||||
return obj === true || obj === false || toString.call(obj) == '[object Boolean]';
|
||||
};
|
||||
|
||||
// Is a given value a date?
|
||||
_.isDate = function(obj) {
|
||||
return toString.call(obj) == '[object Date]';
|
||||
};
|
||||
|
||||
// Is the given value a regular expression?
|
||||
_.isRegExp = function(obj) {
|
||||
return toString.call(obj) == '[object RegExp]';
|
||||
};
|
||||
|
||||
// Is a given value equal to null?
|
||||
_.isNull = function(obj) {
|
||||
return obj === null;
|
||||
};
|
||||
|
||||
// Is a given variable undefined?
|
||||
_.isUndefined = function(obj) {
|
||||
return obj === void 0;
|
||||
};
|
||||
|
||||
// Has own property?
|
||||
_.has = function(obj, key) {
|
||||
return hasOwnProperty.call(obj, key);
|
||||
};
|
||||
|
||||
// Utility Functions
|
||||
// -----------------
|
||||
|
||||
// Run Underscore.js in *noConflict* mode, returning the `_` variable to its
|
||||
// previous owner. Returns a reference to the Underscore object.
|
||||
_.noConflict = function() {
|
||||
root._ = previousUnderscore;
|
||||
return this;
|
||||
};
|
||||
|
||||
// Keep the identity function around for default iterators.
|
||||
_.identity = function(value) {
|
||||
return value;
|
||||
};
|
||||
|
||||
// Run a function **n** times.
|
||||
_.times = function (n, iterator, context) {
|
||||
for (var i = 0; i < n; i++) iterator.call(context, i);
|
||||
};
|
||||
|
||||
// Escape a string for HTML interpolation.
|
||||
_.escape = function(string) {
|
||||
return (''+string).replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>').replace(/"/g, '"').replace(/'/g, ''').replace(/\//g,'/');
|
||||
};
|
||||
|
||||
// Add your own custom functions to the Underscore object, ensuring that
|
||||
// they're correctly added to the OOP wrapper as well.
|
||||
_.mixin = function(obj) {
|
||||
each(_.functions(obj), function(name){
|
||||
addToWrapper(name, _[name] = obj[name]);
|
||||
});
|
||||
};
|
||||
|
||||
// Generate a unique integer id (unique within the entire client session).
|
||||
// Useful for temporary DOM ids.
|
||||
var idCounter = 0;
|
||||
_.uniqueId = function(prefix) {
|
||||
var id = idCounter++;
|
||||
return prefix ? prefix + id : id;
|
||||
};
|
||||
|
||||
// By default, Underscore uses ERB-style template delimiters, change the
|
||||
// following template settings to use alternative delimiters.
|
||||
_.templateSettings = {
|
||||
evaluate : /<%([\s\S]+?)%>/g,
|
||||
interpolate : /<%=([\s\S]+?)%>/g,
|
||||
escape : /<%-([\s\S]+?)%>/g
|
||||
};
|
||||
|
||||
// When customizing `templateSettings`, if you don't want to define an
|
||||
// interpolation, evaluation or escaping regex, we need one that is
|
||||
// guaranteed not to match.
|
||||
var noMatch = /.^/;
|
||||
|
||||
// Within an interpolation, evaluation, or escaping, remove HTML escaping
|
||||
// that had been previously added.
|
||||
var unescape = function(code) {
|
||||
return code.replace(/\\\\/g, '\\').replace(/\\'/g, "'");
|
||||
};
|
||||
|
||||
// JavaScript micro-templating, similar to John Resig's implementation.
|
||||
// Underscore templating handles arbitrary delimiters, preserves whitespace,
|
||||
// and correctly escapes quotes within interpolated code.
|
||||
_.template = function(str, data) {
|
||||
var c = _.templateSettings;
|
||||
var tmpl = 'var __p=[],print=function(){__p.push.apply(__p,arguments);};' +
|
||||
'with(obj||{}){__p.push(\'' +
|
||||
str.replace(/\\/g, '\\\\')
|
||||
.replace(/'/g, "\\'")
|
||||
.replace(c.escape || noMatch, function(match, code) {
|
||||
return "',_.escape(" + unescape(code) + "),'";
|
||||
})
|
||||
.replace(c.interpolate || noMatch, function(match, code) {
|
||||
return "'," + unescape(code) + ",'";
|
||||
})
|
||||
.replace(c.evaluate || noMatch, function(match, code) {
|
||||
return "');" + unescape(code).replace(/[\r\n\t]/g, ' ') + ";__p.push('";
|
||||
})
|
||||
.replace(/\r/g, '\\r')
|
||||
.replace(/\n/g, '\\n')
|
||||
.replace(/\t/g, '\\t')
|
||||
+ "');}return __p.join('');";
|
||||
var func = new Function('obj', '_', tmpl);
|
||||
if (data) return func(data, _);
|
||||
return function(data) {
|
||||
return func.call(this, data, _);
|
||||
};
|
||||
};
|
||||
|
||||
// Add a "chain" function, which will delegate to the wrapper.
|
||||
_.chain = function(obj) {
|
||||
return _(obj).chain();
|
||||
};
|
||||
|
||||
// The OOP Wrapper
|
||||
// ---------------
|
||||
|
||||
// If Underscore is called as a function, it returns a wrapped object that
|
||||
// can be used OO-style. This wrapper holds altered versions of all the
|
||||
// underscore functions. Wrapped objects may be chained.
|
||||
var wrapper = function(obj) { this._wrapped = obj; };
|
||||
|
||||
// Expose `wrapper.prototype` as `_.prototype`
|
||||
_.prototype = wrapper.prototype;
|
||||
|
||||
// Helper function to continue chaining intermediate results.
|
||||
var result = function(obj, chain) {
|
||||
return chain ? _(obj).chain() : obj;
|
||||
};
|
||||
|
||||
// A method to easily add functions to the OOP wrapper.
|
||||
var addToWrapper = function(name, func) {
|
||||
wrapper.prototype[name] = function() {
|
||||
var args = slice.call(arguments);
|
||||
unshift.call(args, this._wrapped);
|
||||
return result(func.apply(_, args), this._chain);
|
||||
};
|
||||
};
|
||||
|
||||
// Add all of the Underscore functions to the wrapper object.
|
||||
_.mixin(_);
|
||||
|
||||
// Add all mutator Array functions to the wrapper.
|
||||
each(['pop', 'push', 'reverse', 'shift', 'sort', 'splice', 'unshift'], function(name) {
|
||||
var method = ArrayProto[name];
|
||||
wrapper.prototype[name] = function() {
|
||||
var wrapped = this._wrapped;
|
||||
method.apply(wrapped, arguments);
|
||||
var length = wrapped.length;
|
||||
if ((name == 'shift' || name == 'splice') && length === 0) delete wrapped[0];
|
||||
return result(wrapped, this._chain);
|
||||
};
|
||||
});
|
||||
|
||||
// Add all accessor Array functions to the wrapper.
|
||||
each(['concat', 'join', 'slice'], function(name) {
|
||||
var method = ArrayProto[name];
|
||||
wrapper.prototype[name] = function() {
|
||||
return result(method.apply(this._wrapped, arguments), this._chain);
|
||||
};
|
||||
});
|
||||
|
||||
// Start chaining a wrapped Underscore object.
|
||||
wrapper.prototype.chain = function() {
|
||||
this._chain = true;
|
||||
return this;
|
||||
};
|
||||
|
||||
// Extracts the result from a wrapped and chained object.
|
||||
wrapper.prototype.value = function() {
|
||||
return this._wrapped;
|
||||
};
|
||||
|
||||
}).call(this);
|
697
tutorials/01-vector-add.html
Normal file
697
tutorials/01-vector-add.html
Normal file
@@ -0,0 +1,697 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Vector Addition — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
|
||||
<link rel="prev" title="From Source" href="../installation/from-source.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Vector Addition</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../_sources/tutorials/01-vector-add.ipynb.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS for nbsphinx extension */
|
||||
|
||||
/* remove conflicting styling from Sphinx themes */
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt *,
|
||||
div.nbinput.container div.input_area pre,
|
||||
div.nboutput.container div.output_area pre,
|
||||
div.nbinput.container div.input_area .highlight,
|
||||
div.nboutput.container div.output_area .highlight {
|
||||
border: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
div.nbinput.container > div[class*=highlight],
|
||||
div.nboutput.container > div[class*=highlight] {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt * {
|
||||
background: none;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area .highlight,
|
||||
div.nboutput.container div.output_area pre {
|
||||
background: unset;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area div.highlight {
|
||||
color: unset; /* override Pygments text color */
|
||||
}
|
||||
|
||||
/* avoid gaps between output lines */
|
||||
div.nboutput.container div[class*=highlight] pre {
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
/* input/output containers */
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
display: -webkit-flex;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin: 0;
|
||||
width: 100%;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
|
||||
/* input container */
|
||||
div.nbinput.container {
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
/* last container */
|
||||
div.nblast.container {
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
|
||||
/* input prompt */
|
||||
div.nbinput.container div.prompt pre {
|
||||
color: #307FC1;
|
||||
}
|
||||
|
||||
/* output prompt */
|
||||
div.nboutput.container div.prompt pre {
|
||||
color: #BF5B3D;
|
||||
}
|
||||
|
||||
/* all prompts */
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: 4.5ex;
|
||||
padding-top: 5px;
|
||||
position: relative;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
margin-right: 0.3ex;
|
||||
}
|
||||
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: unset;
|
||||
text-align: left;
|
||||
padding: 0.4em;
|
||||
}
|
||||
div.nboutput.container div.prompt.empty {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: unset;
|
||||
}
|
||||
}
|
||||
|
||||
/* disable scrollbars on prompts */
|
||||
div.nbinput.container div.prompt pre,
|
||||
div.nboutput.container div.prompt pre {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* input/output area */
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
-webkit-flex: 1;
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
/* input area */
|
||||
div.nbinput.container div.input_area {
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 2px;
|
||||
/*background: #f5f5f5;*/
|
||||
}
|
||||
|
||||
/* override MathJax center alignment in output cells */
|
||||
div.nboutput.container div[class*=MathJax] {
|
||||
text-align: left !important;
|
||||
}
|
||||
|
||||
/* override sphinx.ext.imgmath center alignment in output cells */
|
||||
div.nboutput.container div.math p {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
/* standard error */
|
||||
div.nboutput.container div.output_area.stderr {
|
||||
background: #fdd;
|
||||
}
|
||||
|
||||
/* ANSI colors */
|
||||
.ansi-black-fg { color: #3E424D; }
|
||||
.ansi-black-bg { background-color: #3E424D; }
|
||||
.ansi-black-intense-fg { color: #282C36; }
|
||||
.ansi-black-intense-bg { background-color: #282C36; }
|
||||
.ansi-red-fg { color: #E75C58; }
|
||||
.ansi-red-bg { background-color: #E75C58; }
|
||||
.ansi-red-intense-fg { color: #B22B31; }
|
||||
.ansi-red-intense-bg { background-color: #B22B31; }
|
||||
.ansi-green-fg { color: #00A250; }
|
||||
.ansi-green-bg { background-color: #00A250; }
|
||||
.ansi-green-intense-fg { color: #007427; }
|
||||
.ansi-green-intense-bg { background-color: #007427; }
|
||||
.ansi-yellow-fg { color: #DDB62B; }
|
||||
.ansi-yellow-bg { background-color: #DDB62B; }
|
||||
.ansi-yellow-intense-fg { color: #B27D12; }
|
||||
.ansi-yellow-intense-bg { background-color: #B27D12; }
|
||||
.ansi-blue-fg { color: #208FFB; }
|
||||
.ansi-blue-bg { background-color: #208FFB; }
|
||||
.ansi-blue-intense-fg { color: #0065CA; }
|
||||
.ansi-blue-intense-bg { background-color: #0065CA; }
|
||||
.ansi-magenta-fg { color: #D160C4; }
|
||||
.ansi-magenta-bg { background-color: #D160C4; }
|
||||
.ansi-magenta-intense-fg { color: #A03196; }
|
||||
.ansi-magenta-intense-bg { background-color: #A03196; }
|
||||
.ansi-cyan-fg { color: #60C6C8; }
|
||||
.ansi-cyan-bg { background-color: #60C6C8; }
|
||||
.ansi-cyan-intense-fg { color: #258F8F; }
|
||||
.ansi-cyan-intense-bg { background-color: #258F8F; }
|
||||
.ansi-white-fg { color: #C5C1B4; }
|
||||
.ansi-white-bg { background-color: #C5C1B4; }
|
||||
.ansi-white-intense-fg { color: #A1A6B2; }
|
||||
.ansi-white-intense-bg { background-color: #A1A6B2; }
|
||||
|
||||
.ansi-default-inverse-fg { color: #FFFFFF; }
|
||||
.ansi-default-inverse-bg { background-color: #000000; }
|
||||
|
||||
.ansi-bold { font-weight: bold; }
|
||||
.ansi-underline { text-decoration: underline; }
|
||||
|
||||
|
||||
div.nbinput.container div.input_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight].math,
|
||||
div.nboutput.container div.output_area.rendered_html,
|
||||
div.nboutput.container div.output_area > div.output_javascript,
|
||||
div.nboutput.container div.output_area:not(.rendered_html) > img{
|
||||
padding: 5px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
|
||||
div.nbinput.container div.input_area > div[class^='highlight'],
|
||||
div.nboutput.container div.output_area > div[class^='highlight']{
|
||||
overflow-y: hidden;
|
||||
}
|
||||
|
||||
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
|
||||
.prompt a.copybtn {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Some additional styling taken form the Jupyter notebook CSS */
|
||||
div.rendered_html table {
|
||||
border: none;
|
||||
border-collapse: collapse;
|
||||
border-spacing: 0;
|
||||
color: black;
|
||||
font-size: 12px;
|
||||
table-layout: fixed;
|
||||
}
|
||||
div.rendered_html thead {
|
||||
border-bottom: 1px solid black;
|
||||
vertical-align: bottom;
|
||||
}
|
||||
div.rendered_html tr,
|
||||
div.rendered_html th,
|
||||
div.rendered_html td {
|
||||
text-align: right;
|
||||
vertical-align: middle;
|
||||
padding: 0.5em 0.5em;
|
||||
line-height: normal;
|
||||
white-space: normal;
|
||||
max-width: none;
|
||||
border: none;
|
||||
}
|
||||
div.rendered_html th {
|
||||
font-weight: bold;
|
||||
}
|
||||
div.rendered_html tbody tr:nth-child(odd) {
|
||||
background: #f5f5f5;
|
||||
}
|
||||
div.rendered_html tbody tr:hover {
|
||||
background: rgba(66, 165, 245, 0.2);
|
||||
}
|
||||
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="Vector-Addition">
|
||||
<h1>Vector Addition<a class="headerlink" href="#Vector-Addition" title="Permalink to this headline">¶</a></h1>
|
||||
<p>In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: * The basic syntax of the Triton programming language * The best practices for creating PyTorch custom operators using the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API * The best practices for validating and benchmarking custom ops against native reference implementations</p>
|
||||
<div class="section" id="Writing-the-Compute-Kernel">
|
||||
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Each compute kernel is declared using the <code class="docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel on different chunks of data (See the <a class="reference external" href="https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a> programming model for more details).</p>
|
||||
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
|
||||
<span class="c1">// of the program in the overaching SPMD context</span>
|
||||
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
|
||||
<span class="c1">// different chunks of data in parallel.</span>
|
||||
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
|
||||
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
|
||||
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
|
||||
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
|
||||
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
|
||||
<span class="c1">// memory.</span>
|
||||
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
|
||||
<span class="c1">// whose memory values are, e.g.:</span>
|
||||
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
|
||||
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
|
||||
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||||
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
|
||||
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
|
||||
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
|
||||
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
|
||||
<span class="c1">// write-back `z`</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL’2019 Triton paper</a>.</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-the-Torch-bindings">
|
||||
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The only thing that matters when it comes to Triton and Torch is the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects.</p>
|
||||
<p>To create a <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things: * <code class="docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create * <code class="docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for * <code class="docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <code class="docutils literal notranslate"><span class="pre">#define</span></code> for you</p>
|
||||
<p>Note: The constructor of <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see <code class="docutils literal notranslate"><span class="pre">_kernels</span></code> variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator’s inputs.</p>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[10]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># source-code for Triton compute kernel</span>
|
||||
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
|
||||
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
|
||||
<span class="s2"> // program id</span>
|
||||
<span class="s2"> int pid = get_program_id(0);</span>
|
||||
<span class="s2"> // create arrays of pointers</span>
|
||||
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
|
||||
<span class="s2"> float* px[BLOCK] = x + offset;</span>
|
||||
<span class="s2"> float* py[BLOCK] = y + offset;</span>
|
||||
<span class="s2"> // bounds checking</span>
|
||||
<span class="s2"> bool check[BLOCK] = offset < N;</span>
|
||||
<span class="s2"> // write-back</span>
|
||||
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2"> """</span>
|
||||
<span class="c1"># This function returns a callable `triton.kernel` object</span>
|
||||
<span class="c1"># created from the above source code.</span>
|
||||
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
|
||||
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
|
||||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
|
||||
<span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
|
||||
|
||||
<span class="c1"># This is a standard torch custom autograd Function</span>
|
||||
<span class="c1"># The only difference is that we can now use the above kernel</span>
|
||||
<span class="c1"># in the `forward` and `backward` functions.`</span>
|
||||
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="c1"># *allocate output*</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
|
||||
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
|
||||
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
|
||||
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
|
||||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
|
||||
<span class="c1"># the `.data_ptr()` method</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">z</span>
|
||||
<span class="c1"># Just like we standard PyTorch ops</span>
|
||||
<span class="c1"># We use the `.apply` method to create a</span>
|
||||
<span class="c1"># callable object for our function</span>
|
||||
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<p>At this point <code class="docutils literal notranslate"><span class="pre">add(x,</span> <span class="pre">y)</span></code> is equivalent to <code class="docutils literal notranslate"><span class="pre">x</span> <span class="pre">+</span> <span class="pre">y</span></code> for contiguous tensors. Now let’s test and benchmark it!</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Unit-Test">
|
||||
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<div class="highlight"><pre>
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||||
The maximum difference between torch and triton is 0.0
|
||||
</pre></div></div>
|
||||
</div>
|
||||
<p>Seems to work!</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Benchmark">
|
||||
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline">¶</a></h2>
|
||||
<p>The performance of our GPU code can be benchmark using the <code class="docutils literal notranslate"><span class="pre">torch.cuda.Event(enable_timing=True)</span></code> wrapper. Below is a simple function that benchmarks <code class="docutils literal notranslate"><span class="pre">rep</span></code> runs of our kernels after <code class="docutils literal notranslate"><span class="pre">warmup</span></code> “cold” runs.</p>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[11]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="c1"># We now want to benchmark the performance of `add`</span>
|
||||
<span class="c1"># Against that of PyTorch for increasing vector sizes</span>
|
||||
<span class="k">def</span> <span class="nf">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">warmup</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">rep</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
|
||||
<span class="n">start_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">end_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">ret</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">warmup</span><span class="p">):</span>
|
||||
<span class="n">fn</span><span class="p">()</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||||
<span class="n">start_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rep</span><span class="p">):</span>
|
||||
<span class="n">fn</span><span class="p">()</span>
|
||||
<span class="n">end_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||||
<span class="n">time_ms</span> <span class="o">=</span> <span class="n">start_event</span><span class="o">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="n">end_event</span><span class="p">)</span> <span class="o">/</span> <span class="n">rep</span>
|
||||
<span class="k">return</span> <span class="n">time_ms</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does</p>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[15]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="mi">26</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">triton_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||||
<span class="n">torch_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
|
||||
<span class="c1"># print the performance of triton and torch as well as the achieved bandwidth</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">N</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">triton_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">torch_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<div class="highlight"><pre>
|
||||
131072 0.020 0.003
|
||||
262144 0.019 0.004
|
||||
524288 0.016 0.016
|
||||
1048576 0.033 0.033
|
||||
2097152 0.071 0.070
|
||||
4194304 0.142 0.144
|
||||
8388608 0.287 0.286
|
||||
16777216 0.572 0.568
|
||||
33554432 1.139 1.110
|
||||
</pre></div></div>
|
||||
</div>
|
||||
<p>Our op is on-par with Torch’s vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="../installation/from-source.html" class="btn btn-neutral float-left" title="From Source" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
329
tutorials/01-vector-add.ipynb
Normal file
329
tutorials/01-vector-add.ipynb
Normal file
@@ -0,0 +1,329 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "acute-possession",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Vector Addition"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "median-malaysia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
||||
"* The basic syntax of the Triton programming language\n",
|
||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "identical-conditions",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Compute Kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "collectible-belle",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```c\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
||||
" // of the program in the overaching SPMD context\n",
|
||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
||||
" // different chunks of data in parallel.\n",
|
||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
||||
" // is similar to blockIdx.{x,y,z}\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
||||
" // they are primitives data-types and are -- contrary to C and\n",
|
||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
||||
" // memory.\n",
|
||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
||||
" // whose memory values are, e.g.:\n",
|
||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz [BLOCK] = z + offset;\n",
|
||||
" float* px [BLOCK] = x + offset;\n",
|
||||
" float* py [BLOCK] = y + offset;\n",
|
||||
" // Simple element-wise control-flow for load/store operations can\n",
|
||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
||||
" // write-back `z`\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "forbidden-wednesday",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing the Torch bindings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "numerical-agency",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
||||
"\n",
|
||||
"To create a `triton.kernel`, you only need three things:\n",
|
||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
||||
"\n",
|
||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "sporting-keyboard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import triton\n",
|
||||
"\n",
|
||||
"# source-code for Triton compute kernel\n",
|
||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
||||
"_src = \"\"\"\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // program id\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // create arrays of pointers\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz[BLOCK] = z + offset;\n",
|
||||
" float* px[BLOCK] = x + offset;\n",
|
||||
" float* py[BLOCK] = y + offset;\n",
|
||||
" // bounds checking\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" // write-back\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
" \"\"\"\n",
|
||||
"# This function returns a callable `triton.kernel` object\n",
|
||||
"# created from the above source code.\n",
|
||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
||||
"# We compile the kernel with -DBLOCK=1024\n",
|
||||
"_kernels = dict()\n",
|
||||
"def make_add_kernel(device):\n",
|
||||
" if device not in _kernels:\n",
|
||||
" defines = {'BLOCK': 1024}\n",
|
||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
||||
" return _kernels[device]\n",
|
||||
"\n",
|
||||
"# This is a standard torch custom autograd Function\n",
|
||||
"# The only difference is that we can now use the above kernel\n",
|
||||
"# in the `forward` and `backward` functions.`\n",
|
||||
"class _add(torch.autograd.Function):\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def forward(ctx, x, y):\n",
|
||||
" # constraints of the op\n",
|
||||
" assert x.dtype == torch.float32\n",
|
||||
" # *allocate output*\n",
|
||||
" z = torch.empty_like(x)\n",
|
||||
" # *create launch grid*:\n",
|
||||
" # this is a function which takes compilation parameters `opt`\n",
|
||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
||||
" N = z.shape[0]\n",
|
||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
||||
" # *launch kernel*:\n",
|
||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
||||
" # the `.data_ptr()` method\n",
|
||||
" kernel = make_add_kernel(z.device)\n",
|
||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
||||
" return z\n",
|
||||
"# Just like we standard PyTorch ops\n",
|
||||
"# We use the `.apply` method to create a \n",
|
||||
"# callable object for our function\n",
|
||||
"add = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "separated-polyester",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exclusive-salvation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Unit Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "supported-ribbon",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"The maximum difference between torch and triton is 0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\n",
|
||||
"x = torch.rand(98432, device='cuda')\n",
|
||||
"y = torch.rand(98432, device='cuda')\n",
|
||||
"za = x + y\n",
|
||||
"zb = add(x, y)\n",
|
||||
"print(za)\n",
|
||||
"print(zb)\n",
|
||||
"print(f'The maximum difference between torch and triton is '\n",
|
||||
" f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "otherwise-canadian",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "polished-australia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Writing a Benchmark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "historic-glass",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "strange-luxembourg",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We now want to benchmark the performance of `add`\n",
|
||||
"# Against that of PyTorch for increasing vector sizes\n",
|
||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" ret = fn()\n",
|
||||
" for i in range(warmup):\n",
|
||||
" fn()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" start_event.record()\n",
|
||||
" for i in range(rep):\n",
|
||||
" fn()\n",
|
||||
" end_event.record()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
||||
" return time_ms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "hairy-claim",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "pleasant-valley",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"131072 0.020 0.003\n",
|
||||
"262144 0.019 0.004\n",
|
||||
"524288 0.016 0.016\n",
|
||||
"1048576 0.033 0.033\n",
|
||||
"2097152 0.071 0.070\n",
|
||||
"4194304 0.142 0.144\n",
|
||||
"8388608 0.287 0.286\n",
|
||||
"16777216 0.572 0.568\n",
|
||||
"33554432 1.139 1.110\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
||||
" x = torch.rand(N, device='cuda')\n",
|
||||
" y = torch.rand(N, device='cuda')\n",
|
||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
||||
" torch_ms = do_bench(lambda: x + y)\n",
|
||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "juvenile-supplement",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
705
tutorials/02-fused-softmax.html
Normal file
705
tutorials/02-fused-softmax.html
Normal file
@@ -0,0 +1,705 @@
|
||||
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html class="writer-html5" lang="en" >
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<title>Fused Softmax — Triton documentation</title>
|
||||
|
||||
|
||||
|
||||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<!--[if lt IE 9]>
|
||||
<script src="../_static/js/html5shiv.min.js"></script>
|
||||
<![endif]-->
|
||||
|
||||
|
||||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||||
<script src="../_static/jquery.js"></script>
|
||||
<script src="../_static/underscore.js"></script>
|
||||
<script src="../_static/doctools.js"></script>
|
||||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||||
|
||||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||||
|
||||
|
||||
<link rel="index" title="Index" href="../genindex.html" />
|
||||
<link rel="search" title="Search" href="../search.html" />
|
||||
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
|
||||
</head>
|
||||
|
||||
<body class="wy-body-for-nav">
|
||||
|
||||
|
||||
<div class="wy-grid-for-nav">
|
||||
|
||||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||||
<div class="wy-side-scroll">
|
||||
<div class="wy-side-nav-search" >
|
||||
|
||||
|
||||
|
||||
<a href="../index.html" class="icon icon-home"> Triton
|
||||
|
||||
|
||||
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="search">
|
||||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||||
<input type="text" name="q" placeholder="Search docs" />
|
||||
<input type="hidden" name="check_keywords" value="yes" />
|
||||
<input type="hidden" name="area" value="default" />
|
||||
</form>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||||
|
||||
|
||||
<nav class="wy-nav-top" aria-label="top navigation">
|
||||
|
||||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||||
<a href="../index.html">Triton</a>
|
||||
|
||||
</nav>
|
||||
|
||||
|
||||
<div class="wy-nav-content">
|
||||
|
||||
<div class="rst-content">
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||||
|
||||
<ul class="wy-breadcrumbs">
|
||||
|
||||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||||
|
||||
<li>Fused Softmax</li>
|
||||
|
||||
|
||||
<li class="wy-breadcrumbs-aside">
|
||||
|
||||
|
||||
<a href="../_sources/tutorials/02-fused-softmax.ipynb.txt" rel="nofollow"> View page source</a>
|
||||
|
||||
|
||||
</li>
|
||||
|
||||
</ul>
|
||||
|
||||
|
||||
<hr/>
|
||||
</div>
|
||||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||||
<div itemprop="articleBody">
|
||||
|
||||
|
||||
<style>
|
||||
/* CSS for nbsphinx extension */
|
||||
|
||||
/* remove conflicting styling from Sphinx themes */
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt *,
|
||||
div.nbinput.container div.input_area pre,
|
||||
div.nboutput.container div.output_area pre,
|
||||
div.nbinput.container div.input_area .highlight,
|
||||
div.nboutput.container div.output_area .highlight {
|
||||
border: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
div.nbinput.container > div[class*=highlight],
|
||||
div.nboutput.container > div[class*=highlight] {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt *,
|
||||
div.nboutput.container div.prompt * {
|
||||
background: none;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area .highlight,
|
||||
div.nboutput.container div.output_area pre {
|
||||
background: unset;
|
||||
}
|
||||
|
||||
div.nboutput.container div.output_area div.highlight {
|
||||
color: unset; /* override Pygments text color */
|
||||
}
|
||||
|
||||
/* avoid gaps between output lines */
|
||||
div.nboutput.container div[class*=highlight] pre {
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
/* input/output containers */
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
display: -webkit-flex;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin: 0;
|
||||
width: 100%;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container,
|
||||
div.nboutput.container {
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
|
||||
/* input container */
|
||||
div.nbinput.container {
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
/* last container */
|
||||
div.nblast.container {
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
|
||||
/* input prompt */
|
||||
div.nbinput.container div.prompt pre {
|
||||
color: #307FC1;
|
||||
}
|
||||
|
||||
/* output prompt */
|
||||
div.nboutput.container div.prompt pre {
|
||||
color: #BF5B3D;
|
||||
}
|
||||
|
||||
/* all prompts */
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: 4.5ex;
|
||||
padding-top: 5px;
|
||||
position: relative;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: absolute;
|
||||
right: 0;
|
||||
margin-right: 0.3ex;
|
||||
}
|
||||
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.prompt,
|
||||
div.nboutput.container div.prompt {
|
||||
width: unset;
|
||||
text-align: left;
|
||||
padding: 0.4em;
|
||||
}
|
||||
div.nboutput.container div.prompt.empty {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
div.nbinput.container div.prompt > div,
|
||||
div.nboutput.container div.prompt > div {
|
||||
position: unset;
|
||||
}
|
||||
}
|
||||
|
||||
/* disable scrollbars on prompts */
|
||||
div.nbinput.container div.prompt pre,
|
||||
div.nboutput.container div.prompt pre {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* input/output area */
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
-webkit-flex: 1;
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
}
|
||||
@media (max-width: 540px) {
|
||||
div.nbinput.container div.input_area,
|
||||
div.nboutput.container div.output_area {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
/* input area */
|
||||
div.nbinput.container div.input_area {
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 2px;
|
||||
/*background: #f5f5f5;*/
|
||||
}
|
||||
|
||||
/* override MathJax center alignment in output cells */
|
||||
div.nboutput.container div[class*=MathJax] {
|
||||
text-align: left !important;
|
||||
}
|
||||
|
||||
/* override sphinx.ext.imgmath center alignment in output cells */
|
||||
div.nboutput.container div.math p {
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
/* standard error */
|
||||
div.nboutput.container div.output_area.stderr {
|
||||
background: #fdd;
|
||||
}
|
||||
|
||||
/* ANSI colors */
|
||||
.ansi-black-fg { color: #3E424D; }
|
||||
.ansi-black-bg { background-color: #3E424D; }
|
||||
.ansi-black-intense-fg { color: #282C36; }
|
||||
.ansi-black-intense-bg { background-color: #282C36; }
|
||||
.ansi-red-fg { color: #E75C58; }
|
||||
.ansi-red-bg { background-color: #E75C58; }
|
||||
.ansi-red-intense-fg { color: #B22B31; }
|
||||
.ansi-red-intense-bg { background-color: #B22B31; }
|
||||
.ansi-green-fg { color: #00A250; }
|
||||
.ansi-green-bg { background-color: #00A250; }
|
||||
.ansi-green-intense-fg { color: #007427; }
|
||||
.ansi-green-intense-bg { background-color: #007427; }
|
||||
.ansi-yellow-fg { color: #DDB62B; }
|
||||
.ansi-yellow-bg { background-color: #DDB62B; }
|
||||
.ansi-yellow-intense-fg { color: #B27D12; }
|
||||
.ansi-yellow-intense-bg { background-color: #B27D12; }
|
||||
.ansi-blue-fg { color: #208FFB; }
|
||||
.ansi-blue-bg { background-color: #208FFB; }
|
||||
.ansi-blue-intense-fg { color: #0065CA; }
|
||||
.ansi-blue-intense-bg { background-color: #0065CA; }
|
||||
.ansi-magenta-fg { color: #D160C4; }
|
||||
.ansi-magenta-bg { background-color: #D160C4; }
|
||||
.ansi-magenta-intense-fg { color: #A03196; }
|
||||
.ansi-magenta-intense-bg { background-color: #A03196; }
|
||||
.ansi-cyan-fg { color: #60C6C8; }
|
||||
.ansi-cyan-bg { background-color: #60C6C8; }
|
||||
.ansi-cyan-intense-fg { color: #258F8F; }
|
||||
.ansi-cyan-intense-bg { background-color: #258F8F; }
|
||||
.ansi-white-fg { color: #C5C1B4; }
|
||||
.ansi-white-bg { background-color: #C5C1B4; }
|
||||
.ansi-white-intense-fg { color: #A1A6B2; }
|
||||
.ansi-white-intense-bg { background-color: #A1A6B2; }
|
||||
|
||||
.ansi-default-inverse-fg { color: #FFFFFF; }
|
||||
.ansi-default-inverse-bg { background-color: #000000; }
|
||||
|
||||
.ansi-bold { font-weight: bold; }
|
||||
.ansi-underline { text-decoration: underline; }
|
||||
|
||||
|
||||
div.nbinput.container div.input_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight] > pre,
|
||||
div.nboutput.container div.output_area div[class*=highlight].math,
|
||||
div.nboutput.container div.output_area.rendered_html,
|
||||
div.nboutput.container div.output_area > div.output_javascript,
|
||||
div.nboutput.container div.output_area:not(.rendered_html) > img{
|
||||
padding: 5px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
|
||||
div.nbinput.container div.input_area > div[class^='highlight'],
|
||||
div.nboutput.container div.output_area > div[class^='highlight']{
|
||||
overflow-y: hidden;
|
||||
}
|
||||
|
||||
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
|
||||
.prompt a.copybtn {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Some additional styling taken form the Jupyter notebook CSS */
|
||||
div.rendered_html table {
|
||||
border: none;
|
||||
border-collapse: collapse;
|
||||
border-spacing: 0;
|
||||
color: black;
|
||||
font-size: 12px;
|
||||
table-layout: fixed;
|
||||
}
|
||||
div.rendered_html thead {
|
||||
border-bottom: 1px solid black;
|
||||
vertical-align: bottom;
|
||||
}
|
||||
div.rendered_html tr,
|
||||
div.rendered_html th,
|
||||
div.rendered_html td {
|
||||
text-align: right;
|
||||
vertical-align: middle;
|
||||
padding: 0.5em 0.5em;
|
||||
line-height: normal;
|
||||
white-space: normal;
|
||||
max-width: none;
|
||||
border: none;
|
||||
}
|
||||
div.rendered_html th {
|
||||
font-weight: bold;
|
||||
}
|
||||
div.rendered_html tbody tr:nth-child(odd) {
|
||||
background: #f5f5f5;
|
||||
}
|
||||
div.rendered_html tbody tr:hover {
|
||||
background: rgba(66, 165, 245, 0.2);
|
||||
}
|
||||
|
||||
/* CSS overrides for sphinx_rtd_theme */
|
||||
|
||||
/* 24px margin */
|
||||
.nbinput.nblast.container,
|
||||
.nboutput.nblast.container {
|
||||
margin-bottom: 19px; /* padding has already 5px */
|
||||
}
|
||||
|
||||
/* ... except between code cells! */
|
||||
.nblast.container + .nbinput.container {
|
||||
margin-top: -19px;
|
||||
}
|
||||
|
||||
.admonition > p:before {
|
||||
margin-right: 4px; /* make room for the exclamation icon */
|
||||
}
|
||||
|
||||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||||
.math {
|
||||
text-align: unset;
|
||||
}
|
||||
</style>
|
||||
<div class="section" id="Fused-Softmax">
|
||||
<h1>Fused Softmax<a class="headerlink" href="#Fused-Softmax" title="Permalink to this headline">¶</a></h1>
|
||||
<p>Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
|
||||
<span class="c1"># Compute the row-wise softmax of x \in R^{M \times N}</span>
|
||||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># read MN elements ; write MN elements</span>
|
||||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># read MN elements ; write M elements</span>
|
||||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||||
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||||
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
|
||||
<span class="k">return</span> <span class="n">ret</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <span class="math notranslate nohighlight">\(y\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.</p>
|
||||
<p>Instead, we want to write a custom “fused” pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.</p>
|
||||
<div class="section" id="Writing-the-Compute-Kernel">
|
||||
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// row index</span>
|
||||
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// column indices</span>
|
||||
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="c1">// the memory address of all the elements</span>
|
||||
<span class="c1">// that we want to load can be computed as follows</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="c1">// because BLOCK has to be a power of two</span>
|
||||
<span class="c1">// (per Triton-C specs), it is important</span>
|
||||
<span class="c1">// to guard each memory operation with predicates</span>
|
||||
<span class="c1">// or we will read out of bounds</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
|
||||
<span class="c1">// syntax for reduction in Triton is:</span>
|
||||
<span class="c1">// x[..., OPERATOR, ...]</span>
|
||||
<span class="c1">// ^</span>
|
||||
<span class="c1">// index</span>
|
||||
<span class="c1">// The operators currently supported are {min, max, +}</span>
|
||||
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
|
||||
<span class="c1">// The exponential in Triton is fast but approximate</span>
|
||||
<span class="c1">// (i.e., like __expf in CUDA)</span>
|
||||
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
|
||||
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
|
||||
<span class="c1">// The result of the reduction is now stored in y</span>
|
||||
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
|
||||
<span class="c1">// We write it back</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="Writing-the-Torch-bindings">
|
||||
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput nblast docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># source-code for Triton compute kernel</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
|
||||
<span class="s2"> int m = get_program_id(0);</span>
|
||||
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
|
||||
<span class="s2"> bool check[BLOCK] = n < N;</span>
|
||||
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
|
||||
<span class="s2"> float z [BLOCK] = x - x[max];</span>
|
||||
<span class="s2"> float num [BLOCK] = exp(z);</span>
|
||||
<span class="s2"> float denom = num[+];</span>
|
||||
<span class="s2"> float y [BLOCK] = num / denom;</span>
|
||||
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
|
||||
<span class="s2"> *?(check)py = y;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="c1"># We need to make sure that BLOCK is the smallest power of two</span>
|
||||
<span class="c1"># greater than the number of rows N of the input matrix.</span>
|
||||
<span class="c1"># Different values of BLOCK will result in different kernels</span>
|
||||
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">4</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">8</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">16</span>
|
||||
<span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="k">return</span> <span class="n">n</span>
|
||||
|
||||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
|
||||
<span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||||
|
||||
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># *create launch grid*:</span>
|
||||
<span class="c1"># here we just launch a grid of M programs</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
|
||||
<span class="c1"># *launch kernel*:</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
|
||||
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Unit-Test">
|
||||
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">y_tri</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">y_ref</span><span class="p">)</span>
|
||||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<div class="highlight"><pre>
|
||||
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
|
||||
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
|
||||
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
|
||||
...,
|
||||
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
|
||||
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
|
||||
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
|
||||
device='cuda:0')
|
||||
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
|
||||
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
|
||||
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
|
||||
...,
|
||||
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
|
||||
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
|
||||
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
|
||||
device='cuda:0')
|
||||
True
|
||||
</pre></div></div>
|
||||
</div>
|
||||
<p>Seems to work!</p>
|
||||
</div>
|
||||
<div class="section" id="Writing-a-Benchmark">
|
||||
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="nbinput docutils container">
|
||||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
|
||||
</pre></div>
|
||||
</div>
|
||||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||||
<span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
|
||||
|
||||
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
|
||||
<span class="n">Ns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="o">*</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)]</span>
|
||||
<span class="n">tri_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">ref_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">def_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="n">Ns</span><span class="p">:</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="n">tri_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||||
<span class="n">ref_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))]</span>
|
||||
<span class="n">def_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'N'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'Bandwidth (GB/s)'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">tri_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Triton'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">ref_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Torch'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">def_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Naive'</span><span class="p">)</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
|
||||
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="nboutput nblast docutils container">
|
||||
<div class="prompt empty docutils container">
|
||||
</div>
|
||||
<div class="output_area docutils container">
|
||||
<img alt="../_images/tutorials_02-fused-softmax_12_0.png" src="../_images/tutorials_02-fused-softmax_12_0.png" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<p>
|
||||
© Copyright 2020, Philippe Tillet.
|
||||
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||||
|
||||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||||
|
||||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||||
|
||||
</footer>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</section>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<script type="text/javascript">
|
||||
jQuery(function () {
|
||||
SphinxRtdTheme.Navigation.enable(true);
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
</body>
|
||||
</html>
|
308
tutorials/02-fused-softmax.ipynb
Normal file
308
tutorials/02-fused-softmax.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user