Files
triton/getting-started/tutorials/01-vector-add.html

359 lines
24 KiB
HTML
Raw Normal View History

2021-03-06 17:35:11 -05:00
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vector Addition &mdash; Triton documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-binder.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-dataframe.css" type="text/css" />
<link rel="stylesheet" href="../../_static/gallery-rendered-html.css" type="text/css" />
2021-03-19 16:19:37 -04:00
<link rel="stylesheet" href="../../_static/css/custom.css" type="text/css" />
2021-03-06 17:35:11 -05:00
<!--[if lt IE 9]>
<script src="../../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
<script src="../../_static/jquery.js"></script>
<script src="../../_static/underscore.js"></script>
<script src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
<link rel="prev" title="Tutorials" href="index.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Getting Started</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">Tutorials</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
2021-03-15 13:58:20 -04:00
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
2021-03-06 17:35:11 -05:00
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
2021-03-15 13:58:20 -04:00
<li class="toctree-l2"><a class="reference internal" href="03-matrix-multiplication.html">Matrix Multiplication</a></li>
2021-03-06 17:35:11 -05:00
</ul>
</li>
</ul>
2021-04-21 01:40:29 -04:00
<p class="caption"><span class="caption-text">Language Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
</ul>
2021-03-19 16:19:37 -04:00
<p class="caption"><span class="caption-text">Programming Guide</span></p>
<ul>
2021-03-23 17:10:07 -04:00
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
2021-03-19 16:19:37 -04:00
</ul>
2021-03-06 17:35:11 -05:00
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
<li><a href="index.html">Tutorials</a> &raquo;</li>
<li>Vector Addition</li>
<li class="wy-breadcrumbs-aside">
<a href="../../_sources/getting-started/tutorials/01-vector-add.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="sphx-glr-download-link-note admonition note">
<p class="admonition-title">Note</p>
<p>Click <a class="reference internal" href="#sphx-glr-download-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">here</span></a>
to download the full example code</p>
</div>
<div class="sphx-glr-example-title section" id="vector-addition">
<span id="sphx-glr-getting-started-tutorials-01-vector-add-py"></span><h1>Vector Addition<a class="headerlink" href="#vector-addition" title="Permalink to this headline"></a></h1>
2021-03-06 22:06:32 -05:00
<p>In this tutorial, you will write a simple vector addition using Triton and learn about:</p>
2021-03-06 17:35:11 -05:00
<ul class="simple">
2021-04-21 01:40:29 -04:00
<li><p>The basic programming model used by Triton</p></li>
<li><p>The <cite>triton.jit</cite> decorator, which constitutes the main entry point for writing Triton kernels.</p></li>
2021-03-06 17:35:11 -05:00
<li><p>The best practices for validating and benchmarking custom ops against native reference implementations</p></li>
</ul>
<div class="section" id="compute-kernel">
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline"></a></h2>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
2021-04-21 01:40:29 -04:00
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">_add</span><span class="p">(</span>
<span class="n">X</span><span class="p">,</span> <span class="c1"># *Pointer* to first input vector</span>
<span class="n">Y</span><span class="p">,</span> <span class="c1"># *Pointer* to second input vector</span>
<span class="n">Z</span><span class="p">,</span> <span class="c1"># *Pointer* to output vector</span>
<span class="n">N</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
<span class="o">**</span><span class="n">meta</span> <span class="c1"># Optional meta-parameters for the kernel</span>
<span class="p">):</span>
<span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># Create an offset for the blocks of pointers to be</span>
<span class="c1"># processed by this program instance</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">])</span>
<span class="c1"># Create a mask to guard memory operations against</span>
<span class="c1"># out-of-bounds accesses</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o">&lt;</span> <span class="n">N</span>
<span class="c1"># Load x</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
<span class="c1"># Write back x + y</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
2021-03-06 17:35:11 -05:00
</pre></div>
</div>
2021-04-21 01:40:29 -04:00
<p>We can also declara a helper function that handles allocating the output vector
and enqueueing the kernel.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># The SPMD launch grid denotes the number of kernel instances that should execute in parallel.</span>
<span class="c1"># It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -&gt; Tuple[int]</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">meta</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">]),</span> <span class="p">)</span>
<span class="c1"># NOTE:</span>
<span class="c1"># - torch.tensor objects are implicitly converted to pointers to their first element.</span>
<span class="c1"># - `triton.jit`&#39;ed functions can be subscripted with a launch grid to obtain a callable GPU kernel</span>
<span class="c1"># - don&#39;t forget to pass meta-parameters as keywords arguments</span>
<span class="n">_add</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span>
<span class="c1"># We return a handle to z but, since `torch.cuda.synchronize()` hasn&#39;t been called, the kernel is still</span>
<span class="c1"># running asynchronously.</span>
<span class="k">return</span> <span class="n">z</span>
</pre></div>
2021-03-06 17:35:11 -05:00
</div>
2021-04-21 01:40:29 -04:00
<p>We can now use the above function to compute the sum of two <cite>torch.tensor</cite> objects and test our results:</p>
2021-03-06 17:35:11 -05:00
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
2021-04-21 01:40:29 -04:00
<span class="n">size</span> <span class="o">=</span> <span class="mi">98432</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
2021-03-06 17:35:11 -05:00
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
<p class="sphx-glr-script-out">Out:</p>
2021-03-15 13:58:20 -04:00
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device=&#39;cuda:0&#39;)
2021-03-06 17:35:11 -05:00
The maximum difference between torch and triton is 0.0
</pre></div>
</div>
2021-03-06 22:06:32 -05:00
<p>Seems like were good to go!</p>
2021-03-06 17:35:11 -05:00
</div>
2021-03-15 13:58:20 -04:00
<div class="section" id="benchmark">
<h2>Benchmark<a class="headerlink" href="#benchmark" title="Permalink to this headline"></a></h2>
2021-03-11 11:58:42 -05:00
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
for different problem sizes.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">perf_report</span><span class="p">(</span>
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;size&#39;</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
<span class="n">x_log</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># x axis is logarithmic</span>
<span class="n">y_name</span><span class="o">=</span><span class="s1">&#39;provider&#39;</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;torch&#39;</span><span class="p">,</span> <span class="s1">&#39;triton&#39;</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;Torch&quot;</span><span class="p">,</span> <span class="s2">&quot;Triton&quot;</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
<span class="n">ylabel</span><span class="o">=</span><span class="s2">&quot;GB/s&quot;</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
<span class="n">plot_name</span><span class="o">=</span><span class="s2">&quot;vector-add-performance&quot;</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
<span class="n">args</span><span class="o">=</span><span class="p">{}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;torch&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">&#39;triton&#39;</span><span class="p">:</span>
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">12</span> <span class="o">*</span> <span class="n">size</span> <span class="o">/</span> <span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-6</span>
<span class="k">return</span> <span class="n">gbps</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">gbps</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
2021-03-06 17:35:11 -05:00
</pre></div>
</div>
2021-03-11 11:58:42 -05:00
<p>We can now run the decorated function above. Pass <cite>show_plots=True</cite> to see the plots and/or
<a href="#id1"><span class="problematic" id="id2">`</span></a>save_path=/path/to/results/ to save them to disk along with raw CSV data</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
2021-03-29 11:59:18 -04:00
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
2021-04-21 01:40:29 -04:00
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 5.812 seconds)</p>
2021-03-06 17:35:11 -05:00
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
<div class="sphx-glr-download sphx-glr-download-python docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
</div>
<div class="sphx-glr-download sphx-glr-download-jupyter docutils container">
<p><a class="reference download internal" download="" href="../../_downloads/f191ee1e78dc52eb5f7cba88f71cef2f/01-vector-add.ipynb"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Jupyter</span> <span class="pre">notebook:</span> <span class="pre">01-vector-add.ipynb</span></code></a></p>
</div>
</div>
<p class="sphx-glr-signature"><a class="reference external" href="https://sphinx-gallery.github.io">Gallery generated by Sphinx-Gallery</a></p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="index.html" class="btn btn-neutral float-left" title="Tutorials" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>