697 lines
37 KiB
HTML
697 lines
37 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Vector Addition — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||
<script src="../_static/jquery.js"></script>
|
||
<script src="../_static/underscore.js"></script>
|
||
<script src="../_static/doctools.js"></script>
|
||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||
|
||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="Fused Softmax" href="02-fused-softmax.html" />
|
||
<link rel="prev" title="From Source" href="../installation/from-source.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
|
||
</ul>
|
||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Vector Addition</a><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toctree-l1"><a class="reference internal" href="02-fused-softmax.html">Fused Softmax</a></li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li>Vector Addition</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../_sources/tutorials/01-vector-add.ipynb.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
|
||
<style>
|
||
/* CSS for nbsphinx extension */
|
||
|
||
/* remove conflicting styling from Sphinx themes */
|
||
div.nbinput.container div.prompt *,
|
||
div.nboutput.container div.prompt *,
|
||
div.nbinput.container div.input_area pre,
|
||
div.nboutput.container div.output_area pre,
|
||
div.nbinput.container div.input_area .highlight,
|
||
div.nboutput.container div.output_area .highlight {
|
||
border: none;
|
||
padding: 0;
|
||
margin: 0;
|
||
box-shadow: none;
|
||
}
|
||
|
||
div.nbinput.container > div[class*=highlight],
|
||
div.nboutput.container > div[class*=highlight] {
|
||
margin: 0;
|
||
}
|
||
|
||
div.nbinput.container div.prompt *,
|
||
div.nboutput.container div.prompt * {
|
||
background: none;
|
||
}
|
||
|
||
div.nboutput.container div.output_area .highlight,
|
||
div.nboutput.container div.output_area pre {
|
||
background: unset;
|
||
}
|
||
|
||
div.nboutput.container div.output_area div.highlight {
|
||
color: unset; /* override Pygments text color */
|
||
}
|
||
|
||
/* avoid gaps between output lines */
|
||
div.nboutput.container div[class*=highlight] pre {
|
||
line-height: normal;
|
||
}
|
||
|
||
/* input/output containers */
|
||
div.nbinput.container,
|
||
div.nboutput.container {
|
||
display: -webkit-flex;
|
||
display: flex;
|
||
align-items: flex-start;
|
||
margin: 0;
|
||
width: 100%;
|
||
}
|
||
@media (max-width: 540px) {
|
||
div.nbinput.container,
|
||
div.nboutput.container {
|
||
flex-direction: column;
|
||
}
|
||
}
|
||
|
||
/* input container */
|
||
div.nbinput.container {
|
||
padding-top: 5px;
|
||
}
|
||
|
||
/* last container */
|
||
div.nblast.container {
|
||
padding-bottom: 5px;
|
||
}
|
||
|
||
/* input prompt */
|
||
div.nbinput.container div.prompt pre {
|
||
color: #307FC1;
|
||
}
|
||
|
||
/* output prompt */
|
||
div.nboutput.container div.prompt pre {
|
||
color: #BF5B3D;
|
||
}
|
||
|
||
/* all prompts */
|
||
div.nbinput.container div.prompt,
|
||
div.nboutput.container div.prompt {
|
||
width: 4.5ex;
|
||
padding-top: 5px;
|
||
position: relative;
|
||
user-select: none;
|
||
}
|
||
|
||
div.nbinput.container div.prompt > div,
|
||
div.nboutput.container div.prompt > div {
|
||
position: absolute;
|
||
right: 0;
|
||
margin-right: 0.3ex;
|
||
}
|
||
|
||
@media (max-width: 540px) {
|
||
div.nbinput.container div.prompt,
|
||
div.nboutput.container div.prompt {
|
||
width: unset;
|
||
text-align: left;
|
||
padding: 0.4em;
|
||
}
|
||
div.nboutput.container div.prompt.empty {
|
||
padding: 0;
|
||
}
|
||
|
||
div.nbinput.container div.prompt > div,
|
||
div.nboutput.container div.prompt > div {
|
||
position: unset;
|
||
}
|
||
}
|
||
|
||
/* disable scrollbars on prompts */
|
||
div.nbinput.container div.prompt pre,
|
||
div.nboutput.container div.prompt pre {
|
||
overflow: hidden;
|
||
}
|
||
|
||
/* input/output area */
|
||
div.nbinput.container div.input_area,
|
||
div.nboutput.container div.output_area {
|
||
-webkit-flex: 1;
|
||
flex: 1;
|
||
overflow: auto;
|
||
}
|
||
@media (max-width: 540px) {
|
||
div.nbinput.container div.input_area,
|
||
div.nboutput.container div.output_area {
|
||
width: 100%;
|
||
}
|
||
}
|
||
|
||
/* input area */
|
||
div.nbinput.container div.input_area {
|
||
border: 1px solid #e0e0e0;
|
||
border-radius: 2px;
|
||
/*background: #f5f5f5;*/
|
||
}
|
||
|
||
/* override MathJax center alignment in output cells */
|
||
div.nboutput.container div[class*=MathJax] {
|
||
text-align: left !important;
|
||
}
|
||
|
||
/* override sphinx.ext.imgmath center alignment in output cells */
|
||
div.nboutput.container div.math p {
|
||
text-align: left;
|
||
}
|
||
|
||
/* standard error */
|
||
div.nboutput.container div.output_area.stderr {
|
||
background: #fdd;
|
||
}
|
||
|
||
/* ANSI colors */
|
||
.ansi-black-fg { color: #3E424D; }
|
||
.ansi-black-bg { background-color: #3E424D; }
|
||
.ansi-black-intense-fg { color: #282C36; }
|
||
.ansi-black-intense-bg { background-color: #282C36; }
|
||
.ansi-red-fg { color: #E75C58; }
|
||
.ansi-red-bg { background-color: #E75C58; }
|
||
.ansi-red-intense-fg { color: #B22B31; }
|
||
.ansi-red-intense-bg { background-color: #B22B31; }
|
||
.ansi-green-fg { color: #00A250; }
|
||
.ansi-green-bg { background-color: #00A250; }
|
||
.ansi-green-intense-fg { color: #007427; }
|
||
.ansi-green-intense-bg { background-color: #007427; }
|
||
.ansi-yellow-fg { color: #DDB62B; }
|
||
.ansi-yellow-bg { background-color: #DDB62B; }
|
||
.ansi-yellow-intense-fg { color: #B27D12; }
|
||
.ansi-yellow-intense-bg { background-color: #B27D12; }
|
||
.ansi-blue-fg { color: #208FFB; }
|
||
.ansi-blue-bg { background-color: #208FFB; }
|
||
.ansi-blue-intense-fg { color: #0065CA; }
|
||
.ansi-blue-intense-bg { background-color: #0065CA; }
|
||
.ansi-magenta-fg { color: #D160C4; }
|
||
.ansi-magenta-bg { background-color: #D160C4; }
|
||
.ansi-magenta-intense-fg { color: #A03196; }
|
||
.ansi-magenta-intense-bg { background-color: #A03196; }
|
||
.ansi-cyan-fg { color: #60C6C8; }
|
||
.ansi-cyan-bg { background-color: #60C6C8; }
|
||
.ansi-cyan-intense-fg { color: #258F8F; }
|
||
.ansi-cyan-intense-bg { background-color: #258F8F; }
|
||
.ansi-white-fg { color: #C5C1B4; }
|
||
.ansi-white-bg { background-color: #C5C1B4; }
|
||
.ansi-white-intense-fg { color: #A1A6B2; }
|
||
.ansi-white-intense-bg { background-color: #A1A6B2; }
|
||
|
||
.ansi-default-inverse-fg { color: #FFFFFF; }
|
||
.ansi-default-inverse-bg { background-color: #000000; }
|
||
|
||
.ansi-bold { font-weight: bold; }
|
||
.ansi-underline { text-decoration: underline; }
|
||
|
||
|
||
div.nbinput.container div.input_area div[class*=highlight] > pre,
|
||
div.nboutput.container div.output_area div[class*=highlight] > pre,
|
||
div.nboutput.container div.output_area div[class*=highlight].math,
|
||
div.nboutput.container div.output_area.rendered_html,
|
||
div.nboutput.container div.output_area > div.output_javascript,
|
||
div.nboutput.container div.output_area:not(.rendered_html) > img{
|
||
padding: 5px;
|
||
margin: 0;
|
||
}
|
||
|
||
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
|
||
div.nbinput.container div.input_area > div[class^='highlight'],
|
||
div.nboutput.container div.output_area > div[class^='highlight']{
|
||
overflow-y: hidden;
|
||
}
|
||
|
||
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
|
||
.prompt a.copybtn {
|
||
display: none;
|
||
}
|
||
|
||
/* Some additional styling taken form the Jupyter notebook CSS */
|
||
div.rendered_html table {
|
||
border: none;
|
||
border-collapse: collapse;
|
||
border-spacing: 0;
|
||
color: black;
|
||
font-size: 12px;
|
||
table-layout: fixed;
|
||
}
|
||
div.rendered_html thead {
|
||
border-bottom: 1px solid black;
|
||
vertical-align: bottom;
|
||
}
|
||
div.rendered_html tr,
|
||
div.rendered_html th,
|
||
div.rendered_html td {
|
||
text-align: right;
|
||
vertical-align: middle;
|
||
padding: 0.5em 0.5em;
|
||
line-height: normal;
|
||
white-space: normal;
|
||
max-width: none;
|
||
border: none;
|
||
}
|
||
div.rendered_html th {
|
||
font-weight: bold;
|
||
}
|
||
div.rendered_html tbody tr:nth-child(odd) {
|
||
background: #f5f5f5;
|
||
}
|
||
div.rendered_html tbody tr:hover {
|
||
background: rgba(66, 165, 245, 0.2);
|
||
}
|
||
|
||
/* CSS overrides for sphinx_rtd_theme */
|
||
|
||
/* 24px margin */
|
||
.nbinput.nblast.container,
|
||
.nboutput.nblast.container {
|
||
margin-bottom: 19px; /* padding has already 5px */
|
||
}
|
||
|
||
/* ... except between code cells! */
|
||
.nblast.container + .nbinput.container {
|
||
margin-top: -19px;
|
||
}
|
||
|
||
.admonition > p:before {
|
||
margin-right: 4px; /* make room for the exclamation icon */
|
||
}
|
||
|
||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||
.math {
|
||
text-align: unset;
|
||
}
|
||
</style>
|
||
<div class="section" id="Vector-Addition">
|
||
<h1>Vector Addition<a class="headerlink" href="#Vector-Addition" title="Permalink to this headline">¶</a></h1>
|
||
<p>In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn: * The basic syntax of the Triton programming language * The best practices for creating PyTorch custom operators using the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> Python API * The best practices for validating and benchmarking custom ops against native reference implementations</p>
|
||
<div class="section" id="Writing-the-Compute-Kernel">
|
||
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline">¶</a></h2>
|
||
<p>Each compute kernel is declared using the <code class="docutils literal notranslate"><span class="pre">__global__</span></code> attribute, and executed many times in parallel on different chunks of data (See the <a class="reference external" href="https://en.wikipedia.org/wiki/SPMD">Single Program, Multiple Data</a> programming model for more details).</p>
|
||
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">add</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">z</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||
<span class="c1">// The `get_program_id(i)` returns the i-th coordinate</span>
|
||
<span class="c1">// of the program in the overaching SPMD context</span>
|
||
<span class="c1">// (a.k.a launch grid). This is what allows us to process</span>
|
||
<span class="c1">// different chunks of data in parallel.</span>
|
||
<span class="c1">// For those similar with CUDA, `get_program_id({0,1,2})`</span>
|
||
<span class="c1">// is similar to blockIdx.{x,y,z}</span>
|
||
<span class="kt">int</span> <span class="n">pid</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="c1">// In Triton, arrays are first-class citizen. In other words,</span>
|
||
<span class="c1">// they are primitives data-types and are -- contrary to C and</span>
|
||
<span class="c1">// CUDA -- not implemented as pointers to contiguous chunks of</span>
|
||
<span class="c1">// memory.</span>
|
||
<span class="c1">// In the few lines below, we create an array of `BLOCK` pointers</span>
|
||
<span class="c1">// whose memory values are, e.g.:</span>
|
||
<span class="c1">// [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]</span>
|
||
<span class="c1">// Note: here BLOCK is expected to be a pre-processor macro defined at compile-time</span>
|
||
<span class="kt">int</span> <span class="n">offset</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">BLOCK</span> <span class="o">+</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||
<span class="kt">float</span><span class="o">*</span> <span class="n">pz</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">offset</span><span class="p">;</span>
|
||
<span class="c1">// Simple element-wise control-flow for load/store operations can</span>
|
||
<span class="c1">// be achieved using the the ternary operator `cond ? val_true : val_false`</span>
|
||
<span class="c1">// or the conditional dereferencing operator `*?(cond)ptr</span>
|
||
<span class="c1">// Here, we make sure that we do not access memory out-of-bounds when we</span>
|
||
<span class="c1">// write-back `z`</span>
|
||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">offset</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">pz</span> <span class="o">=</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">px</span> <span class="o">+</span> <span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span><span class="p">;</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the <a class="reference external" href="http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf">MAPL’2019 Triton paper</a>.</p>
|
||
</div>
|
||
<div class="section" id="Writing-the-Torch-bindings">
|
||
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||
<p>The only thing that matters when it comes to Triton and Torch is the <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> class. This allows you to transform the above C-like function into a callable python object that can be used to modify <code class="docutils literal notranslate"><span class="pre">torch.tensor</span></code> objects.</p>
|
||
<p>To create a <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code>, you only need three things: * <code class="docutils literal notranslate"><span class="pre">source:</span> <span class="pre">string</span></code>: the source-code of the kernel you want to create * <code class="docutils literal notranslate"><span class="pre">device:</span> <span class="pre">torch.device</span></code>: the device you want to compile this code for * <code class="docutils literal notranslate"><span class="pre">defines:</span> <span class="pre">dict</span></code>: the set of macros that you want the pre-processor to <code class="docutils literal notranslate"><span class="pre">#define</span></code> for you</p>
|
||
<p>Note: The constructor of <code class="docutils literal notranslate"><span class="pre">triton.kernel</span></code> does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see <code class="docutils literal notranslate"><span class="pre">_kernels</span></code> variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator’s inputs.</p>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[10]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
|
||
<span class="c1"># source-code for Triton compute kernel</span>
|
||
<span class="c1"># here we just copy-paste the above code without the extensive comments.</span>
|
||
<span class="c1"># you may prefer to store it in a .c file and load it from there instead.</span>
|
||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||
<span class="s2">__global__ void add(float* z, float* x, float* y, int N){</span>
|
||
<span class="s2"> // program id</span>
|
||
<span class="s2"> int pid = get_program_id(0);</span>
|
||
<span class="s2"> // create arrays of pointers</span>
|
||
<span class="s2"> int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;</span>
|
||
<span class="s2"> float* pz[BLOCK] = z + offset;</span>
|
||
<span class="s2"> float* px[BLOCK] = x + offset;</span>
|
||
<span class="s2"> float* py[BLOCK] = y + offset;</span>
|
||
<span class="s2"> // bounds checking</span>
|
||
<span class="s2"> bool check[BLOCK] = offset < N;</span>
|
||
<span class="s2"> // write-back</span>
|
||
<span class="s2"> *?(check)pz = *?(check)px + *?(check)py;</span>
|
||
<span class="s2">}</span>
|
||
<span class="s2"> """</span>
|
||
<span class="c1"># This function returns a callable `triton.kernel` object</span>
|
||
<span class="c1"># created from the above source code.</span>
|
||
<span class="c1"># For portability, we maintain a cache of kernels for different `torch.device`</span>
|
||
<span class="c1"># We compile the kernel with -DBLOCK=1024</span>
|
||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||
<span class="k">def</span> <span class="nf">make_add_kernel</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">device</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="mi">1024</span><span class="p">}</span>
|
||
<span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">device</span><span class="p">]</span>
|
||
|
||
<span class="c1"># This is a standard torch custom autograd Function</span>
|
||
<span class="c1"># The only difference is that we can now use the above kernel</span>
|
||
<span class="c1"># in the `forward` and `backward` functions.`</span>
|
||
<span class="k">class</span> <span class="nc">_add</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
|
||
<span class="c1"># constraints of the op</span>
|
||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="c1"># *allocate output*</span>
|
||
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="c1"># *create launch grid*:</span>
|
||
<span class="c1"># this is a function which takes compilation parameters `opt`</span>
|
||
<span class="c1"># as input and returns a tuple of int (i.e., launch grid) for the kernel.</span>
|
||
<span class="c1"># triton.cdiv is a shortcut for ceil division:</span>
|
||
<span class="c1"># triton.cdiv(a, b) = (a + b - 1) // b</span>
|
||
<span class="n">N</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">opt</span><span class="o">.</span><span class="n">BLOCK</span><span class="p">),</span> <span class="p">)</span>
|
||
<span class="c1"># *launch kernel*:</span>
|
||
<span class="c1"># pointer to the data of torch tensors can be retrieved with</span>
|
||
<span class="c1"># the `.data_ptr()` method</span>
|
||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_add_kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">kernel</span><span class="p">(</span><span class="n">z</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">z</span>
|
||
<span class="c1"># Just like we standard PyTorch ops</span>
|
||
<span class="c1"># We use the `.apply` method to create a</span>
|
||
<span class="c1"># callable object for our function</span>
|
||
<span class="n">add</span> <span class="o">=</span> <span class="n">_add</span><span class="o">.</span><span class="n">apply</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<p>At this point <code class="docutils literal notranslate"><span class="pre">add(x,</span> <span class="pre">y)</span></code> is equivalent to <code class="docutils literal notranslate"><span class="pre">x</span> <span class="pre">+</span> <span class="pre">y</span></code> for contiguous tensors. Now let’s test and benchmark it!</p>
|
||
</div>
|
||
<div class="section" id="Writing-a-Unit-Test">
|
||
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline">¶</a></h2>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[9]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">98432</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">za</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||
<span class="n">zb</span> <span class="o">=</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">za</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">zb</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'The maximum difference between torch and triton is '</span>
|
||
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">za</span> <span class="o">-</span> <span class="n">zb</span><span class="p">))</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area docutils container">
|
||
<div class="highlight"><pre>
|
||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||
tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
|
||
The maximum difference between torch and triton is 0.0
|
||
</pre></div></div>
|
||
</div>
|
||
<p>Seems to work!</p>
|
||
</div>
|
||
<div class="section" id="Writing-a-Benchmark">
|
||
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<p>The performance of our GPU code can be benchmark using the <code class="docutils literal notranslate"><span class="pre">torch.cuda.Event(enable_timing=True)</span></code> wrapper. Below is a simple function that benchmarks <code class="docutils literal notranslate"><span class="pre">rep</span></code> runs of our kernels after <code class="docutils literal notranslate"><span class="pre">warmup</span></code> “cold” runs.</p>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[11]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="c1"># We now want to benchmark the performance of `add`</span>
|
||
<span class="c1"># Against that of PyTorch for increasing vector sizes</span>
|
||
<span class="k">def</span> <span class="nf">do_bench</span><span class="p">(</span><span class="n">fn</span><span class="p">,</span> <span class="n">warmup</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">rep</span> <span class="o">=</span> <span class="mi">50</span><span class="p">):</span>
|
||
<span class="n">start_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">end_event</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">fn</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">warmup</span><span class="p">):</span>
|
||
<span class="n">fn</span><span class="p">()</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="n">start_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rep</span><span class="p">):</span>
|
||
<span class="n">fn</span><span class="p">()</span>
|
||
<span class="n">end_event</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="n">time_ms</span> <span class="o">=</span> <span class="n">start_event</span><span class="o">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="n">end_event</span><span class="p">)</span> <span class="o">/</span> <span class="n">rep</span>
|
||
<span class="k">return</span> <span class="n">time_ms</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<p>We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does</p>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[15]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="mi">26</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">triton_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">add</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span>
|
||
<span class="n">torch_ms</span> <span class="o">=</span> <span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">)</span>
|
||
<span class="c1"># print the performance of triton and torch as well as the achieved bandwidth</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">N</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">triton_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="n">torch_ms</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area docutils container">
|
||
<div class="highlight"><pre>
|
||
131072 0.020 0.003
|
||
262144 0.019 0.004
|
||
524288 0.016 0.016
|
||
1048576 0.033 0.033
|
||
2097152 0.071 0.070
|
||
4194304 0.142 0.144
|
||
8388608 0.287 0.286
|
||
16777216 0.572 0.568
|
||
33554432 1.139 1.110
|
||
</pre></div></div>
|
||
</div>
|
||
<p>Our op is on-par with Torch’s vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce.</p>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="02-fused-softmax.html" class="btn btn-neutral float-right" title="Fused Softmax" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||
<a href="../installation/from-source.html" class="btn btn-neutral float-left" title="From Source" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |