Files
triton/tutorials/01-vector-add.html
2021-03-06 03:09:03 -05:00

697 lines
37 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vector Addition &mdash; Triton documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/doctools.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
<link rel="prev" title="From Source" href="../installation/from-source.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home"> Triton
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
</ul>
<p class="caption"><span class="caption-text">Tutorials</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
<nav class="wy-nav-top" aria-label="top navigation">
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">Triton</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
<li>Vector Addition</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/tutorials/01-vector-add.ipynb.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<style>
/* CSS for nbsphinx extension */
/* remove conflicting styling from Sphinx themes */
div.nbinput.container div.prompt *,
div.nboutput.container div.prompt *,
div.nbinput.container div.input_area pre,
div.nboutput.container div.output_area pre,
div.nbinput.container div.input_area .highlight,
div.nboutput.container div.output_area .highlight {
border: none;
padding: 0;
margin: 0;
box-shadow: none;
}
div.nbinput.container > div[class*=highlight],
div.nboutput.container > div[class*=highlight] {
margin: 0;
}
div.nbinput.container div.prompt *,
div.nboutput.container div.prompt * {
background: none;
}
div.nboutput.container div.output_area .highlight,
div.nboutput.container div.output_area pre {
background: unset;
}
div.nboutput.container div.output_area div.highlight {
color: unset; /* override Pygments text color */
}
/* avoid gaps between output lines */
div.nboutput.container div[class*=highlight] pre {
line-height: normal;
}
/* input/output containers */
div.nbinput.container,
div.nboutput.container {
display: -webkit-flex;
display: flex;
align-items: flex-start;
margin: 0;
width: 100%;
}
@media (max-width: 540px) {
div.nbinput.container,
div.nboutput.container {
flex-direction: column;
}
}
/* input container */
div.nbinput.container {
padding-top: 5px;
}
/* last container */
div.nblast.container {
padding-bottom: 5px;
}
/* input prompt */
div.nbinput.container div.prompt pre {
color: #307FC1;
}
/* output prompt */
div.nboutput.container div.prompt pre {
color: #BF5B3D;
}
/* all prompts */
div.nbinput.container div.prompt,
div.nboutput.container div.prompt {
width: 4.5ex;
padding-top: 5px;
position: relative;
user-select: none;
}
div.nbinput.container div.prompt > div,
div.nboutput.container div.prompt > div {
position: absolute;
right: 0;
margin-right: 0.3ex;
}
@media (max-width: 540px) {
div.nbinput.container div.prompt,
div.nboutput.container div.prompt {
width: unset;
text-align: left;
padding: 0.4em;
}
div.nboutput.container div.prompt.empty {
padding: 0;
}
div.nbinput.container div.prompt > div,
div.nboutput.container div.prompt > div {
position: unset;
}
}
/* disable scrollbars on prompts */
div.nbinput.container div.prompt pre,
div.nboutput.container div.prompt pre {
overflow: hidden;
}
/* input/output area */
div.nbinput.container div.input_area,
div.nboutput.container div.output_area {
-webkit-flex: 1;
flex: 1;
overflow: auto;
}
@media (max-width: 540px) {
div.nbinput.container div.input_area,
div.nboutput.container div.output_area {
width: 100%;
}
}
/* input area */
div.nbinput.container div.input_area {
border: 1px solid #e0e0e0;
border-radius: 2px;
/*background: #f5f5f5;*/
}
/* override MathJax center alignment in output cells */
div.nboutput.container div[class*=MathJax] {
text-align: left !important;
}
/* override sphinx.ext.imgmath center alignment in output cells */
div.nboutput.container div.math p {
text-align: left;
}
/* standard error */
div.nboutput.container div.output_area.stderr {
background: #fdd;
}
/* ANSI colors */
.ansi-black-fg { color: #3E424D; }
.ansi-black-bg { background-color: #3E424D; }
.ansi-black-intense-fg { color: #282C36; }
.ansi-black-intense-bg { background-color: #282C36; }
.ansi-red-fg { color: #E75C58; }
.ansi-red-bg { background-color: #E75C58; }
.ansi-red-intense-fg { color: #B22B31; }
.ansi-red-intense-bg { background-color: #B22B31; }
.ansi-green-fg { color: #00A250; }
.ansi-green-bg { background-color: #00A250; }
.ansi-green-intense-fg { color: #007427; }
.ansi-green-intense-bg { background-color: #007427; }
.ansi-yellow-fg { color: #DDB62B; }
.ansi-yellow-bg { background-color: #DDB62B; }
.ansi-yellow-intense-fg { color: #B27D12; }
.ansi-yellow-intense-bg { background-color: #B27D12; }
.ansi-blue-fg { color: #208FFB; }
.ansi-blue-bg { background-color: #208FFB; }
.ansi-blue-intense-fg { color: #0065CA; }
.ansi-blue-intense-bg { background-color: #0065CA; }
.ansi-magenta-fg { color: #D160C4; }
.ansi-magenta-bg { background-color: #D160C4; }
.ansi-magenta-intense-fg { color: #A03196; }
.ansi-magenta-intense-bg { background-color: #A03196; }
.ansi-cyan-fg { color: #60C6C8; }
.ansi-cyan-bg { background-color: #60C6C8; }
.ansi-cyan-intense-fg { color: #258F8F; }
.ansi-cyan-intense-bg { background-color: #258F8F; }
.ansi-white-fg { color: #C5C1B4; }
.ansi-white-bg { background-color: #C5C1B4; }
.ansi-white-intense-fg { color: #A1A6B2; }
.ansi-white-intense-bg { background-color: #A1A6B2; }
.ansi-default-inverse-fg { color: #FFFFFF; }
.ansi-default-inverse-bg { background-color: #000000; }
.ansi-bold { font-weight: bold; }
.ansi-underline { text-decoration: underline; }
div.nbinput.container div.input_area div[class*=highlight] > pre,
div.nboutput.container div.output_area div[class*=highlight] > pre,
div.nboutput.container div.output_area div[class*=highlight].math,
div.nboutput.container div.output_area.rendered_html,
div.nboutput.container div.output_area > div.output_javascript,
div.nboutput.container div.output_area:not(.rendered_html) > img{
padding: 5px;
margin: 0;
}
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
div.nbinput.container div.input_area > div[class^='highlight'],
div.nboutput.container div.output_area > div[class^='highlight']{
overflow-y: hidden;
}
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
.prompt a.copybtn {
display: none;
}
/* Some additional styling taken form the Jupyter notebook CSS */
div.rendered_html table {
border: none;
border-collapse: collapse;
border-spacing: 0;
color: black;
font-size: 12px;
table-layout: fixed;
}
div.rendered_html thead {
border-bottom: 1px solid black;
vertical-align: bottom;
}
div.rendered_html tr,
div.rendered_html th,
div.rendered_html td {
text-align: right;
vertical-align: middle;
padding: 0.5em 0.5em;
line-height: normal;
white-space: normal;
max-width: none;
border: none;
}
div.rendered_html th {
font-weight: bold;
}
div.rendered_html tbody tr:nth-child(odd) {
background: #f5f5f5;
}
div.rendered_html tbody tr:hover {
background: rgba(66, 165, 245, 0.2);
}
/* CSS overrides for sphinx_rtd_theme */
/* 24px margin */
.nbinput.nblast.container,
.nboutput.nblast.container {
margin-bottom: 19px; /* padding has already 5px */
}
/* ... except between code cells! */
.nblast.container + .nbinput.container {
margin-top: -19px;
}
.admonition > p:before {
margin-right: 4px; /* make room for the exclamation icon */
}
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
.math {
text-align: unset;
}
</style>
<div class="section" id="Vector-Addition">
<h1>Vector Addition<a class="headerlink" href="#Vector-Addition" title="Permalink to this headline"></a></h1>
<p>In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: * The basic syntax of the Triton programming language * The best practices for creating PyTorch custom operators using the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API * The best practices for validating and benchmarking custom ops against native reference implementations</p>
<div class="section" id="Writing-the-Compute-Kernel">
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline"></a></h2>
<p>Each compute kernel is declared using the <code class="docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel on different chunks of data (See the <a class="reference external" href="https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a> programming model for more details).</p>
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
<span class="c1">// of the program in the overaching SPMD context</span>
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
<span class="c1">// different chunks of data in parallel.</span>
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
<span class="c1">// memory.</span>
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
<span class="c1">// whose memory values are, e.g.:</span>
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
<span class="c1">// write-back `z`</span>
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o">&lt;</span> <span class="n">N</span><span class="p">;</span>
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL2019 Triton paper</a>.</p>
</div>
<div class="section" id="Writing-the-Torch-bindings">
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline"></a></h2>
<p>The only thing that matters when it comes to Triton and Torch is the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects.</p>
<p>To create a <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things: * <code class="docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create * <code class="docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for * <code class="docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <code class="docutils literal notranslate"><span class="pre">#define</span></code> for you</p>
<p>Note: The constructor of <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see <code class="docutils literal notranslate"><span class="pre">_kernels</span></code> variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operators inputs.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[10]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">triton</span>
<span class="c1"># source-code for Triton compute kernel</span>
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
<span class="n">_src</span> <span class="o">=</span> <span class="s2">&quot;&quot;&quot;</span>
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
<span class="s2"> // program id</span>
<span class="s2"> int pid = get_program_id(0);</span>
<span class="s2"> // create arrays of pointers</span>
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
<span class="s2"> float* px[BLOCK] = x + offset;</span>
<span class="s2"> float* py[BLOCK] = y + offset;</span>
<span class="s2"> // bounds checking</span>
<span class="s2"> bool check[BLOCK] = offset &lt; N;</span>
<span class="s2"> // write-back</span>
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
<span class="s2">}</span>
<span class="s2"> &quot;&quot;&quot;</span>
<span class="c1"># This function returns a callable `triton.kernel` object</span>
<span class="c1"># created from the above source code.</span>
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BLOCK&#39;</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
<span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
<span class="c1"># This is a standard torch custom autograd Function</span>
<span class="c1"># The only difference is that we can now use the above kernel</span>
<span class="c1"># in the `forward` and `backward` functions.`</span>
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
<span class="c1"># constraints of the op</span>
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="c1"># *allocate output*</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="c1"># *create launch grid*:</span>
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
<span class="c1"># *launch kernel*:</span>
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
<span class="c1"># the `.data_ptr()` method</span>
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
<span class="k">return</span> <span class="n">z</span>
<span class="c1"># Just like we standard PyTorch ops</span>
<span class="c1"># We use the `.apply` method to create a</span>
<span class="c1"># callable object for our function</span>
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
</pre></div>
</div>
</div>
<p>At this point <code class="docutils literal notranslate"><span class="pre">add(x,</span> <span class="pre">y)</span></code> is equivalent to <code class="docutils literal notranslate"><span class="pre">x</span> <span class="pre">+</span> <span class="pre">y</span></code> for contiguous tensors. Now lets test and benchmark it!</p>
</div>
<div class="section" id="Writing-a-Unit-Test">
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline"></a></h2>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The maximum difference between torch and triton is &#39;</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;)
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device=&#39;cuda:0&#39;)
The maximum difference between torch and triton is 0.0
</pre></div></div>
</div>
<p>Seems to work!</p>
</div>
<div class="section" id="Writing-a-Benchmark">
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline"></a></h2>
<p>The performance of our GPU code can be benchmark using the <code class="docutils literal notranslate"><span class="pre">torch.cuda.Event(enable_timing=True)</span></code> wrapper. Below is a simple function that benchmarks <code class="docutils literal notranslate"><span class="pre">rep</span></code> runs of our kernels after <code class="docutils literal notranslate"><span class="pre">warmup</span></code> “cold” runs.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[11]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="c1"># We now want to benchmark the performance of `add`</span>
<span class="c1"># Against that of PyTorch for increasing vector sizes</span>
<span class="k">def</span> <span class="nf">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">warmup</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">rep</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
<span class="n">start_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">end_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">warmup</span><span class="p">):</span>
<span class="n">fn</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">start_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rep</span><span class="p">):</span>
<span class="n">fn</span><span class="p">()</span>
<span class="n">end_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">time_ms</span> <span class="o">=</span> <span class="n">start_event</span><span class="o">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="n">end_event</span><span class="p">)</span> <span class="o">/</span> <span class="n">rep</span>
<span class="k">return</span> <span class="n">time_ms</span>
</pre></div>
</div>
</div>
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[15]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
<span></span><span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="mi">26</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">triton_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
<span class="n">torch_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
<span class="c1"># print the performance of triton and torch as well as the achieved bandwidth</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">N</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">triton_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">torch_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
131072 0.020 0.003
262144 0.019 0.004
524288 0.016 0.016
1048576 0.033 0.033
2097152 0.071 0.070
4194304 0.142 0.144
8388608 0.287 0.286
16777216 0.572 0.568
33554432 1.139 1.110
</pre></div></div>
</div>
<p>Our op is on-par with Torchs vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce.</p>
</div>
</div>
</div>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
<a href="../installation/from-source.html" class="btn btn-neutral float-left" title="From Source" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
</div>
<hr/>
<div role="contentinfo">
<p>
&#169; Copyright 2020, Philippe Tillet.
</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>