Files
triton/getting-started/tutorials/01-vector-add.html

417 lines
30 KiB
HTML
Raw Normal View History

2021-03-06 17:35:11 -05:00
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vector Addition &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
<link rel="prev" title="Tutorials" href="index.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch bindings</a></li>
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
<li class="toctree-l3"><a class="reference internal" href="#benchmarking">Benchmarking</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
</ul>
</li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Vector Addition</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="vector-addition">
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline"></a></h1>
2021-03-06 22:06:32 -05:00
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
2021-03-06 17:35:11 -05:00
<ul class="simple">
<li><p>The basic syntax of the Triton programming language</p></li>
<li><p>The best practices for creating PyTorch custom operators using the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API</p></li>
<li><p>The best practices for validating and benchmarking custom ops against native reference implementations</p></li>
</ul>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<p>Each compute kernel is declared using the <code class="code docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel
on different chunks of data (See the <a class="reference external" href="(https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a>)
programming model for more details).</p>
<blockquote>
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
<span class="c1">// of the program in the overaching SPMD context</span>
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
<span class="c1">// different chunks of data in parallel.</span>
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
<span class="c1">// memory.</span>
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
<span class="c1">// whose memory values are, e.g.:</span>
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
<span class="c1">// write-back `z`</span>
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span>
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div></blockquote>
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL2019 Triton paper</a>.</p>
</div>
<div class="section" id="torch-bindings">
<h2>Torch bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline"></a></h2>
<p>The only thing that matters when it comes to Triton and Torch is the <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="code docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects. To create a <code class="code docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things:</p>
<ul class="simple">
<li><p><code class="code docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <cite>#define</cite> for you</p></li>
</ul>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="c1"># source-code for Triton compute kernel</span>
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
<span class="n">_src</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
<span class="s2"> // program id</span>
<span class="s2"> int pid = get_program_id(0);</span>
<span class="s2"> // create arrays of pointers</span>
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
<span class="s2"> float* px[BLOCK] = x + offset;</span>
<span class="s2"> float* py[BLOCK] = y + offset;</span>
<span class="s2"> // bounds checking</span>
<span class="s2"> bool check[BLOCK] = offset &lt; N;</span>
<span class="s2"> // write-back</span>
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
<span class="s2">}</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="c1"># This function returns a callable `triton.kernel` object created from the above source code.</span>
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span>
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
<span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
<span class="n">make_add_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="c1"># This is a standard torch custom autograd Function;</span>
<span class="c1"># The only difference is that we can now use the above kernel in the `forward` and `backward` functions.`</span>
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="c1"># constraints of the op</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="c1"># *allocate output*</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># *create launch grid*:</span>
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
<span class="c1"># *launch kernel*:</span>
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
<span class="c1"># the `.data_ptr()` method</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
<span class="k">return</span> <span class="n">z</span>
<span class="c1"># Just like we standard PyTorch ops We use the :code:`.apply` method to create a callable object for our function</span>
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
</pre></div>
</div>
2021-03-06 22:06:32 -05:00
<p>We can now use the above function to compute the sum of two <cite>torch.tensor</cite> objects:</p>
2021-03-06 17:35:11 -05:00
</div>
<div class="section" id="unit-test">
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h2>
2021-03-06 22:06:32 -05:00
<p>Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:</p>
2021-03-06 17:35:11 -05:00
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 0.0
</pre></div>
</div>
2021-03-06 22:06:32 -05:00
<p>Seems like were good to go!</p>
2021-03-06 17:35:11 -05:00
</div>
<div class="section" id="benchmarking">
<h2>Benchmarking<a class="headerlink" href="#benchmarking" title="Permalink to this headline"></a></h2>
2021-03-06 22:06:32 -05:00
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="c1"># There are three tensors of 4N bytes each. So the bandwidth of a given kernel</span>
<span class="c1"># is 12N / time_ms * 1e-6 GB/s</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">N</span><span class="p">,</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">12</span> <span class="o">*</span> <span class="n">N</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="c1"># We want to benchmark small and large vector alike</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">25</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]</span>
<span class="n">triton_bw</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">torch_bw</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># Triton provide a do_bench utility function that can be used to benchmark</span>
<span class="c1"># arbitrary workloads. It supports a `warmup` parameter that is used to stabilize</span>
<span class="c1"># GPU clock speeds as well as a `rep` parameter that controls the number of times</span>
<span class="c1"># the benchmark is repeated. Importantly, we set `clear_l2 = True` to make sure</span>
<span class="c1"># that the L2 cache does not contain any element of x before each kernel call when</span>
<span class="c1"># N is small.</span>
<span class="n">do_bench</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">fn</span><span class="p">:</span> <span class="n">gbps</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">rep</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">clear_l2</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>
<span class="n">triton_bw</span> <span class="o">+=</span> <span class="p">[</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))]</span>
<span class="n">torch_bw</span> <span class="o">+=</span> <span class="p">[</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)]</span>
<span class="c1"># We plot the results as a semi-log</span>
<span class="n">plt</span><span class="o">.</span><span class="n">semilogx</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">triton_bw</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Triton&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">semilogx</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">torch_bw</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;Torch&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
2021-03-06 17:35:11 -05:00
</pre></div>
</div>
2021-03-06 22:06:32 -05:00
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
<p>Seems like our simple element-wise operation operates at peak bandwidth. While this is a fairly low bar for a custom GPU programming language, this is a good start before we move to more advanced operations.</p>
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 4.784 seconds)</p>
2021-03-06 17:35:11 -05:00
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">01-vector-add.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="index.html" class="btn btn-neutral float-left" title="Tutorials" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>