[GH-PAGES] Added missing files

This commit is contained in:
Philippe Tillet
2021-03-06 03:09:03 -05:00
parent 8b18c19875
commit 61cf3ddd96
33 changed files with 17543 additions and 0 deletions

4
.buildinfo Normal file
View File

@@ -0,0 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 76c4cbf22d8ff0aa13ac3f4683f2586c
tags: 645f666f9bcd5a90fca523b33c5a78b7

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
.vscode
./docs/_build/*
./build/*

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -0,0 +1,329 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "acute-possession",
"metadata": {},
"source": [
"# Vector Addition"
]
},
{
"cell_type": "markdown",
"id": "median-malaysia",
"metadata": {},
"source": [
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
"* The basic syntax of the Triton programming language\n",
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
"* The best practices for validating and benchmarking custom ops against native reference implementations"
]
},
{
"cell_type": "markdown",
"id": "identical-conditions",
"metadata": {},
"source": [
"## Writing the Compute Kernel"
]
},
{
"cell_type": "markdown",
"id": "collectible-belle",
"metadata": {},
"source": [
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
"\n",
"\n",
"```c\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // The `get_program_id(i)` returns the i-th coordinate\n",
" // of the program in the overaching SPMD context\n",
" // (a.k.a launch grid). This is what allows us to process\n",
" // different chunks of data in parallel.\n",
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
" // is similar to blockIdx.{x,y,z}\n",
" int pid = get_program_id(0);\n",
" // In Triton, arrays are first-class citizen. In other words,\n",
" // they are primitives data-types and are -- contrary to C and\n",
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
" // memory.\n",
" // In the few lines below, we create an array of `BLOCK` pointers\n",
" // whose memory values are, e.g.:\n",
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz [BLOCK] = z + offset;\n",
" float* px [BLOCK] = x + offset;\n",
" float* py [BLOCK] = y + offset;\n",
" // Simple element-wise control-flow for load/store operations can\n",
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
" // or the conditional dereferencing operator `*?(cond)ptr\n",
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
" // write-back `z`\n",
" bool check[BLOCK] = offset < N;\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
"```\n",
"\n",
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
]
},
{
"cell_type": "markdown",
"id": "forbidden-wednesday",
"metadata": {},
"source": [
"## Writing the Torch bindings"
]
},
{
"cell_type": "markdown",
"id": "numerical-agency",
"metadata": {},
"source": [
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
"\n",
"To create a `triton.kernel`, you only need three things:\n",
"* `source: string`: the source-code of the kernel you want to create\n",
"* `device: torch.device`: the device you want to compile this code for\n",
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
"\n",
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "sporting-keyboard",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"# here we just copy-paste the above code without the extensive comments.\n",
"# you may prefer to store it in a .c file and load it from there instead.\n",
"_src = \"\"\"\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // program id\n",
" int pid = get_program_id(0);\n",
" // create arrays of pointers\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz[BLOCK] = z + offset;\n",
" float* px[BLOCK] = x + offset;\n",
" float* py[BLOCK] = y + offset;\n",
" // bounds checking\n",
" bool check[BLOCK] = offset < N;\n",
" // write-back\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
" \"\"\"\n",
"# This function returns a callable `triton.kernel` object\n",
"# created from the above source code.\n",
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
"# We compile the kernel with -DBLOCK=1024\n",
"_kernels = dict()\n",
"def make_add_kernel(device):\n",
" if device not in _kernels:\n",
" defines = {'BLOCK': 1024}\n",
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
" return _kernels[device]\n",
"\n",
"# This is a standard torch custom autograd Function\n",
"# The only difference is that we can now use the above kernel\n",
"# in the `forward` and `backward` functions.`\n",
"class _add(torch.autograd.Function):\n",
" \n",
" @staticmethod\n",
" def forward(ctx, x, y):\n",
" # constraints of the op\n",
" assert x.dtype == torch.float32\n",
" # *allocate output*\n",
" z = torch.empty_like(x)\n",
" # *create launch grid*:\n",
" # this is a function which takes compilation parameters `opt`\n",
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
" # triton.cdiv is a shortcut for ceil division:\n",
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
" N = z.shape[0]\n",
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
" # *launch kernel*:\n",
" # pointer to the data of torch tensors can be retrieved with\n",
" # the `.data_ptr()` method\n",
" kernel = make_add_kernel(z.device)\n",
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
" return z\n",
"# Just like we standard PyTorch ops\n",
"# We use the `.apply` method to create a \n",
"# callable object for our function\n",
"add = _add.apply"
]
},
{
"cell_type": "markdown",
"id": "separated-polyester",
"metadata": {},
"source": [
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
]
},
{
"cell_type": "markdown",
"id": "exclusive-salvation",
"metadata": {},
"source": [
"## Writing a Unit Test"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "supported-ribbon",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"The maximum difference between torch and triton is 0.0\n"
]
}
],
"source": [
"torch.manual_seed(0)\n",
"x = torch.rand(98432, device='cuda')\n",
"y = torch.rand(98432, device='cuda')\n",
"za = x + y\n",
"zb = add(x, y)\n",
"print(za)\n",
"print(zb)\n",
"print(f'The maximum difference between torch and triton is '\n",
" f'{torch.max(torch.abs(za - zb))}')"
]
},
{
"cell_type": "markdown",
"id": "otherwise-canadian",
"metadata": {},
"source": [
"Seems to work!"
]
},
{
"cell_type": "markdown",
"id": "polished-australia",
"metadata": {},
"source": [
"## Writing a Benchmark"
]
},
{
"cell_type": "markdown",
"id": "historic-glass",
"metadata": {},
"source": [
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "strange-luxembourg",
"metadata": {},
"outputs": [],
"source": [
"# We now want to benchmark the performance of `add`\n",
"# Against that of PyTorch for increasing vector sizes\n",
"def do_bench(fn, warmup = 10, rep = 50):\n",
" start_event = torch.cuda.Event(enable_timing=True)\n",
" end_event = torch.cuda.Event(enable_timing=True)\n",
" ret = fn()\n",
" for i in range(warmup):\n",
" fn()\n",
" torch.cuda.synchronize()\n",
" start_event.record()\n",
" for i in range(rep):\n",
" fn()\n",
" end_event.record()\n",
" torch.cuda.synchronize()\n",
" time_ms = start_event.elapsed_time(end_event) / rep\n",
" return time_ms"
]
},
{
"cell_type": "markdown",
"id": "hairy-claim",
"metadata": {},
"source": [
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "pleasant-valley",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"131072 0.020 0.003\n",
"262144 0.019 0.004\n",
"524288 0.016 0.016\n",
"1048576 0.033 0.033\n",
"2097152 0.071 0.070\n",
"4194304 0.142 0.144\n",
"8388608 0.287 0.286\n",
"16777216 0.572 0.568\n",
"33554432 1.139 1.110\n"
]
}
],
"source": [
"for N in [2**i for i in range(17, 26, 1)]:\n",
" x = torch.rand(N, device='cuda')\n",
" y = torch.rand(N, device='cuda')\n",
" triton_ms = do_bench(lambda: add(x, y))\n",
" torch_ms = do_bench(lambda: x + y)\n",
" # print the performance of triton and torch as well as the achieved bandwidth\n",
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
]
},
{
"cell_type": "markdown",
"id": "juvenile-supplement",
"metadata": {},
"source": [
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 434 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,12 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
VERSION: '',
LANGUAGE: 'None',
COLLAPSE_INDEX: false,
BUILDER: 'html',
FILE_SUFFIX: '.html',
LINK_SUFFIX: '.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: '.txt',
NAVIGATION_WITH_KEYS: false
};

10872
_static/jquery-3.5.1.js vendored Normal file

File diff suppressed because it is too large Load Diff

1
_static/js/badge_only.js Normal file
View File

@@ -0,0 +1 @@
!function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}});

4
_static/js/html5shiv-printshiv.min.js vendored Normal file
View File

@@ -0,0 +1,4 @@
/**
* @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed
*/
!function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x<style>"+b+"</style>",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="<xyz></xyz>",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document);

4
_static/js/html5shiv.min.js vendored Normal file
View File

@@ -0,0 +1,4 @@
/**
* @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed
*/
!function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x<style>"+b+"</style>",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="<xyz></xyz>",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document);

297
_static/language_data.js Normal file
View File

@@ -0,0 +1,297 @@
/*
* language_data.js
* ~~~~~~~~~~~~~~~~
*
* This script contains the language-specific data used by searchtools.js,
* namely the list of stopwords, stemmer, scorer and splitter.
*
* :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
var stopwords = ["a","and","are","as","at","be","but","by","for","if","in","into","is","it","near","no","not","of","on","or","such","that","the","their","then","there","these","they","this","to","was","will","with"];
/* Non-minified version JS is _stemmer.js if file is provided */
/**
* Porter Stemmer
*/
var Stemmer = function() {
var step2list = {
ational: 'ate',
tional: 'tion',
enci: 'ence',
anci: 'ance',
izer: 'ize',
bli: 'ble',
alli: 'al',
entli: 'ent',
eli: 'e',
ousli: 'ous',
ization: 'ize',
ation: 'ate',
ator: 'ate',
alism: 'al',
iveness: 'ive',
fulness: 'ful',
ousness: 'ous',
aliti: 'al',
iviti: 'ive',
biliti: 'ble',
logi: 'log'
};
var step3list = {
icate: 'ic',
ative: '',
alize: 'al',
iciti: 'ic',
ical: 'ic',
ful: '',
ness: ''
};
var c = "[^aeiou]"; // consonant
var v = "[aeiouy]"; // vowel
var C = c + "[^aeiouy]*"; // consonant sequence
var V = v + "[aeiou]*"; // vowel sequence
var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0
var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1
var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1
var s_v = "^(" + C + ")?" + v; // vowel in stem
this.stemWord = function (w) {
var stem;
var suffix;
var firstch;
var origword = w;
if (w.length < 3)
return w;
var re;
var re2;
var re3;
var re4;
firstch = w.substr(0,1);
if (firstch == "y")
w = firstch.toUpperCase() + w.substr(1);
// Step 1a
re = /^(.+?)(ss|i)es$/;
re2 = /^(.+?)([^s])s$/;
if (re.test(w))
w = w.replace(re,"$1$2");
else if (re2.test(w))
w = w.replace(re2,"$1$2");
// Step 1b
re = /^(.+?)eed$/;
re2 = /^(.+?)(ed|ing)$/;
if (re.test(w)) {
var fp = re.exec(w);
re = new RegExp(mgr0);
if (re.test(fp[1])) {
re = /.$/;
w = w.replace(re,"");
}
}
else if (re2.test(w)) {
var fp = re2.exec(w);
stem = fp[1];
re2 = new RegExp(s_v);
if (re2.test(stem)) {
w = stem;
re2 = /(at|bl|iz)$/;
re3 = new RegExp("([^aeiouylsz])\\1$");
re4 = new RegExp("^" + C + v + "[^aeiouwxy]$");
if (re2.test(w))
w = w + "e";
else if (re3.test(w)) {
re = /.$/;
w = w.replace(re,"");
}
else if (re4.test(w))
w = w + "e";
}
}
// Step 1c
re = /^(.+?)y$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(s_v);
if (re.test(stem))
w = stem + "i";
}
// Step 2
re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
suffix = fp[2];
re = new RegExp(mgr0);
if (re.test(stem))
w = stem + step2list[suffix];
}
// Step 3
re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
suffix = fp[2];
re = new RegExp(mgr0);
if (re.test(stem))
w = stem + step3list[suffix];
}
// Step 4
re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/;
re2 = /^(.+?)(s|t)(ion)$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(mgr1);
if (re.test(stem))
w = stem;
}
else if (re2.test(w)) {
var fp = re2.exec(w);
stem = fp[1] + fp[2];
re2 = new RegExp(mgr1);
if (re2.test(stem))
w = stem;
}
// Step 5
re = /^(.+?)e$/;
if (re.test(w)) {
var fp = re.exec(w);
stem = fp[1];
re = new RegExp(mgr1);
re2 = new RegExp(meq1);
re3 = new RegExp("^" + C + v + "[^aeiouwxy]$");
if (re.test(stem) || (re2.test(stem) && !(re3.test(stem))))
w = stem;
}
re = /ll$/;
re2 = new RegExp(mgr1);
if (re.test(w) && re2.test(w)) {
re = /.$/;
w = w.replace(re,"");
}
// and turn initial Y back to y
if (firstch == "y")
w = firstch.toLowerCase() + w.substr(1);
return w;
}
}
var splitChars = (function() {
var result = {};
var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648,
1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702,
2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971,
2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345,
3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761,
3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823,
4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125,
8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695,
11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587,
43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141];
var i, j, start, end;
for (i = 0; i < singles.length; i++) {
result[singles[i]] = true;
}
var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709],
[722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161],
[1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568],
[1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807],
[1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047],
[2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383],
[2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450],
[2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547],
[2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673],
[2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820],
[2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946],
[2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023],
[3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173],
[3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332],
[3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481],
[3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718],
[3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791],
[3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095],
[4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205],
[4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687],
[4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968],
[4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869],
[5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102],
[6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271],
[6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592],
[6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822],
[6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167],
[7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959],
[7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143],
[8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318],
[8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483],
[8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101],
[10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567],
[11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292],
[12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444],
[12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783],
[12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311],
[19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511],
[42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774],
[42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071],
[43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263],
[43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519],
[43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647],
[43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967],
[44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295],
[57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274],
[64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007],
[65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381],
[65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]];
for (i = 0; i < ranges.length; i++) {
start = ranges[i][0];
end = ranges[i][1];
for (j = start; j <= end; j++) {
result[j] = true;
}
}
return result;
})();
function splitQuery(query) {
var result = [];
var start = -1;
for (var i = 0; i < query.length; i++) {
if (splitChars[query.charCodeAt(i)]) {
if (start !== -1) {
result.push(query.slice(start, i));
start = -1;
}
} else if (start === -1) {
start = i;
}
}
if (start !== -1) {
result.push(query.slice(start));
}
return result;
}

999
_static/underscore-1.3.1.js Normal file
View File

@@ -0,0 +1,999 @@
// Underscore.js 1.3.1
// (c) 2009-2012 Jeremy Ashkenas, DocumentCloud Inc.
// Underscore is freely distributable under the MIT license.
// Portions of Underscore are inspired or borrowed from Prototype,
// Oliver Steele's Functional, and John Resig's Micro-Templating.
// For all details and documentation:
// http://documentcloud.github.com/underscore
(function() {
// Baseline setup
// --------------
// Establish the root object, `window` in the browser, or `global` on the server.
var root = this;
// Save the previous value of the `_` variable.
var previousUnderscore = root._;
// Establish the object that gets returned to break out of a loop iteration.
var breaker = {};
// Save bytes in the minified (but not gzipped) version:
var ArrayProto = Array.prototype, ObjProto = Object.prototype, FuncProto = Function.prototype;
// Create quick reference variables for speed access to core prototypes.
var slice = ArrayProto.slice,
unshift = ArrayProto.unshift,
toString = ObjProto.toString,
hasOwnProperty = ObjProto.hasOwnProperty;
// All **ECMAScript 5** native function implementations that we hope to use
// are declared here.
var
nativeForEach = ArrayProto.forEach,
nativeMap = ArrayProto.map,
nativeReduce = ArrayProto.reduce,
nativeReduceRight = ArrayProto.reduceRight,
nativeFilter = ArrayProto.filter,
nativeEvery = ArrayProto.every,
nativeSome = ArrayProto.some,
nativeIndexOf = ArrayProto.indexOf,
nativeLastIndexOf = ArrayProto.lastIndexOf,
nativeIsArray = Array.isArray,
nativeKeys = Object.keys,
nativeBind = FuncProto.bind;
// Create a safe reference to the Underscore object for use below.
var _ = function(obj) { return new wrapper(obj); };
// Export the Underscore object for **Node.js**, with
// backwards-compatibility for the old `require()` API. If we're in
// the browser, add `_` as a global object via a string identifier,
// for Closure Compiler "advanced" mode.
if (typeof exports !== 'undefined') {
if (typeof module !== 'undefined' && module.exports) {
exports = module.exports = _;
}
exports._ = _;
} else {
root['_'] = _;
}
// Current version.
_.VERSION = '1.3.1';
// Collection Functions
// --------------------
// The cornerstone, an `each` implementation, aka `forEach`.
// Handles objects with the built-in `forEach`, arrays, and raw objects.
// Delegates to **ECMAScript 5**'s native `forEach` if available.
var each = _.each = _.forEach = function(obj, iterator, context) {
if (obj == null) return;
if (nativeForEach && obj.forEach === nativeForEach) {
obj.forEach(iterator, context);
} else if (obj.length === +obj.length) {
for (var i = 0, l = obj.length; i < l; i++) {
if (i in obj && iterator.call(context, obj[i], i, obj) === breaker) return;
}
} else {
for (var key in obj) {
if (_.has(obj, key)) {
if (iterator.call(context, obj[key], key, obj) === breaker) return;
}
}
}
};
// Return the results of applying the iterator to each element.
// Delegates to **ECMAScript 5**'s native `map` if available.
_.map = _.collect = function(obj, iterator, context) {
var results = [];
if (obj == null) return results;
if (nativeMap && obj.map === nativeMap) return obj.map(iterator, context);
each(obj, function(value, index, list) {
results[results.length] = iterator.call(context, value, index, list);
});
if (obj.length === +obj.length) results.length = obj.length;
return results;
};
// **Reduce** builds up a single result from a list of values, aka `inject`,
// or `foldl`. Delegates to **ECMAScript 5**'s native `reduce` if available.
_.reduce = _.foldl = _.inject = function(obj, iterator, memo, context) {
var initial = arguments.length > 2;
if (obj == null) obj = [];
if (nativeReduce && obj.reduce === nativeReduce) {
if (context) iterator = _.bind(iterator, context);
return initial ? obj.reduce(iterator, memo) : obj.reduce(iterator);
}
each(obj, function(value, index, list) {
if (!initial) {
memo = value;
initial = true;
} else {
memo = iterator.call(context, memo, value, index, list);
}
});
if (!initial) throw new TypeError('Reduce of empty array with no initial value');
return memo;
};
// The right-associative version of reduce, also known as `foldr`.
// Delegates to **ECMAScript 5**'s native `reduceRight` if available.
_.reduceRight = _.foldr = function(obj, iterator, memo, context) {
var initial = arguments.length > 2;
if (obj == null) obj = [];
if (nativeReduceRight && obj.reduceRight === nativeReduceRight) {
if (context) iterator = _.bind(iterator, context);
return initial ? obj.reduceRight(iterator, memo) : obj.reduceRight(iterator);
}
var reversed = _.toArray(obj).reverse();
if (context && !initial) iterator = _.bind(iterator, context);
return initial ? _.reduce(reversed, iterator, memo, context) : _.reduce(reversed, iterator);
};
// Return the first value which passes a truth test. Aliased as `detect`.
_.find = _.detect = function(obj, iterator, context) {
var result;
any(obj, function(value, index, list) {
if (iterator.call(context, value, index, list)) {
result = value;
return true;
}
});
return result;
};
// Return all the elements that pass a truth test.
// Delegates to **ECMAScript 5**'s native `filter` if available.
// Aliased as `select`.
_.filter = _.select = function(obj, iterator, context) {
var results = [];
if (obj == null) return results;
if (nativeFilter && obj.filter === nativeFilter) return obj.filter(iterator, context);
each(obj, function(value, index, list) {
if (iterator.call(context, value, index, list)) results[results.length] = value;
});
return results;
};
// Return all the elements for which a truth test fails.
_.reject = function(obj, iterator, context) {
var results = [];
if (obj == null) return results;
each(obj, function(value, index, list) {
if (!iterator.call(context, value, index, list)) results[results.length] = value;
});
return results;
};
// Determine whether all of the elements match a truth test.
// Delegates to **ECMAScript 5**'s native `every` if available.
// Aliased as `all`.
_.every = _.all = function(obj, iterator, context) {
var result = true;
if (obj == null) return result;
if (nativeEvery && obj.every === nativeEvery) return obj.every(iterator, context);
each(obj, function(value, index, list) {
if (!(result = result && iterator.call(context, value, index, list))) return breaker;
});
return result;
};
// Determine if at least one element in the object matches a truth test.
// Delegates to **ECMAScript 5**'s native `some` if available.
// Aliased as `any`.
var any = _.some = _.any = function(obj, iterator, context) {
iterator || (iterator = _.identity);
var result = false;
if (obj == null) return result;
if (nativeSome && obj.some === nativeSome) return obj.some(iterator, context);
each(obj, function(value, index, list) {
if (result || (result = iterator.call(context, value, index, list))) return breaker;
});
return !!result;
};
// Determine if a given value is included in the array or object using `===`.
// Aliased as `contains`.
_.include = _.contains = function(obj, target) {
var found = false;
if (obj == null) return found;
if (nativeIndexOf && obj.indexOf === nativeIndexOf) return obj.indexOf(target) != -1;
found = any(obj, function(value) {
return value === target;
});
return found;
};
// Invoke a method (with arguments) on every item in a collection.
_.invoke = function(obj, method) {
var args = slice.call(arguments, 2);
return _.map(obj, function(value) {
return (_.isFunction(method) ? method || value : value[method]).apply(value, args);
});
};
// Convenience version of a common use case of `map`: fetching a property.
_.pluck = function(obj, key) {
return _.map(obj, function(value){ return value[key]; });
};
// Return the maximum element or (element-based computation).
_.max = function(obj, iterator, context) {
if (!iterator && _.isArray(obj)) return Math.max.apply(Math, obj);
if (!iterator && _.isEmpty(obj)) return -Infinity;
var result = {computed : -Infinity};
each(obj, function(value, index, list) {
var computed = iterator ? iterator.call(context, value, index, list) : value;
computed >= result.computed && (result = {value : value, computed : computed});
});
return result.value;
};
// Return the minimum element (or element-based computation).
_.min = function(obj, iterator, context) {
if (!iterator && _.isArray(obj)) return Math.min.apply(Math, obj);
if (!iterator && _.isEmpty(obj)) return Infinity;
var result = {computed : Infinity};
each(obj, function(value, index, list) {
var computed = iterator ? iterator.call(context, value, index, list) : value;
computed < result.computed && (result = {value : value, computed : computed});
});
return result.value;
};
// Shuffle an array.
_.shuffle = function(obj) {
var shuffled = [], rand;
each(obj, function(value, index, list) {
if (index == 0) {
shuffled[0] = value;
} else {
rand = Math.floor(Math.random() * (index + 1));
shuffled[index] = shuffled[rand];
shuffled[rand] = value;
}
});
return shuffled;
};
// Sort the object's values by a criterion produced by an iterator.
_.sortBy = function(obj, iterator, context) {
return _.pluck(_.map(obj, function(value, index, list) {
return {
value : value,
criteria : iterator.call(context, value, index, list)
};
}).sort(function(left, right) {
var a = left.criteria, b = right.criteria;
return a < b ? -1 : a > b ? 1 : 0;
}), 'value');
};
// Groups the object's values by a criterion. Pass either a string attribute
// to group by, or a function that returns the criterion.
_.groupBy = function(obj, val) {
var result = {};
var iterator = _.isFunction(val) ? val : function(obj) { return obj[val]; };
each(obj, function(value, index) {
var key = iterator(value, index);
(result[key] || (result[key] = [])).push(value);
});
return result;
};
// Use a comparator function to figure out at what index an object should
// be inserted so as to maintain order. Uses binary search.
_.sortedIndex = function(array, obj, iterator) {
iterator || (iterator = _.identity);
var low = 0, high = array.length;
while (low < high) {
var mid = (low + high) >> 1;
iterator(array[mid]) < iterator(obj) ? low = mid + 1 : high = mid;
}
return low;
};
// Safely convert anything iterable into a real, live array.
_.toArray = function(iterable) {
if (!iterable) return [];
if (iterable.toArray) return iterable.toArray();
if (_.isArray(iterable)) return slice.call(iterable);
if (_.isArguments(iterable)) return slice.call(iterable);
return _.values(iterable);
};
// Return the number of elements in an object.
_.size = function(obj) {
return _.toArray(obj).length;
};
// Array Functions
// ---------------
// Get the first element of an array. Passing **n** will return the first N
// values in the array. Aliased as `head`. The **guard** check allows it to work
// with `_.map`.
_.first = _.head = function(array, n, guard) {
return (n != null) && !guard ? slice.call(array, 0, n) : array[0];
};
// Returns everything but the last entry of the array. Especcialy useful on
// the arguments object. Passing **n** will return all the values in
// the array, excluding the last N. The **guard** check allows it to work with
// `_.map`.
_.initial = function(array, n, guard) {
return slice.call(array, 0, array.length - ((n == null) || guard ? 1 : n));
};
// Get the last element of an array. Passing **n** will return the last N
// values in the array. The **guard** check allows it to work with `_.map`.
_.last = function(array, n, guard) {
if ((n != null) && !guard) {
return slice.call(array, Math.max(array.length - n, 0));
} else {
return array[array.length - 1];
}
};
// Returns everything but the first entry of the array. Aliased as `tail`.
// Especially useful on the arguments object. Passing an **index** will return
// the rest of the values in the array from that index onward. The **guard**
// check allows it to work with `_.map`.
_.rest = _.tail = function(array, index, guard) {
return slice.call(array, (index == null) || guard ? 1 : index);
};
// Trim out all falsy values from an array.
_.compact = function(array) {
return _.filter(array, function(value){ return !!value; });
};
// Return a completely flattened version of an array.
_.flatten = function(array, shallow) {
return _.reduce(array, function(memo, value) {
if (_.isArray(value)) return memo.concat(shallow ? value : _.flatten(value));
memo[memo.length] = value;
return memo;
}, []);
};
// Return a version of the array that does not contain the specified value(s).
_.without = function(array) {
return _.difference(array, slice.call(arguments, 1));
};
// Produce a duplicate-free version of the array. If the array has already
// been sorted, you have the option of using a faster algorithm.
// Aliased as `unique`.
_.uniq = _.unique = function(array, isSorted, iterator) {
var initial = iterator ? _.map(array, iterator) : array;
var result = [];
_.reduce(initial, function(memo, el, i) {
if (0 == i || (isSorted === true ? _.last(memo) != el : !_.include(memo, el))) {
memo[memo.length] = el;
result[result.length] = array[i];
}
return memo;
}, []);
return result;
};
// Produce an array that contains the union: each distinct element from all of
// the passed-in arrays.
_.union = function() {
return _.uniq(_.flatten(arguments, true));
};
// Produce an array that contains every item shared between all the
// passed-in arrays. (Aliased as "intersect" for back-compat.)
_.intersection = _.intersect = function(array) {
var rest = slice.call(arguments, 1);
return _.filter(_.uniq(array), function(item) {
return _.every(rest, function(other) {
return _.indexOf(other, item) >= 0;
});
});
};
// Take the difference between one array and a number of other arrays.
// Only the elements present in just the first array will remain.
_.difference = function(array) {
var rest = _.flatten(slice.call(arguments, 1));
return _.filter(array, function(value){ return !_.include(rest, value); });
};
// Zip together multiple lists into a single array -- elements that share
// an index go together.
_.zip = function() {
var args = slice.call(arguments);
var length = _.max(_.pluck(args, 'length'));
var results = new Array(length);
for (var i = 0; i < length; i++) results[i] = _.pluck(args, "" + i);
return results;
};
// If the browser doesn't supply us with indexOf (I'm looking at you, **MSIE**),
// we need this function. Return the position of the first occurrence of an
// item in an array, or -1 if the item is not included in the array.
// Delegates to **ECMAScript 5**'s native `indexOf` if available.
// If the array is large and already in sort order, pass `true`
// for **isSorted** to use binary search.
_.indexOf = function(array, item, isSorted) {
if (array == null) return -1;
var i, l;
if (isSorted) {
i = _.sortedIndex(array, item);
return array[i] === item ? i : -1;
}
if (nativeIndexOf && array.indexOf === nativeIndexOf) return array.indexOf(item);
for (i = 0, l = array.length; i < l; i++) if (i in array && array[i] === item) return i;
return -1;
};
// Delegates to **ECMAScript 5**'s native `lastIndexOf` if available.
_.lastIndexOf = function(array, item) {
if (array == null) return -1;
if (nativeLastIndexOf && array.lastIndexOf === nativeLastIndexOf) return array.lastIndexOf(item);
var i = array.length;
while (i--) if (i in array && array[i] === item) return i;
return -1;
};
// Generate an integer Array containing an arithmetic progression. A port of
// the native Python `range()` function. See
// [the Python documentation](http://docs.python.org/library/functions.html#range).
_.range = function(start, stop, step) {
if (arguments.length <= 1) {
stop = start || 0;
start = 0;
}
step = arguments[2] || 1;
var len = Math.max(Math.ceil((stop - start) / step), 0);
var idx = 0;
var range = new Array(len);
while(idx < len) {
range[idx++] = start;
start += step;
}
return range;
};
// Function (ahem) Functions
// ------------------
// Reusable constructor function for prototype setting.
var ctor = function(){};
// Create a function bound to a given object (assigning `this`, and arguments,
// optionally). Binding with arguments is also known as `curry`.
// Delegates to **ECMAScript 5**'s native `Function.bind` if available.
// We check for `func.bind` first, to fail fast when `func` is undefined.
_.bind = function bind(func, context) {
var bound, args;
if (func.bind === nativeBind && nativeBind) return nativeBind.apply(func, slice.call(arguments, 1));
if (!_.isFunction(func)) throw new TypeError;
args = slice.call(arguments, 2);
return bound = function() {
if (!(this instanceof bound)) return func.apply(context, args.concat(slice.call(arguments)));
ctor.prototype = func.prototype;
var self = new ctor;
var result = func.apply(self, args.concat(slice.call(arguments)));
if (Object(result) === result) return result;
return self;
};
};
// Bind all of an object's methods to that object. Useful for ensuring that
// all callbacks defined on an object belong to it.
_.bindAll = function(obj) {
var funcs = slice.call(arguments, 1);
if (funcs.length == 0) funcs = _.functions(obj);
each(funcs, function(f) { obj[f] = _.bind(obj[f], obj); });
return obj;
};
// Memoize an expensive function by storing its results.
_.memoize = function(func, hasher) {
var memo = {};
hasher || (hasher = _.identity);
return function() {
var key = hasher.apply(this, arguments);
return _.has(memo, key) ? memo[key] : (memo[key] = func.apply(this, arguments));
};
};
// Delays a function for the given number of milliseconds, and then calls
// it with the arguments supplied.
_.delay = function(func, wait) {
var args = slice.call(arguments, 2);
return setTimeout(function(){ return func.apply(func, args); }, wait);
};
// Defers a function, scheduling it to run after the current call stack has
// cleared.
_.defer = function(func) {
return _.delay.apply(_, [func, 1].concat(slice.call(arguments, 1)));
};
// Returns a function, that, when invoked, will only be triggered at most once
// during a given window of time.
_.throttle = function(func, wait) {
var context, args, timeout, throttling, more;
var whenDone = _.debounce(function(){ more = throttling = false; }, wait);
return function() {
context = this; args = arguments;
var later = function() {
timeout = null;
if (more) func.apply(context, args);
whenDone();
};
if (!timeout) timeout = setTimeout(later, wait);
if (throttling) {
more = true;
} else {
func.apply(context, args);
}
whenDone();
throttling = true;
};
};
// Returns a function, that, as long as it continues to be invoked, will not
// be triggered. The function will be called after it stops being called for
// N milliseconds.
_.debounce = function(func, wait) {
var timeout;
return function() {
var context = this, args = arguments;
var later = function() {
timeout = null;
func.apply(context, args);
};
clearTimeout(timeout);
timeout = setTimeout(later, wait);
};
};
// Returns a function that will be executed at most one time, no matter how
// often you call it. Useful for lazy initialization.
_.once = function(func) {
var ran = false, memo;
return function() {
if (ran) return memo;
ran = true;
return memo = func.apply(this, arguments);
};
};
// Returns the first function passed as an argument to the second,
// allowing you to adjust arguments, run code before and after, and
// conditionally execute the original function.
_.wrap = function(func, wrapper) {
return function() {
var args = [func].concat(slice.call(arguments, 0));
return wrapper.apply(this, args);
};
};
// Returns a function that is the composition of a list of functions, each
// consuming the return value of the function that follows.
_.compose = function() {
var funcs = arguments;
return function() {
var args = arguments;
for (var i = funcs.length - 1; i >= 0; i--) {
args = [funcs[i].apply(this, args)];
}
return args[0];
};
};
// Returns a function that will only be executed after being called N times.
_.after = function(times, func) {
if (times <= 0) return func();
return function() {
if (--times < 1) { return func.apply(this, arguments); }
};
};
// Object Functions
// ----------------
// Retrieve the names of an object's properties.
// Delegates to **ECMAScript 5**'s native `Object.keys`
_.keys = nativeKeys || function(obj) {
if (obj !== Object(obj)) throw new TypeError('Invalid object');
var keys = [];
for (var key in obj) if (_.has(obj, key)) keys[keys.length] = key;
return keys;
};
// Retrieve the values of an object's properties.
_.values = function(obj) {
return _.map(obj, _.identity);
};
// Return a sorted list of the function names available on the object.
// Aliased as `methods`
_.functions = _.methods = function(obj) {
var names = [];
for (var key in obj) {
if (_.isFunction(obj[key])) names.push(key);
}
return names.sort();
};
// Extend a given object with all the properties in passed-in object(s).
_.extend = function(obj) {
each(slice.call(arguments, 1), function(source) {
for (var prop in source) {
obj[prop] = source[prop];
}
});
return obj;
};
// Fill in a given object with default properties.
_.defaults = function(obj) {
each(slice.call(arguments, 1), function(source) {
for (var prop in source) {
if (obj[prop] == null) obj[prop] = source[prop];
}
});
return obj;
};
// Create a (shallow-cloned) duplicate of an object.
_.clone = function(obj) {
if (!_.isObject(obj)) return obj;
return _.isArray(obj) ? obj.slice() : _.extend({}, obj);
};
// Invokes interceptor with the obj, and then returns obj.
// The primary purpose of this method is to "tap into" a method chain, in
// order to perform operations on intermediate results within the chain.
_.tap = function(obj, interceptor) {
interceptor(obj);
return obj;
};
// Internal recursive comparison function.
function eq(a, b, stack) {
// Identical objects are equal. `0 === -0`, but they aren't identical.
// See the Harmony `egal` proposal: http://wiki.ecmascript.org/doku.php?id=harmony:egal.
if (a === b) return a !== 0 || 1 / a == 1 / b;
// A strict comparison is necessary because `null == undefined`.
if (a == null || b == null) return a === b;
// Unwrap any wrapped objects.
if (a._chain) a = a._wrapped;
if (b._chain) b = b._wrapped;
// Invoke a custom `isEqual` method if one is provided.
if (a.isEqual && _.isFunction(a.isEqual)) return a.isEqual(b);
if (b.isEqual && _.isFunction(b.isEqual)) return b.isEqual(a);
// Compare `[[Class]]` names.
var className = toString.call(a);
if (className != toString.call(b)) return false;
switch (className) {
// Strings, numbers, dates, and booleans are compared by value.
case '[object String]':
// Primitives and their corresponding object wrappers are equivalent; thus, `"5"` is
// equivalent to `new String("5")`.
return a == String(b);
case '[object Number]':
// `NaN`s are equivalent, but non-reflexive. An `egal` comparison is performed for
// other numeric values.
return a != +a ? b != +b : (a == 0 ? 1 / a == 1 / b : a == +b);
case '[object Date]':
case '[object Boolean]':
// Coerce dates and booleans to numeric primitive values. Dates are compared by their
// millisecond representations. Note that invalid dates with millisecond representations
// of `NaN` are not equivalent.
return +a == +b;
// RegExps are compared by their source patterns and flags.
case '[object RegExp]':
return a.source == b.source &&
a.global == b.global &&
a.multiline == b.multiline &&
a.ignoreCase == b.ignoreCase;
}
if (typeof a != 'object' || typeof b != 'object') return false;
// Assume equality for cyclic structures. The algorithm for detecting cyclic
// structures is adapted from ES 5.1 section 15.12.3, abstract operation `JO`.
var length = stack.length;
while (length--) {
// Linear search. Performance is inversely proportional to the number of
// unique nested structures.
if (stack[length] == a) return true;
}
// Add the first object to the stack of traversed objects.
stack.push(a);
var size = 0, result = true;
// Recursively compare objects and arrays.
if (className == '[object Array]') {
// Compare array lengths to determine if a deep comparison is necessary.
size = a.length;
result = size == b.length;
if (result) {
// Deep compare the contents, ignoring non-numeric properties.
while (size--) {
// Ensure commutative equality for sparse arrays.
if (!(result = size in a == size in b && eq(a[size], b[size], stack))) break;
}
}
} else {
// Objects with different constructors are not equivalent.
if ('constructor' in a != 'constructor' in b || a.constructor != b.constructor) return false;
// Deep compare objects.
for (var key in a) {
if (_.has(a, key)) {
// Count the expected number of properties.
size++;
// Deep compare each member.
if (!(result = _.has(b, key) && eq(a[key], b[key], stack))) break;
}
}
// Ensure that both objects contain the same number of properties.
if (result) {
for (key in b) {
if (_.has(b, key) && !(size--)) break;
}
result = !size;
}
}
// Remove the first object from the stack of traversed objects.
stack.pop();
return result;
}
// Perform a deep comparison to check if two objects are equal.
_.isEqual = function(a, b) {
return eq(a, b, []);
};
// Is a given array, string, or object empty?
// An "empty" object has no enumerable own-properties.
_.isEmpty = function(obj) {
if (_.isArray(obj) || _.isString(obj)) return obj.length === 0;
for (var key in obj) if (_.has(obj, key)) return false;
return true;
};
// Is a given value a DOM element?
_.isElement = function(obj) {
return !!(obj && obj.nodeType == 1);
};
// Is a given value an array?
// Delegates to ECMA5's native Array.isArray
_.isArray = nativeIsArray || function(obj) {
return toString.call(obj) == '[object Array]';
};
// Is a given variable an object?
_.isObject = function(obj) {
return obj === Object(obj);
};
// Is a given variable an arguments object?
_.isArguments = function(obj) {
return toString.call(obj) == '[object Arguments]';
};
if (!_.isArguments(arguments)) {
_.isArguments = function(obj) {
return !!(obj && _.has(obj, 'callee'));
};
}
// Is a given value a function?
_.isFunction = function(obj) {
return toString.call(obj) == '[object Function]';
};
// Is a given value a string?
_.isString = function(obj) {
return toString.call(obj) == '[object String]';
};
// Is a given value a number?
_.isNumber = function(obj) {
return toString.call(obj) == '[object Number]';
};
// Is the given value `NaN`?
_.isNaN = function(obj) {
// `NaN` is the only value for which `===` is not reflexive.
return obj !== obj;
};
// Is a given value a boolean?
_.isBoolean = function(obj) {
return obj === true || obj === false || toString.call(obj) == '[object Boolean]';
};
// Is a given value a date?
_.isDate = function(obj) {
return toString.call(obj) == '[object Date]';
};
// Is the given value a regular expression?
_.isRegExp = function(obj) {
return toString.call(obj) == '[object RegExp]';
};
// Is a given value equal to null?
_.isNull = function(obj) {
return obj === null;
};
// Is a given variable undefined?
_.isUndefined = function(obj) {
return obj === void 0;
};
// Has own property?
_.has = function(obj, key) {
return hasOwnProperty.call(obj, key);
};
// Utility Functions
// -----------------
// Run Underscore.js in *noConflict* mode, returning the `_` variable to its
// previous owner. Returns a reference to the Underscore object.
_.noConflict = function() {
root._ = previousUnderscore;
return this;
};
// Keep the identity function around for default iterators.
_.identity = function(value) {
return value;
};
// Run a function **n** times.
_.times = function (n, iterator, context) {
for (var i = 0; i < n; i++) iterator.call(context, i);
};
// Escape a string for HTML interpolation.
_.escape = function(string) {
return (''+string).replace(/&/g, '&amp;').replace(/</g, '&lt;').replace(/>/g, '&gt;').replace(/"/g, '&quot;').replace(/'/g, '&#x27;').replace(/\//g,'&#x2F;');
};
// Add your own custom functions to the Underscore object, ensuring that
// they're correctly added to the OOP wrapper as well.
_.mixin = function(obj) {
each(_.functions(obj), function(name){
addToWrapper(name, _[name] = obj[name]);
});
};
// Generate a unique integer id (unique within the entire client session).
// Useful for temporary DOM ids.
var idCounter = 0;
_.uniqueId = function(prefix) {
var id = idCounter++;
return prefix ? prefix + id : id;
};
// By default, Underscore uses ERB-style template delimiters, change the
// following template settings to use alternative delimiters.
_.templateSettings = {
evaluate : /<%([\s\S]+?)%>/g,
interpolate : /<%=([\s\S]+?)%>/g,
escape : /<%-([\s\S]+?)%>/g
};
// When customizing `templateSettings`, if you don't want to define an
// interpolation, evaluation or escaping regex, we need one that is
// guaranteed not to match.
var noMatch = /.^/;
// Within an interpolation, evaluation, or escaping, remove HTML escaping
// that had been previously added.
var unescape = function(code) {
return code.replace(/\\\\/g, '\\').replace(/\\'/g, "'");
};
// JavaScript micro-templating, similar to John Resig's implementation.
// Underscore templating handles arbitrary delimiters, preserves whitespace,
// and correctly escapes quotes within interpolated code.
_.template = function(str, data) {
var c = _.templateSettings;
var tmpl = 'var __p=[],print=function(){__p.push.apply(__p,arguments);};' +
'with(obj||{}){__p.push(\'' +
str.replace(/\\/g, '\\\\')
.replace(/'/g, "\\'")
.replace(c.escape || noMatch, function(match, code) {
return "',_.escape(" + unescape(code) + "),'";
})
.replace(c.interpolate || noMatch, function(match, code) {
return "'," + unescape(code) + ",'";
})
.replace(c.evaluate || noMatch, function(match, code) {
return "');" + unescape(code).replace(/[\r\n\t]/g, ' ') + ";__p.push('";
})
.replace(/\r/g, '\\r')
.replace(/\n/g, '\\n')
.replace(/\t/g, '\\t')
+ "');}return __p.join('');";
var func = new Function('obj', '_', tmpl);
if (data) return func(data, _);
return function(data) {
return func.call(this, data, _);
};
};
// Add a "chain" function, which will delegate to the wrapper.
_.chain = function(obj) {
return _(obj).chain();
};
// The OOP Wrapper
// ---------------
// If Underscore is called as a function, it returns a wrapped object that
// can be used OO-style. This wrapper holds altered versions of all the
// underscore functions. Wrapped objects may be chained.
var wrapper = function(obj) { this._wrapped = obj; };
// Expose `wrapper.prototype` as `_.prototype`
_.prototype = wrapper.prototype;
// Helper function to continue chaining intermediate results.
var result = function(obj, chain) {
return chain ? _(obj).chain() : obj;
};
// A method to easily add functions to the OOP wrapper.
var addToWrapper = function(name, func) {
wrapper.prototype[name] = function() {
var args = slice.call(arguments);
unshift.call(args, this._wrapped);
return result(func.apply(_, args), this._chain);
};
};
// Add all of the Underscore functions to the wrapper object.
_.mixin(_);
// Add all mutator Array functions to the wrapper.
each(['pop', 'push', 'reverse', 'shift', 'sort', 'splice', 'unshift'], function(name) {
var method = ArrayProto[name];
wrapper.prototype[name] = function() {
var wrapped = this._wrapped;
method.apply(wrapped, arguments);
var length = wrapped.length;
if ((name == 'shift' || name == 'splice') && length === 0) delete wrapped[0];
return result(wrapped, this._chain);
};
});
// Add all accessor Array functions to the wrapper.
each(['concat', 'join', 'slice'], function(name) {
var method = ArrayProto[name];
wrapper.prototype[name] = function() {
return result(method.apply(this._wrapped, arguments), this._chain);
};
});
// Start chaining a wrapped Underscore object.
wrapper.prototype.chain = function() {
this._chain = true;
return this;
};
// Extracts the result from a wrapped and chained object.
wrapper.prototype.value = function() {
return this._wrapped;
};
}).call(this);

View File

@@ -0,0 +1,697 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vector Addition &mdash; Triton documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/doctools.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
<link rel="prev" title="From Source" href="../installation/from-source.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
</ul>
<p class="caption"><span class="caption-text">Tutorials</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
<li>Vector Addition</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/tutorials/01-vector-add.ipynb.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<style>
/* CSS for nbsphinx extension */
/* remove conflicting styling from Sphinx themes */
div.nbinput.container div.prompt *,
div.nboutput.container div.prompt *,
div.nbinput.container div.input_area pre,
div.nboutput.container div.output_area pre,
div.nbinput.container div.input_area .highlight,
div.nboutput.container div.output_area .highlight {
border: none;
padding: 0;
margin: 0;
box-shadow: none;
}
div.nbinput.container > div[class*=highlight],
div.nboutput.container > div[class*=highlight] {
margin: 0;
}
div.nbinput.container div.prompt *,
div.nboutput.container div.prompt * {
background: none;
}
div.nboutput.container div.output_area .highlight,
div.nboutput.container div.output_area pre {
background: unset;
}
div.nboutput.container div.output_area div.highlight {
color: unset; /* override Pygments text color */
}
/* avoid gaps between output lines */
div.nboutput.container div[class*=highlight] pre {
line-height: normal;
}
/* input/output containers */
div.nbinput.container,
div.nboutput.container {
display: -webkit-flex;
display: flex;
align-items: flex-start;
margin: 0;
width: 100%;
}
@media (max-width: 540px) {
div.nbinput.container,
div.nboutput.container {
flex-direction: column;
}
}
/* input container */
div.nbinput.container {
padding-top: 5px;
}
/* last container */
div.nblast.container {
padding-bottom: 5px;
}
/* input prompt */
div.nbinput.container div.prompt pre {
color: #307FC1;
}
/* output prompt */
div.nboutput.container div.prompt pre {
color: #BF5B3D;
}
/* all prompts */
div.nbinput.container div.prompt,
div.nboutput.container div.prompt {
width: 4.5ex;
padding-top: 5px;
position: relative;
user-select: none;
}
div.nbinput.container div.prompt > div,
div.nboutput.container div.prompt > div {
position: absolute;
right: 0;
margin-right: 0.3ex;
}
@media (max-width: 540px) {
div.nbinput.container div.prompt,
div.nboutput.container div.prompt {
width: unset;
text-align: left;
padding: 0.4em;
}
div.nboutput.container div.prompt.empty {
padding: 0;
}
div.nbinput.container div.prompt > div,
div.nboutput.container div.prompt > div {
position: unset;
}
}
/* disable scrollbars on prompts */
div.nbinput.container div.prompt pre,
div.nboutput.container div.prompt pre {
overflow: hidden;
}
/* input/output area */
div.nbinput.container div.input_area,
div.nboutput.container div.output_area {
-webkit-flex: 1;
flex: 1;
overflow: auto;
}
@media (max-width: 540px) {
div.nbinput.container div.input_area,
div.nboutput.container div.output_area {
width: 100%;
}
}
/* input area */
div.nbinput.container div.input_area {
border: 1px solid #e0e0e0;
border-radius: 2px;
/*background: #f5f5f5;*/
}
/* override MathJax center alignment in output cells */
div.nboutput.container div[class*=MathJax] {
text-align: left !important;
}
/* override sphinx.ext.imgmath center alignment in output cells */
div.nboutput.container div.math p {
text-align: left;
}
/* standard error */
div.nboutput.container div.output_area.stderr {
background: #fdd;
}
/* ANSI colors */
.ansi-black-fg { color: #3E424D; }
.ansi-black-bg { background-color: #3E424D; }
.ansi-black-intense-fg { color: #282C36; }
.ansi-black-intense-bg { background-color: #282C36; }
.ansi-red-fg { color: #E75C58; }
.ansi-red-bg { background-color: #E75C58; }
.ansi-red-intense-fg { color: #B22B31; }
.ansi-red-intense-bg { background-color: #B22B31; }
.ansi-green-fg { color: #00A250; }
.ansi-green-bg { background-color: #00A250; }
.ansi-green-intense-fg { color: #007427; }
.ansi-green-intense-bg { background-color: #007427; }
.ansi-yellow-fg { color: #DDB62B; }
.ansi-yellow-bg { background-color: #DDB62B; }
.ansi-yellow-intense-fg { color: #B27D12; }
.ansi-yellow-intense-bg { background-color: #B27D12; }
.ansi-blue-fg { color: #208FFB; }
.ansi-blue-bg { background-color: #208FFB; }
.ansi-blue-intense-fg { color: #0065CA; }
.ansi-blue-intense-bg { background-color: #0065CA; }
.ansi-magenta-fg { color: #D160C4; }
.ansi-magenta-bg { background-color: #D160C4; }
.ansi-magenta-intense-fg { color: #A03196; }
.ansi-magenta-intense-bg { background-color: #A03196; }
.ansi-cyan-fg { color: #60C6C8; }
.ansi-cyan-bg { background-color: #60C6C8; }
.ansi-cyan-intense-fg { color: #258F8F; }
.ansi-cyan-intense-bg { background-color: #258F8F; }
.ansi-white-fg { color: #C5C1B4; }
.ansi-white-bg { background-color: #C5C1B4; }
.ansi-white-intense-fg { color: #A1A6B2; }
.ansi-white-intense-bg { background-color: #A1A6B2; }
.ansi-default-inverse-fg { color: #FFFFFF; }
.ansi-default-inverse-bg { background-color: #000000; }
.ansi-bold { font-weight: bold; }
.ansi-underline { text-decoration: underline; }
div.nbinput.container div.input_area div[class*=highlight] > pre,
div.nboutput.container div.output_area div[class*=highlight] > pre,
div.nboutput.container div.output_area div[class*=highlight].math,
div.nboutput.container div.output_area.rendered_html,
div.nboutput.container div.output_area > div.output_javascript,
div.nboutput.container div.output_area:not(.rendered_html) > img{
padding: 5px;
margin: 0;
}
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
div.nbinput.container div.input_area > div[class^='highlight'],
div.nboutput.container div.output_area > div[class^='highlight']{
overflow-y: hidden;
}
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
.prompt a.copybtn {
display: none;
}
/* Some additional styling taken form the Jupyter notebook CSS */
div.rendered_html table {
border: none;
border-collapse: collapse;
border-spacing: 0;
color: black;
font-size: 12px;
table-layout: fixed;
}
div.rendered_html thead {
border-bottom: 1px solid black;
vertical-align: bottom;
}
div.rendered_html tr,
div.rendered_html th,
div.rendered_html td {
text-align: right;
vertical-align: middle;
padding: 0.5em 0.5em;
line-height: normal;
white-space: normal;
max-width: none;
border: none;
}
div.rendered_html th {
font-weight: bold;
}
div.rendered_html tbody tr:nth-child(odd) {
background: #f5f5f5;
}
div.rendered_html tbody tr:hover {
background: rgba(66, 165, 245, 0.2);
}
/* CSS overrides for sphinx_rtd_theme */
/* 24px margin */
.nbinput.nblast.container,
.nboutput.nblast.container {
margin-bottom: 19px; /* padding has already 5px */
}
/* ... except between code cells! */
.nblast.container + .nbinput.container {
margin-top: -19px;
}
.admonition > p:before {
margin-right: 4px; /* make room for the exclamation icon */
}
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
.math {
text-align: unset;
}
</style>
<div class="section" id="Vector-Addition">
<h1>Vector Addition<a class="headerlink" href="#Vector-Addition" title="Permalink to this headline"></a></h1>
<p>In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: * The basic syntax of the Triton programming language * The best practices for creating PyTorch custom operators using the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API * The best practices for validating and benchmarking custom ops against native reference implementations</p>
<div class="section" id="Writing-the-Compute-Kernel">
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline"></a></h2>
<p>Each compute kernel is declared using the <code class="docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel on different chunks of data (See the <a class="reference external" href="https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a> programming model for more details).</p>
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
<span class="c1">// of the program in the overaching SPMD context</span>
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
<span class="c1">// different chunks of data in parallel.</span>
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
<span class="c1">// memory.</span>
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
<span class="c1">// whose memory values are, e.g.:</span>
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
<span class="c1">// write-back `z`</span>
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span>
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL2019 Triton paper</a>.</p>
</div>
<div class="section" id="Writing-the-Torch-bindings">
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline"></a></h2>
<p>The only thing that matters when it comes to Triton and Torch is the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects.</p>
<p>To create a <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things: * <code class="docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create * <code class="docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for * <code class="docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <code class="docutils literal notranslate"><span class="pre">#define</span></code> for you</p>
<p>Note: The constructor of <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see <code class="docutils literal notranslate"><span class="pre">_kernels</span></code> variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operators inputs.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[10]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="c1"># source-code for Triton compute kernel</span>
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
<span class="n">_src</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
<span class="s2"> // program id</span>
<span class="s2"> int pid = get_program_id(0);</span>
<span class="s2"> // create arrays of pointers</span>
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
<span class="s2"> float* px[BLOCK] = x + offset;</span>
<span class="s2"> float* py[BLOCK] = y + offset;</span>
<span class="s2"> // bounds checking</span>
<span class="s2"> bool check[BLOCK] = offset &lt; N;</span>
<span class="s2"> // write-back</span>
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
<span class="s2">}</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="c1"># This function returns a callable `triton.kernel` object</span>
<span class="c1"># created from the above source code.</span>
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
<span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
<span class="c1"># This is a standard torch custom autograd Function</span>
<span class="c1"># The only difference is that we can now use the above kernel</span>
<span class="c1"># in the `forward` and `backward` functions.`</span>
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="c1"># constraints of the op</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="c1"># *allocate output*</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># *create launch grid*:</span>
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
<span class="c1"># *launch kernel*:</span>
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
<span class="c1"># the `.data_ptr()` method</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
<span class="k">return</span> <span class="n">z</span>
<span class="c1"># Just like we standard PyTorch ops</span>
<span class="c1"># We use the `.apply` method to create a</span>
<span class="c1"># callable object for our function</span>
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
</pre></div>
</div>
</div>
<p>At this point <code class="docutils literal notranslate"><span class="pre">add(x,</span> <span class="pre">y)</span></code> is equivalent to <code class="docutils literal notranslate"><span class="pre">x</span> <span class="pre">+</span> <span class="pre">y</span></code> for contiguous tensors. Now lets test and benchmark it!</p>
</div>
<div class="section" id="Writing-a-Unit-Test">
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline"></a></h2>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 0.0
</pre></div></div>
</div>
<p>Seems to work!</p>
</div>
<div class="section" id="Writing-a-Benchmark">
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline"></a></h2>
<p>The performance of our GPU code can be benchmark using the <code class="docutils literal notranslate"><span class="pre">torch.cuda.Event(enable_timing=True)</span></code> wrapper. Below is a simple function that benchmarks <code class="docutils literal notranslate"><span class="pre">rep</span></code> runs of our kernels after <code class="docutils literal notranslate"><span class="pre">warmup</span></code> “cold” runs.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[11]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="c1"># We now want to benchmark the performance of `add`</span>
<span class="c1"># Against that of PyTorch for increasing vector sizes</span>
<span class="k">def</span> <span class="nf">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">warmup</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">rep</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
<span class="n">start_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">end_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">warmup</span><span class="p">):</span>
<span class="n">fn</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">start_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rep</span><span class="p">):</span>
<span class="n">fn</span><span class="p">()</span>
<span class="n">end_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">time_ms</span> <span class="o">=</span> <span class="n">start_event</span><span class="o">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="n">end_event</span><span class="p">)</span> <span class="o">/</span> <span class="n">rep</span>
<span class="k">return</span> <span class="n">time_ms</span>
</pre></div>
</div>
</div>
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[15]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="mi">26</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">triton_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="n">torch_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
<span class="c1"># print the performance of triton and torch as well as the achieved bandwidth</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">N</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">triton_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">torch_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
131072 0.020 0.003
262144 0.019 0.004
524288 0.016 0.016
1048576 0.033 0.033
2097152 0.071 0.070
4194304 0.142 0.144
8388608 0.287 0.286
16777216 0.572 0.568
33554432 1.139 1.110
</pre></div></div>
</div>
<p>Our op is on-par with Torchs vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce.</p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="../installation/from-source.html" class="btn btn-neutral float-left" title="From Source" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

View File

@@ -0,0 +1,329 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "acute-possession",
"metadata": {},
"source": [
"# Vector Addition"
]
},
{
"cell_type": "markdown",
"id": "median-malaysia",
"metadata": {},
"source": [
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
"* The basic syntax of the Triton programming language\n",
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
"* The best practices for validating and benchmarking custom ops against native reference implementations"
]
},
{
"cell_type": "markdown",
"id": "identical-conditions",
"metadata": {},
"source": [
"## Writing the Compute Kernel"
]
},
{
"cell_type": "markdown",
"id": "collectible-belle",
"metadata": {},
"source": [
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
"\n",
"\n",
"```c\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // The `get_program_id(i)` returns the i-th coordinate\n",
" // of the program in the overaching SPMD context\n",
" // (a.k.a launch grid). This is what allows us to process\n",
" // different chunks of data in parallel.\n",
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
" // is similar to blockIdx.{x,y,z}\n",
" int pid = get_program_id(0);\n",
" // In Triton, arrays are first-class citizen. In other words,\n",
" // they are primitives data-types and are -- contrary to C and\n",
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
" // memory.\n",
" // In the few lines below, we create an array of `BLOCK` pointers\n",
" // whose memory values are, e.g.:\n",
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz [BLOCK] = z + offset;\n",
" float* px [BLOCK] = x + offset;\n",
" float* py [BLOCK] = y + offset;\n",
" // Simple element-wise control-flow for load/store operations can\n",
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
" // or the conditional dereferencing operator `*?(cond)ptr\n",
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
" // write-back `z`\n",
" bool check[BLOCK] = offset < N;\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
"```\n",
"\n",
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
]
},
{
"cell_type": "markdown",
"id": "forbidden-wednesday",
"metadata": {},
"source": [
"## Writing the Torch bindings"
]
},
{
"cell_type": "markdown",
"id": "numerical-agency",
"metadata": {},
"source": [
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
"\n",
"To create a `triton.kernel`, you only need three things:\n",
"* `source: string`: the source-code of the kernel you want to create\n",
"* `device: torch.device`: the device you want to compile this code for\n",
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
"\n",
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "sporting-keyboard",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"# here we just copy-paste the above code without the extensive comments.\n",
"# you may prefer to store it in a .c file and load it from there instead.\n",
"_src = \"\"\"\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // program id\n",
" int pid = get_program_id(0);\n",
" // create arrays of pointers\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz[BLOCK] = z + offset;\n",
" float* px[BLOCK] = x + offset;\n",
" float* py[BLOCK] = y + offset;\n",
" // bounds checking\n",
" bool check[BLOCK] = offset < N;\n",
" // write-back\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
" \"\"\"\n",
"# This function returns a callable `triton.kernel` object\n",
"# created from the above source code.\n",
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
"# We compile the kernel with -DBLOCK=1024\n",
"_kernels = dict()\n",
"def make_add_kernel(device):\n",
" if device not in _kernels:\n",
" defines = {'BLOCK': 1024}\n",
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
" return _kernels[device]\n",
"\n",
"# This is a standard torch custom autograd Function\n",
"# The only difference is that we can now use the above kernel\n",
"# in the `forward` and `backward` functions.`\n",
"class _add(torch.autograd.Function):\n",
" \n",
" @staticmethod\n",
" def forward(ctx, x, y):\n",
" # constraints of the op\n",
" assert x.dtype == torch.float32\n",
" # *allocate output*\n",
" z = torch.empty_like(x)\n",
" # *create launch grid*:\n",
" # this is a function which takes compilation parameters `opt`\n",
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
" # triton.cdiv is a shortcut for ceil division:\n",
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
" N = z.shape[0]\n",
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
" # *launch kernel*:\n",
" # pointer to the data of torch tensors can be retrieved with\n",
" # the `.data_ptr()` method\n",
" kernel = make_add_kernel(z.device)\n",
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
" return z\n",
"# Just like we standard PyTorch ops\n",
"# We use the `.apply` method to create a \n",
"# callable object for our function\n",
"add = _add.apply"
]
},
{
"cell_type": "markdown",
"id": "separated-polyester",
"metadata": {},
"source": [
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
]
},
{
"cell_type": "markdown",
"id": "exclusive-salvation",
"metadata": {},
"source": [
"## Writing a Unit Test"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "supported-ribbon",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"The maximum difference between torch and triton is 0.0\n"
]
}
],
"source": [
"torch.manual_seed(0)\n",
"x = torch.rand(98432, device='cuda')\n",
"y = torch.rand(98432, device='cuda')\n",
"za = x + y\n",
"zb = add(x, y)\n",
"print(za)\n",
"print(zb)\n",
"print(f'The maximum difference between torch and triton is '\n",
" f'{torch.max(torch.abs(za - zb))}')"
]
},
{
"cell_type": "markdown",
"id": "otherwise-canadian",
"metadata": {},
"source": [
"Seems to work!"
]
},
{
"cell_type": "markdown",
"id": "polished-australia",
"metadata": {},
"source": [
"## Writing a Benchmark"
]
},
{
"cell_type": "markdown",
"id": "historic-glass",
"metadata": {},
"source": [
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "strange-luxembourg",
"metadata": {},
"outputs": [],
"source": [
"# We now want to benchmark the performance of `add`\n",
"# Against that of PyTorch for increasing vector sizes\n",
"def do_bench(fn, warmup = 10, rep = 50):\n",
" start_event = torch.cuda.Event(enable_timing=True)\n",
" end_event = torch.cuda.Event(enable_timing=True)\n",
" ret = fn()\n",
" for i in range(warmup):\n",
" fn()\n",
" torch.cuda.synchronize()\n",
" start_event.record()\n",
" for i in range(rep):\n",
" fn()\n",
" end_event.record()\n",
" torch.cuda.synchronize()\n",
" time_ms = start_event.elapsed_time(end_event) / rep\n",
" return time_ms"
]
},
{
"cell_type": "markdown",
"id": "hairy-claim",
"metadata": {},
"source": [
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "pleasant-valley",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"131072 0.020 0.003\n",
"262144 0.019 0.004\n",
"524288 0.016 0.016\n",
"1048576 0.033 0.033\n",
"2097152 0.071 0.070\n",
"4194304 0.142 0.144\n",
"8388608 0.287 0.286\n",
"16777216 0.572 0.568\n",
"33554432 1.139 1.110\n"
]
}
],
"source": [
"for N in [2**i for i in range(17, 26, 1)]:\n",
" x = torch.rand(N, device='cuda')\n",
" y = torch.rand(N, device='cuda')\n",
" triton_ms = do_bench(lambda: add(x, y))\n",
" torch_ms = do_bench(lambda: x + y)\n",
" # print the performance of triton and torch as well as the achieved bandwidth\n",
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
]
},
{
"cell_type": "markdown",
"id": "juvenile-supplement",
"metadata": {},
"source": [
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,705 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Fused Softmax &mdash; Triton documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/doctools.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
</ul>
<p class="caption"><span class="caption-text">Tutorials</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
</ul>
</li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
<li>Fused Softmax</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/tutorials/02-fused-softmax.ipynb.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<style>
/* CSS for nbsphinx extension */
/* remove conflicting styling from Sphinx themes */
div.nbinput.container div.prompt *,
div.nboutput.container div.prompt *,
div.nbinput.container div.input_area pre,
div.nboutput.container div.output_area pre,
div.nbinput.container div.input_area .highlight,
div.nboutput.container div.output_area .highlight {
border: none;
padding: 0;
margin: 0;
box-shadow: none;
}
div.nbinput.container > div[class*=highlight],
div.nboutput.container > div[class*=highlight] {
margin: 0;
}
div.nbinput.container div.prompt *,
div.nboutput.container div.prompt * {
background: none;
}
div.nboutput.container div.output_area .highlight,
div.nboutput.container div.output_area pre {
background: unset;
}
div.nboutput.container div.output_area div.highlight {
color: unset; /* override Pygments text color */
}
/* avoid gaps between output lines */
div.nboutput.container div[class*=highlight] pre {
line-height: normal;
}
/* input/output containers */
div.nbinput.container,
div.nboutput.container {
display: -webkit-flex;
display: flex;
align-items: flex-start;
margin: 0;
width: 100%;
}
@media (max-width: 540px) {
div.nbinput.container,
div.nboutput.container {
flex-direction: column;
}
}
/* input container */
div.nbinput.container {
padding-top: 5px;
}
/* last container */
div.nblast.container {
padding-bottom: 5px;
}
/* input prompt */
div.nbinput.container div.prompt pre {
color: #307FC1;
}
/* output prompt */
div.nboutput.container div.prompt pre {
color: #BF5B3D;
}
/* all prompts */
div.nbinput.container div.prompt,
div.nboutput.container div.prompt {
width: 4.5ex;
padding-top: 5px;
position: relative;
user-select: none;
}
div.nbinput.container div.prompt > div,
div.nboutput.container div.prompt > div {
position: absolute;
right: 0;
margin-right: 0.3ex;
}
@media (max-width: 540px) {
div.nbinput.container div.prompt,
div.nboutput.container div.prompt {
width: unset;
text-align: left;
padding: 0.4em;
}
div.nboutput.container div.prompt.empty {
padding: 0;
}
div.nbinput.container div.prompt > div,
div.nboutput.container div.prompt > div {
position: unset;
}
}
/* disable scrollbars on prompts */
div.nbinput.container div.prompt pre,
div.nboutput.container div.prompt pre {
overflow: hidden;
}
/* input/output area */
div.nbinput.container div.input_area,
div.nboutput.container div.output_area {
-webkit-flex: 1;
flex: 1;
overflow: auto;
}
@media (max-width: 540px) {
div.nbinput.container div.input_area,
div.nboutput.container div.output_area {
width: 100%;
}
}
/* input area */
div.nbinput.container div.input_area {
border: 1px solid #e0e0e0;
border-radius: 2px;
/*background: #f5f5f5;*/
}
/* override MathJax center alignment in output cells */
div.nboutput.container div[class*=MathJax] {
text-align: left !important;
}
/* override sphinx.ext.imgmath center alignment in output cells */
div.nboutput.container div.math p {
text-align: left;
}
/* standard error */
div.nboutput.container div.output_area.stderr {
background: #fdd;
}
/* ANSI colors */
.ansi-black-fg { color: #3E424D; }
.ansi-black-bg { background-color: #3E424D; }
.ansi-black-intense-fg { color: #282C36; }
.ansi-black-intense-bg { background-color: #282C36; }
.ansi-red-fg { color: #E75C58; }
.ansi-red-bg { background-color: #E75C58; }
.ansi-red-intense-fg { color: #B22B31; }
.ansi-red-intense-bg { background-color: #B22B31; }
.ansi-green-fg { color: #00A250; }
.ansi-green-bg { background-color: #00A250; }
.ansi-green-intense-fg { color: #007427; }
.ansi-green-intense-bg { background-color: #007427; }
.ansi-yellow-fg { color: #DDB62B; }
.ansi-yellow-bg { background-color: #DDB62B; }
.ansi-yellow-intense-fg { color: #B27D12; }
.ansi-yellow-intense-bg { background-color: #B27D12; }
.ansi-blue-fg { color: #208FFB; }
.ansi-blue-bg { background-color: #208FFB; }
.ansi-blue-intense-fg { color: #0065CA; }
.ansi-blue-intense-bg { background-color: #0065CA; }
.ansi-magenta-fg { color: #D160C4; }
.ansi-magenta-bg { background-color: #D160C4; }
.ansi-magenta-intense-fg { color: #A03196; }
.ansi-magenta-intense-bg { background-color: #A03196; }
.ansi-cyan-fg { color: #60C6C8; }
.ansi-cyan-bg { background-color: #60C6C8; }
.ansi-cyan-intense-fg { color: #258F8F; }
.ansi-cyan-intense-bg { background-color: #258F8F; }
.ansi-white-fg { color: #C5C1B4; }
.ansi-white-bg { background-color: #C5C1B4; }
.ansi-white-intense-fg { color: #A1A6B2; }
.ansi-white-intense-bg { background-color: #A1A6B2; }
.ansi-default-inverse-fg { color: #FFFFFF; }
.ansi-default-inverse-bg { background-color: #000000; }
.ansi-bold { font-weight: bold; }
.ansi-underline { text-decoration: underline; }
div.nbinput.container div.input_area div[class*=highlight] > pre,
div.nboutput.container div.output_area div[class*=highlight] > pre,
div.nboutput.container div.output_area div[class*=highlight].math,
div.nboutput.container div.output_area.rendered_html,
div.nboutput.container div.output_area > div.output_javascript,
div.nboutput.container div.output_area:not(.rendered_html) > img{
padding: 5px;
margin: 0;
}
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
div.nbinput.container div.input_area > div[class^='highlight'],
div.nboutput.container div.output_area > div[class^='highlight']{
overflow-y: hidden;
}
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
.prompt a.copybtn {
display: none;
}
/* Some additional styling taken form the Jupyter notebook CSS */
div.rendered_html table {
border: none;
border-collapse: collapse;
border-spacing: 0;
color: black;
font-size: 12px;
table-layout: fixed;
}
div.rendered_html thead {
border-bottom: 1px solid black;
vertical-align: bottom;
}
div.rendered_html tr,
div.rendered_html th,
div.rendered_html td {
text-align: right;
vertical-align: middle;
padding: 0.5em 0.5em;
line-height: normal;
white-space: normal;
max-width: none;
border: none;
}
div.rendered_html th {
font-weight: bold;
}
div.rendered_html tbody tr:nth-child(odd) {
background: #f5f5f5;
}
div.rendered_html tbody tr:hover {
background: rgba(66, 165, 245, 0.2);
}
/* CSS overrides for sphinx_rtd_theme */
/* 24px margin */
.nbinput.nblast.container,
.nboutput.nblast.container {
margin-bottom: 19px; /* padding has already 5px */
}
/* ... except between code cells! */
.nblast.container + .nbinput.container {
margin-top: -19px;
}
.admonition > p:before {
margin-right: 4px; /* make room for the exclamation icon */
}
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
.math {
text-align: unset;
}
</style>
<div class="section" id="Fused-Softmax">
<h1>Fused Softmax<a class="headerlink" href="#Fused-Softmax" title="Permalink to this headline"></a></h1>
<p>Custom GPU kernels for elementwise additions are educationally valuable but wont get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="c1"># Compute the row-wise softmax of x \in R^{M \times N}</span>
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="c1"># read MN elements ; write M elements</span>
<span class="n">x_max</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># read 2MN elements ; write MN elements</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># read MN elements ; write MN elements</span>
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># read MN elements ; write M elements</span>
<span class="n">denominator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># read 2MN elements ; write MN elements</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
<span class="k">return</span> <span class="n">ret</span>
</pre></div>
</div>
</div>
<p>When implemented naively in pytorch, computing <span class="math notranslate nohighlight">\(y\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.</p>
<p>Instead, we want to write a custom “fused” pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.</p>
<div class="section" id="Writing-the-Compute-Kernel">
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline"></a></h2>
<p>Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:</p>
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
<span class="c1">// row index</span>
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="c1">// column indices</span>
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
<span class="c1">// the memory address of all the elements</span>
<span class="c1">// that we want to load can be computed as follows</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
<span class="c1">// because BLOCK has to be a power of two</span>
<span class="c1">// (per Triton-C specs), it is important</span>
<span class="c1">// to guard each memory operation with predicates</span>
<span class="c1">// or we will read out of bounds</span>
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span>
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
<span class="c1">// syntax for reduction in Triton is:</span>
<span class="c1">// x[..., OPERATOR, ...]</span>
<span class="c1">// ^</span>
<span class="c1">// index</span>
<span class="c1">// The operators currently supported are {min, max, +}</span>
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
<span class="c1">// The exponential in Triton is fast but approximate</span>
<span class="c1">// (i.e., like __expf in CUDA)</span>
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
<span class="c1">// The result of the reduction is now stored in y</span>
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
<span class="c1">// We write it back</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="Writing-the-Torch-bindings">
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline"></a></h2>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="c1"># source-code for Triton compute kernel</span>
<span class="n">_src</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
<span class="s2"> int m = get_program_id(0);</span>
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
<span class="s2"> bool check[BLOCK] = n &lt; N;</span>
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
<span class="s2"> float z [BLOCK] = x - x[max];</span>
<span class="s2"> float num [BLOCK] = exp(z);</span>
<span class="s2"> float denom = num[+];</span>
<span class="s2"> float y [BLOCK] = num / denom;</span>
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
<span class="s2"> *?(check)py = y;</span>
<span class="s2">}</span>
<span class="s2">&quot;&quot;&quot;</span>
<span class="c1"># We need to make sure that BLOCK is the smallest power of two</span>
<span class="c1"># greater than the number of rows N of the input matrix.</span>
<span class="c1"># Different values of BLOCK will result in different kernels</span>
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">&gt;&gt;</span> <span class="mi">1</span>
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">&gt;&gt;</span> <span class="mi">2</span>
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">&gt;&gt;</span> <span class="mi">4</span>
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">&gt;&gt;</span> <span class="mi">8</span>
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">&gt;&gt;</span> <span class="mi">16</span>
<span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">n</span>
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
<span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="c1"># constraints of the op</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># *create launch grid*:</span>
<span class="c1"># here we just launch a grid of M programs</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
<span class="c1"># *launch kernel*:</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
<span class="k">return</span> <span class="n">y</span>
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="Writing-a-Unit-Test">
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline"></a></h2>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">y_tri</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">y_ref</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">))</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
...,
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
device=&#39;cuda:0&#39;)
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
...,
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
device=&#39;cuda:0&#39;)
True
</pre></div></div>
</div>
<p>Seems to work!</p>
</div>
<div class="section" id="Writing-a-Benchmark">
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline"></a></h2>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
<span class="n">Ns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="o">*</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)]</span>
<span class="n">tri_ms</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">ref_ms</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">def_ms</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="n">Ns</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
<span class="n">tri_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
<span class="n">ref_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))]</span>
<span class="n">def_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">&#39;N&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">&#39;Bandwidth (GB/s)&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">tri_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">&#39;Triton&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">ref_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">&#39;Torch&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">def_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">&#39;Naive&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<img alt="../_images/tutorials_02-fused-softmax_12_0.png" src="../_images/tutorials_02-fused-softmax_12_0.png" />
</div>
</div>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>

File diff suppressed because one or more lines are too long