Files
triton/getting-started/tutorials/01-vector-add.html
2021-03-23 17:10:07 -04:00

427 lines
31 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vector Addition &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
<link rel="prev" title="Tutorials" href="index.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
</ul>
</li>
</ul>
<p class="caption"><span class="caption-text">Programming Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Vector Addition</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="vector-addition">
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline"></a></h1>
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
<ul class="simple">
<li><p>The basic syntax of the Triton programming language</p></li>
<li><p>The best practices for creating PyTorch custom operators using the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API</p></li>
<li><p>The best practices for validating and benchmarking custom ops against native reference implementations</p></li>
</ul>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<p>Each compute kernel is declared using the <code class="code docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel
on different chunks of data (See the <a class="reference external" href="(https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a>)
programming model for more details).</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
<span class="c1">// of the program in the overaching SPMD context</span>
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
<span class="c1">// different chunks of data in parallel.</span>
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
<span class="c1">// memory.</span>
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
<span class="c1">// whose memory values are, e.g.:</span>
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
<span class="c1">// write-back `z`</span>
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span>
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div></blockquote>
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL2019 Triton paper</a>.</p>
</div>
<div class="section" id="torch-bindings">
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline"></a></h2>
<p>The only thing that matters when it comes to Triton and Torch is the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="code docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects. To create a <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things:</p>
<ul class="simple">
<li><p><code class="code docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <cite>#define</cite> for you</p></li>
</ul>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="c1"># source-code for Triton compute kernel</span>
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
<span class="n">_src</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
<span class="s2"> // program id</span>
<span class="s2"> int pid = get_program_id(0);</span>
<span class="s2"> // create arrays of pointers</span>
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
<span class="s2"> float* px[BLOCK] = x + offset;</span>
<span class="s2"> float* py[BLOCK] = y + offset;</span>
<span class="s2"> // bounds checking</span>
<span class="s2"> bool check[BLOCK] = offset &lt; N;</span>
<span class="s2"> // write-back</span>
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
<span class="s2">}</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="c1"># This function returns a callable `triton.kernel` object created from the above source code.</span>
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span>
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
<span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
<span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="c1"># This is a standard torch custom autograd Function;</span>
<span class="c1"># The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`</span>
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="c1"># constraints of the op</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="c1"># *allocate output*</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># *create launch grid*:</span>
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
<span class="c1"># *launch kernel*:</span>
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
<span class="c1"># the `.data_ptr()` method</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
<span class="k">return</span> <span class="n">z</span>
<span class="c1"># Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function</span>
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
</pre></div>
</div>
<p>We can now use the above function to compute the sum of two <cite>torch.tensor</cite> objects:</p>
</div>
<div class="section" id="unit-test">
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h2>
<p>Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 0.0
</pre></div>
</div>
<p>Seems like were good to go!</p>
</div>
<div class="section" id="benchmark">
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;size&#39;</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
<span class="n">x_log</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># x axis is logarithmic</span>
<span class="n">y_name</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;torch&#39;</span><span class="p">,</span> <span class="s1">&#39;triton&#39;</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;Torch&quot;</span><span class="p">,</span> <span class="s2">&quot;Triton&quot;</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s2">&quot;GB/s&quot;</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s2">&quot;vector-add-performance&quot;</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
<span class="n">args</span><span class="o">=</span><span class="p">{}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">12</span> <span class="o">*</span> <span class="n">size</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
</pre></div>
</div>
<p>We can now run the decorated function above. Pass <cite>show_plots=True</cite> to see the plots and/or
<a href="#id1"><span class="problematic" id="id2">`</span></a>save_path=/path/to/results/ to save them to disk along with raw CSV data</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<img alt="vector-add-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 7.756 seconds)</p>
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">01-vector-add.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="index.html" class="btn btn-neutral float-left" title="Tutorials" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>