2021-03-06 17:35:11 -05:00
<!DOCTYPE html>
< html class = "writer-html5" lang = "en" >
< head >
< meta charset = "utf-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title > Vector Addition — Triton documentation< / title >
< link rel = "stylesheet" href = "../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/pygments.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-binder.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-dataframe.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../_static/gallery-rendered-html.css" type = "text/css" / >
2021-03-19 16:19:37 -04:00
< link rel = "stylesheet" href = "../../_static/css/custom.css" type = "text/css" / >
2021-03-06 17:35:11 -05:00
<!-- [if lt IE 9]>
< script src = "../../_static/js/html5shiv.min.js" > < / script >
<![endif]-->
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../" src = "../../_static/documentation_options.js" > < / script >
< script src = "../../_static/jquery.js" > < / script >
< script src = "../../_static/underscore.js" > < / script >
< script src = "../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../_static/js/theme.js" > < / script >
< link rel = "index" title = "Index" href = "../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../search.html" / >
< link rel = "next" title = "Fused Softmax" href = "02-fused-softmax.html" / >
< link rel = "prev" title = "Tutorials" href = "index.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../index.html" class = "icon icon-home" > Triton
< / a >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< p class = "caption" > < span class = "caption-text" > Getting Started< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "../installation.html" > Installation< / a > < / li >
< li class = "toctree-l1 current" > < a class = "reference internal" href = "index.html" > Tutorials< / a > < ul class = "current" >
< li class = "toctree-l2 current" > < a class = "current reference internal" href = "#" > Vector Addition< / a > < ul >
< li class = "toctree-l3" > < a class = "reference internal" href = "#compute-kernel" > Compute Kernel< / a > < / li >
2021-03-15 13:58:20 -04:00
< li class = "toctree-l3" > < a class = "reference internal" href = "#torch-bindings" > Torch Bindings< / a > < / li >
2021-03-06 17:35:11 -05:00
< li class = "toctree-l3" > < a class = "reference internal" href = "#unit-test" > Unit Test< / a > < / li >
2021-03-15 13:58:20 -04:00
< li class = "toctree-l3" > < a class = "reference internal" href = "#benchmark" > Benchmark< / a > < / li >
2021-03-06 17:35:11 -05:00
< / ul >
< / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "02-fused-softmax.html" > Fused Softmax< / a > < / li >
2021-03-15 13:58:20 -04:00
< li class = "toctree-l2" > < a class = "reference internal" href = "03-matrix-multiplication.html" > Matrix Multiplication< / a > < / li >
2021-03-06 17:35:11 -05:00
< / ul >
< / li >
< / ul >
2021-03-19 16:19:37 -04:00
< p class = "caption" > < span class = "caption-text" > Programming Guide< / span > < / p >
< ul >
2021-03-23 17:10:07 -04:00
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-1/introduction.html" > Introduction< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-2/related-work.html" > Related Work< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-3/triton-c.html" > The Triton-C Language< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../programming-guide/chapter-4/triton-ir.html" > The Triton-IR Intermediate Representation< / a > < / li >
2021-03-19 16:19:37 -04:00
< / ul >
2021-03-06 17:35:11 -05:00
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../index.html" > Triton< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../index.html" class = "icon icon-home" > < / a > » < / li >
< li > < a href = "index.html" > Tutorials< / a > » < / li >
< li > Vector Addition< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "sphx-glr-download-link-note admonition note" >
< p class = "admonition-title" > Note< / p >
< p > Click < a class = "reference internal" href = "#sphx-glr-download-getting-started-tutorials-01-vector-add-py" > < span class = "std std-ref" > here< / span > < / a >
to download the full example code< / p >
< / div >
< div class = "sphx-glr-example-title section" id = "vector-addition" >
< span id = "sphx-glr-getting-started-tutorials-01-vector-add-py" > < / span > < h1 > Vector Addition< a class = "headerlink" href = "#vector-addition" title = "Permalink to this headline" > ¶< / a > < / h1 >
2021-03-06 22:06:32 -05:00
< p > In this tutorial, you will write a simple vector addition using Triton and learn about:< / p >
2021-03-06 17:35:11 -05:00
< ul class = "simple" >
< li > < p > The basic syntax of the Triton programming language< / p > < / li >
< li > < p > The best practices for creating PyTorch custom operators using the < code class = "code docutils literal notranslate" > < span class = "pre" > triton.kernel< / span > < / code > Python API< / p > < / li >
< li > < p > The best practices for validating and benchmarking custom ops against native reference implementations< / p > < / li >
< / ul >
< div class = "section" id = "compute-kernel" >
< h2 > Compute Kernel< a class = "headerlink" href = "#compute-kernel" title = "Permalink to this headline" > ¶< / a > < / h2 >
< p > Each compute kernel is declared using the < code class = "code docutils literal notranslate" > < span class = "pre" > __global__< / span > < / code > attribute, and executed many times in parallel
on different chunks of data (See the < a class = "reference external" href = "(https://en.wikipedia.org/wiki/SPMD" > Single Program, Multiple Data< / a > )
programming model for more details).< / p >
< blockquote >
< div > < div class = "highlight-C notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > __global__< / span > < span class = "kt" > void< / span > < span class = "n" > add< / span > < span class = "p" > (< / span > < span class = "kt" > float< / span > < span class = "o" > *< / span > < span class = "n" > z< / span > < span class = "p" > ,< / span > < span class = "kt" > float< / span > < span class = "o" > *< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "kt" > float< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "p" > ,< / span > < span class = "kt" > int< / span > < span class = "n" > N< / span > < span class = "p" > ){< / span >
< span class = "c1" > // The `get_program_id(i)` returns the i-th coordinate< / span >
< span class = "c1" > // of the program in the overaching SPMD context< / span >
< span class = "c1" > // (a.k.a launch grid). This is what allows us to process< / span >
< span class = "c1" > // different chunks of data in parallel.< / span >
< span class = "c1" > // For those similar with CUDA, `get_program_id({0,1,2})`< / span >
< span class = "c1" > // is similar to blockIdx.{x,y,z}< / span >
< span class = "kt" > int< / span > < span class = "n" > pid< / span > < span class = "o" > =< / span > < span class = "n" > get_program_id< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > );< / span >
< span class = "c1" > // In Triton, arrays are first-class citizen. In other words,< / span >
< span class = "c1" > // they are primitives data-types and are -- contrary to C and< / span >
< span class = "c1" > // CUDA -- not implemented as pointers to contiguous chunks of< / span >
< span class = "c1" > // memory.< / span >
< span class = "c1" > // In the few lines below, we create an array of `BLOCK` pointers< / span >
< span class = "c1" > // whose memory values are, e.g.:< / span >
< span class = "c1" > // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]< / span >
< span class = "c1" > // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time< / span >
< span class = "kt" > int< / span > < span class = "n" > offset< / span > < span class = "p" > [< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > pid< / span > < span class = "o" > *< / span > < span class = "n" > BLOCK< / span > < span class = "o" > +< / span > < span class = "mi" > 0< / span > < span class = "p" > ...< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ;< / span >
< span class = "kt" > float< / span > < span class = "o" > *< / span > < span class = "n" > pz< / span > < span class = "p" > [< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > z< / span > < span class = "o" > +< / span > < span class = "n" > offset< / span > < span class = "p" > ;< / span >
< span class = "kt" > float< / span > < span class = "o" > *< / span > < span class = "n" > px< / span > < span class = "p" > [< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > +< / span > < span class = "n" > offset< / span > < span class = "p" > ;< / span >
< span class = "kt" > float< / span > < span class = "o" > *< / span > < span class = "n" > py< / span > < span class = "p" > [< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > y< / span > < span class = "o" > +< / span > < span class = "n" > offset< / span > < span class = "p" > ;< / span >
< span class = "c1" > // Simple element-wise control-flow for load/store operations can< / span >
< span class = "c1" > // be achieved using the the ternary operator `cond ? val_true : val_false`< / span >
< span class = "c1" > // or the conditional dereferencing operator `*?(cond)ptr< / span >
< span class = "c1" > // Here, we make sure that we do not access memory out-of-bounds when we< / span >
< span class = "c1" > // write-back `z`< / span >
< span class = "kt" > bool< / span > < span class = "n" > check< / span > < span class = "p" > [< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > offset< / span > < span class = "o" > < < / span > < span class = "n" > N< / span > < span class = "p" > ;< / span >
< span class = "o" > *?< / span > < span class = "p" > (< / span > < span class = "n" > check< / span > < span class = "p" > )< / span > < span class = "n" > pz< / span > < span class = "o" > =< / span > < span class = "o" > *?< / span > < span class = "p" > (< / span > < span class = "n" > check< / span > < span class = "p" > )< / span > < span class = "n" > px< / span > < span class = "o" > +< / span > < span class = "o" > *?< / span > < span class = "p" > (< / span > < span class = "n" > check< / span > < span class = "p" > )< / span > < span class = "n" > py< / span > < span class = "p" > ;< / span >
< span class = "p" > }< / span >
< / pre > < / div >
< / div >
< / div > < / blockquote >
< p > The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the < a class = "reference external" href = "http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf" > MAPL’ 2019 Triton paper< / a > .< / p >
< / div >
< div class = "section" id = "torch-bindings" >
2021-03-15 13:58:20 -04:00
< h2 > Torch Bindings< a class = "headerlink" href = "#torch-bindings" title = "Permalink to this headline" > ¶< / a > < / h2 >
2021-03-06 17:35:11 -05:00
< p > The only thing that matters when it comes to Triton and Torch is the < code class = "code docutils literal notranslate" > < span class = "pre" > triton.kernel< / span > < / code > class. This allows you to transform the above C-like function into a callable python object that can be used to modify < code class = "code docutils literal notranslate" > < span class = "pre" > torch.tensor< / span > < / code > objects. To create a < code class = "code docutils literal notranslate" > < span class = "pre" > triton.kernel< / span > < / code > , you only need three things:< / p >
< ul class = "simple" >
< li > < p > < code class = "code docutils literal notranslate" > < span class = "pre" > source:< / span > < span class = "pre" > string< / span > < / code > : the source-code of the kernel you want to create< / p > < / li >
< li > < p > < code class = "code docutils literal notranslate" > < span class = "pre" > device:< / span > < span class = "pre" > torch.device< / span > < / code > : the device you want to compile this code for< / p > < / li >
< li > < p > < code class = "code docutils literal notranslate" > < span class = "pre" > defines:< / span > < span class = "pre" > dict< / span > < / code > : the set of macros that you want the pre-processor to < cite > #define< / cite > for you< / p > < / li >
< / ul >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > triton< / span >
< span class = "c1" > # source-code for Triton compute kernel< / span >
< span class = "c1" > # here we just copy-paste the above code without the extensive comments.< / span >
< span class = "c1" > # you may prefer to store it in a .c file and load it from there instead.< / span >
< span class = "n" > _src< / span > < span class = "o" > =< / span > < span class = "s2" > " " " < / span >
< span class = "s2" > __global__ void add(float* z, float* x, float* y, int N){< / span >
< span class = "s2" > // program id< / span >
< span class = "s2" > int pid = get_program_id(0);< / span >
< span class = "s2" > // create arrays of pointers< / span >
< span class = "s2" > int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;< / span >
< span class = "s2" > float* pz[BLOCK] = z + offset;< / span >
< span class = "s2" > float* px[BLOCK] = x + offset;< / span >
< span class = "s2" > float* py[BLOCK] = y + offset;< / span >
< span class = "s2" > // bounds checking< / span >
< span class = "s2" > bool check[BLOCK] = offset < N;< / span >
< span class = "s2" > // write-back< / span >
< span class = "s2" > *?(check)pz = *?(check)px + *?(check)py;< / span >
< span class = "s2" > }< / span >
< span class = "s2" > " " " < / span >
< span class = "c1" > # This function returns a callable `triton.kernel` object created from the above source code.< / span >
< span class = "c1" > # For portability, we maintain a cache of kernels for different `torch.device`< / span >
< span class = "c1" > # We compile the kernel with -DBLOCK=1024< / span >
< span class = "k" > def< / span > < span class = "nf" > make_add_kernel< / span > < span class = "p" > (< / span > < span class = "n" > device< / span > < span class = "p" > ):< / span >
< span class = "n" > cache< / span > < span class = "o" > =< / span > < span class = "n" > make_add_kernel< / span > < span class = "o" > .< / span > < span class = "n" > cache< / span >
< span class = "k" > if< / span > < span class = "n" > device< / span > < span class = "ow" > not< / span > < span class = "ow" > in< / span > < span class = "n" > cache< / span > < span class = "p" > :< / span >
< span class = "n" > defines< / span > < span class = "o" > =< / span > < span class = "p" > {< / span > < span class = "s1" > ' BLOCK' < / span > < span class = "p" > :< / span > < span class = "mi" > 1024< / span > < span class = "p" > }< / span >
< span class = "n" > cache< / span > < span class = "p" > [< / span > < span class = "n" > device< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > kernel< / span > < span class = "p" > (< / span > < span class = "n" > _src< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "n" > device< / span > < span class = "p" > ,< / span > < span class = "n" > defines< / span > < span class = "o" > =< / span > < span class = "n" > defines< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > cache< / span > < span class = "p" > [< / span > < span class = "n" > device< / span > < span class = "p" > ]< / span >
< span class = "n" > make_add_kernel< / span > < span class = "o" > .< / span > < span class = "n" > cache< / span > < span class = "o" > =< / span > < span class = "nb" > dict< / span > < span class = "p" > ()< / span >
< span class = "c1" > # This is a standard torch custom autograd Function;< / span >
< span class = "c1" > # The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`< / span >
< span class = "k" > class< / span > < span class = "nc" > _add< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > autograd< / span > < span class = "o" > .< / span > < span class = "n" > Function< / span > < span class = "p" > ):< / span >
< span class = "nd" > @staticmethod< / span >
< span class = "k" > def< / span > < span class = "nf" > forward< / span > < span class = "p" > (< / span > < span class = "n" > ctx< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "c1" > # constraints of the op< / span >
< span class = "k" > assert< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > dtype< / span > < span class = "o" > ==< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span >
< span class = "c1" > # *allocate output*< / span >
< span class = "n" > z< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > empty_like< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > )< / span >
< span class = "c1" > # *create launch grid*:< / span >
< span class = "c1" > # this is a function which takes compilation parameters `opt`< / span >
< span class = "c1" > # as input and returns a tuple of int (i.e., launch grid) for the kernel.< / span >
< span class = "c1" > # triton.cdiv is a shortcut for ceil division:< / span >
< span class = "c1" > # triton.cdiv(a, b) = (a + b - 1) // b< / span >
< span class = "n" > N< / span > < span class = "o" > =< / span > < span class = "n" > z< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > opt< / span > < span class = "p" > :< / span > < span class = "p" > (< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > cdiv< / span > < span class = "p" > (< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > opt< / span > < span class = "o" > .< / span > < span class = "n" > BLOCK< / span > < span class = "p" > ),< / span > < span class = "p" > )< / span >
< span class = "c1" > # *launch kernel*:< / span >
< span class = "c1" > # pointer to the data of torch tensors can be retrieved with< / span >
< span class = "c1" > # the `.data_ptr()` method< / span >
< span class = "n" > kernel< / span > < span class = "o" > =< / span > < span class = "n" > make_add_kernel< / span > < span class = "p" > (< / span > < span class = "n" > z< / span > < span class = "o" > .< / span > < span class = "n" > device< / span > < span class = "p" > )< / span >
< span class = "n" > kernel< / span > < span class = "p" > (< / span > < span class = "n" > z< / span > < span class = "o" > .< / span > < span class = "n" > data_ptr< / span > < span class = "p" > (),< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > data_ptr< / span > < span class = "p" > (),< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > data_ptr< / span > < span class = "p" > (),< / span > < span class = "n" > N< / span > < span class = "p" > ,< / span > < span class = "n" > grid< / span > < span class = "o" > =< / span > < span class = "n" > grid< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > z< / span >
< span class = "c1" > # Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function< / span >
< span class = "n" > add< / span > < span class = "o" > =< / span > < span class = "n" > _add< / span > < span class = "o" > .< / span > < span class = "n" > apply< / span >
< / pre > < / div >
< / div >
2021-03-06 22:06:32 -05:00
< p > We can now use the above function to compute the sum of two < cite > torch.tensor< / cite > objects:< / p >
2021-03-06 17:35:11 -05:00
< / div >
< div class = "section" id = "unit-test" >
< h2 > Unit Test< a class = "headerlink" href = "#unit-test" title = "Permalink to this headline" > ¶< / a > < / h2 >
2021-03-06 22:06:32 -05:00
< p > Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:< / p >
2021-03-06 17:35:11 -05:00
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > manual_seed< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "mi" > 98432< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "mi" > 98432< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > )< / span >
< span class = "n" > za< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "o" > +< / span > < span class = "n" > y< / span >
< span class = "n" > zb< / span > < span class = "o" > =< / span > < span class = "n" > add< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > za< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > zb< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' The maximum difference between torch and triton is ' < / span > < span class = "sa" > f< / span > < span class = "s1" > ' < / span > < span class = "si" > {< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > max< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > abs< / span > < span class = "p" > (< / span > < span class = "n" > za< / span > < span class = "o" > -< / span > < span class = "n" > zb< / span > < span class = "p" > ))< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< p class = "sphx-glr-script-out" > Out:< / p >
2021-03-15 13:58:20 -04:00
< div class = "sphx-glr-script-out highlight-none notranslate" > < div class = "highlight" > < pre > < span > < / span > tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=' cuda:0' )
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=' cuda:0' )
2021-03-06 17:35:11 -05:00
The maximum difference between torch and triton is 0.0
< / pre > < / div >
< / div >
2021-03-06 22:06:32 -05:00
< p > Seems like we’ re good to go!< / p >
2021-03-06 17:35:11 -05:00
< / div >
2021-03-15 13:58:20 -04:00
< div class = "section" id = "benchmark" >
< h2 > Benchmark< a class = "headerlink" href = "#benchmark" title = "Permalink to this headline" > ¶< / a > < / h2 >
2021-03-11 11:58:42 -05:00
< p > We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes.< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "nd" > @triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > perf_report< / span > < span class = "p" > (< / span >
< span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > Benchmark< / span > < span class = "p" > (< / span >
< span class = "n" > x_names< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' size' < / span > < span class = "p" > ],< / span > < span class = "c1" > # argument names to use as an x-axis for the plot< / span >
< span class = "n" > x_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "o" > **< / span > < span class = "n" > i< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "mi" > 12< / span > < span class = "p" > ,< / span > < span class = "mi" > 28< / span > < span class = "p" > ,< / span > < span class = "mi" > 1< / span > < span class = "p" > )],< / span > < span class = "c1" > # different possible values for `x_name`< / span >
< span class = "n" > x_log< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "c1" > # x axis is logarithmic< / span >
< span class = "n" > y_name< / span > < span class = "o" > =< / span > < span class = "s1" > ' provider' < / span > < span class = "p" > ,< / span > < span class = "c1" > # argument name whose value corresponds to a different line in the plot< / span >
< span class = "n" > y_vals< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > ],< / span > < span class = "c1" > # possible keys for `y_name`< / span >
< span class = "n" > y_lines< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s2" > " Torch" < / span > < span class = "p" > ,< / span > < span class = "s2" > " Triton" < / span > < span class = "p" > ],< / span > < span class = "c1" > # label name for the lines< / span >
< span class = "n" > ylabel< / span > < span class = "o" > =< / span > < span class = "s2" > " GB/s" < / span > < span class = "p" > ,< / span > < span class = "c1" > # label name for the y-axis< / span >
< span class = "n" > plot_name< / span > < span class = "o" > =< / span > < span class = "s2" > " vector-add-performance" < / span > < span class = "p" > ,< / span > < span class = "c1" > # name for the plot. Used also as a file name for saving the plot.< / span >
< span class = "n" > args< / span > < span class = "o" > =< / span > < span class = "p" > {}< / span > < span class = "c1" > # values for function arguments not in `x_names` and `y_name`< / span >
< span class = "p" > )< / span >
< span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > benchmark< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "n" > provider< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > y< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > size< / span > < span class = "p" > ,< / span > < span class = "n" > device< / span > < span class = "o" > =< / span > < span class = "s1" > ' cuda' < / span > < span class = "p" > ,< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' torch' < / span > < span class = "p" > :< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > x< / span > < span class = "o" > +< / span > < span class = "n" > y< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > provider< / span > < span class = "o" > ==< / span > < span class = "s1" > ' triton' < / span > < span class = "p" > :< / span >
< span class = "n" > ms< / span > < span class = "p" > ,< / span > < span class = "n" > min_ms< / span > < span class = "p" > ,< / span > < span class = "n" > max_ms< / span > < span class = "o" > =< / span > < span class = "n" > triton< / span > < span class = "o" > .< / span > < span class = "n" > testing< / span > < span class = "o" > .< / span > < span class = "n" > do_bench< / span > < span class = "p" > (< / span > < span class = "k" > lambda< / span > < span class = "p" > :< / span > < span class = "n" > add< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ))< / span >
< span class = "n" > gbps< / span > < span class = "o" > =< / span > < span class = "k" > lambda< / span > < span class = "n" > ms< / span > < span class = "p" > :< / span > < span class = "mi" > 12< / span > < span class = "o" > *< / span > < span class = "n" > size< / span > < span class = "o" > /< / span > < span class = "n" > ms< / span > < span class = "o" > *< / span > < span class = "mf" > 1e-6< / span >
< span class = "k" > return< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > max_ms< / span > < span class = "p" > ),< / span > < span class = "n" > gbps< / span > < span class = "p" > (< / span > < span class = "n" > min_ms< / span > < span class = "p" > )< / span >
2021-03-06 17:35:11 -05:00
< / pre > < / div >
< / div >
2021-03-11 11:58:42 -05:00
< p > We can now run the decorated function above. Pass < cite > show_plots=True< / cite > to see the plots and/or
< a href = "#id1" > < span class = "problematic" id = "id2" > `< / span > < / a > save_path=’ /path/to/results/’ to save them to disk along with raw CSV data< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "n" > benchmark< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span > < span class = "n" > show_plots< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
2021-03-29 11:59:18 -04:00
< img alt = "01 vector add" class = "sphx-glr-single-img" src = "../../_images/sphx_glr_01-vector-add_001.png" / >
< p class = "sphx-glr-timing" > < strong > Total running time of the script:< / strong > ( 0 minutes 9.497 seconds)< / p >
2021-03-06 17:35:11 -05:00
< div class = "sphx-glr-footer class sphx-glr-footer-example docutils container" id = "sphx-glr-download-getting-started-tutorials-01-vector-add-py" >
< div class = "sphx-glr-download sphx-glr-download-python docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Python< / span > < span class = "pre" > source< / span > < span class = "pre" > code:< / span > < span class = "pre" > 01-vector-add.py< / span > < / code > < / a > < / p >
< / div >
< div class = "sphx-glr-download sphx-glr-download-jupyter docutils container" >
< p > < a class = "reference download internal" download = "" href = "../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb" > < code class = "xref download docutils literal notranslate" > < span class = "pre" > Download< / span > < span class = "pre" > Jupyter< / span > < span class = "pre" > notebook:< / span > < span class = "pre" > 01-vector-add.ipynb< / span > < / code > < / a > < / p >
< / div >
< / div >
< p class = "sphx-glr-signature" > < a class = "reference external" href = "https://sphinx-gallery.github.io" > Gallery generated by Sphinx-Gallery< / a > < / p >
< / div >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "02-fused-softmax.html" class = "btn btn-neutral float-right" title = "Fused Softmax" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" aria-hidden = "true" > < / span > < / a >
< a href = "index.html" class = "btn btn-neutral float-left" title = "Tutorials" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" aria-hidden = "true" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2020, Philippe Tillet.
< / p >
< / div >
Built with < a href = "https://www.sphinx-doc.org/" > Sphinx< / a > using a
< a href = "https://github.com/readthedocs/sphinx_rtd_theme" > theme< / a >
provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >