705 lines
37 KiB
HTML
705 lines
37 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
<html class="writer-html5" lang="en" >
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
|
||
<title>Fused Softmax — Triton documentation</title>
|
||
|
||
|
||
|
||
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
|
||
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<!--[if lt IE 9]>
|
||
<script src="../_static/js/html5shiv.min.js"></script>
|
||
<![endif]-->
|
||
|
||
|
||
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
|
||
<script src="../_static/jquery.js"></script>
|
||
<script src="../_static/underscore.js"></script>
|
||
<script src="../_static/doctools.js"></script>
|
||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
||
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
|
||
|
||
<script type="text/javascript" src="../_static/js/theme.js"></script>
|
||
|
||
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="prev" title="Vector Addition" href="01-vector-add.html" />
|
||
</head>
|
||
|
||
<body class="wy-body-for-nav">
|
||
|
||
|
||
<div class="wy-grid-for-nav">
|
||
|
||
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
||
<div class="wy-side-scroll">
|
||
<div class="wy-side-nav-search" >
|
||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home"> Triton
|
||
|
||
|
||
|
||
</a>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="search">
|
||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
|
||
<input type="text" name="q" placeholder="Search docs" />
|
||
<input type="hidden" name="check_keywords" value="yes" />
|
||
<input type="hidden" name="area" value="default" />
|
||
</form>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<p class="caption"><span class="caption-text">Installation Instructions</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/packaged-binaries.html">Packaged Binaries</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../installation/from-source.html">From Source</a></li>
|
||
</ul>
|
||
<p class="caption"><span class="caption-text">Tutorials</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1"><a class="reference internal" href="01-vector-add.html">Vector Addition</a></li>
|
||
<li class="toctree-l1 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Compute-Kernel">Writing the Compute Kernel</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-the-Torch-bindings">Writing the Torch bindings</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Unit-Test">Writing a Unit Test</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="#Writing-a-Benchmark">Writing a Benchmark</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</nav>
|
||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
||
|
||
|
||
<nav class="wy-nav-top" aria-label="top navigation">
|
||
|
||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
||
<a href="../index.html">Triton</a>
|
||
|
||
</nav>
|
||
|
||
|
||
<div class="wy-nav-content">
|
||
|
||
<div class="rst-content">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<div role="navigation" aria-label="breadcrumbs navigation">
|
||
|
||
<ul class="wy-breadcrumbs">
|
||
|
||
<li><a href="../index.html" class="icon icon-home"></a> »</li>
|
||
|
||
<li>Fused Softmax</li>
|
||
|
||
|
||
<li class="wy-breadcrumbs-aside">
|
||
|
||
|
||
<a href="../_sources/tutorials/02-fused-softmax.ipynb.txt" rel="nofollow"> View page source</a>
|
||
|
||
|
||
</li>
|
||
|
||
</ul>
|
||
|
||
|
||
<hr/>
|
||
</div>
|
||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
||
<div itemprop="articleBody">
|
||
|
||
|
||
<style>
|
||
/* CSS for nbsphinx extension */
|
||
|
||
/* remove conflicting styling from Sphinx themes */
|
||
div.nbinput.container div.prompt *,
|
||
div.nboutput.container div.prompt *,
|
||
div.nbinput.container div.input_area pre,
|
||
div.nboutput.container div.output_area pre,
|
||
div.nbinput.container div.input_area .highlight,
|
||
div.nboutput.container div.output_area .highlight {
|
||
border: none;
|
||
padding: 0;
|
||
margin: 0;
|
||
box-shadow: none;
|
||
}
|
||
|
||
div.nbinput.container > div[class*=highlight],
|
||
div.nboutput.container > div[class*=highlight] {
|
||
margin: 0;
|
||
}
|
||
|
||
div.nbinput.container div.prompt *,
|
||
div.nboutput.container div.prompt * {
|
||
background: none;
|
||
}
|
||
|
||
div.nboutput.container div.output_area .highlight,
|
||
div.nboutput.container div.output_area pre {
|
||
background: unset;
|
||
}
|
||
|
||
div.nboutput.container div.output_area div.highlight {
|
||
color: unset; /* override Pygments text color */
|
||
}
|
||
|
||
/* avoid gaps between output lines */
|
||
div.nboutput.container div[class*=highlight] pre {
|
||
line-height: normal;
|
||
}
|
||
|
||
/* input/output containers */
|
||
div.nbinput.container,
|
||
div.nboutput.container {
|
||
display: -webkit-flex;
|
||
display: flex;
|
||
align-items: flex-start;
|
||
margin: 0;
|
||
width: 100%;
|
||
}
|
||
@media (max-width: 540px) {
|
||
div.nbinput.container,
|
||
div.nboutput.container {
|
||
flex-direction: column;
|
||
}
|
||
}
|
||
|
||
/* input container */
|
||
div.nbinput.container {
|
||
padding-top: 5px;
|
||
}
|
||
|
||
/* last container */
|
||
div.nblast.container {
|
||
padding-bottom: 5px;
|
||
}
|
||
|
||
/* input prompt */
|
||
div.nbinput.container div.prompt pre {
|
||
color: #307FC1;
|
||
}
|
||
|
||
/* output prompt */
|
||
div.nboutput.container div.prompt pre {
|
||
color: #BF5B3D;
|
||
}
|
||
|
||
/* all prompts */
|
||
div.nbinput.container div.prompt,
|
||
div.nboutput.container div.prompt {
|
||
width: 4.5ex;
|
||
padding-top: 5px;
|
||
position: relative;
|
||
user-select: none;
|
||
}
|
||
|
||
div.nbinput.container div.prompt > div,
|
||
div.nboutput.container div.prompt > div {
|
||
position: absolute;
|
||
right: 0;
|
||
margin-right: 0.3ex;
|
||
}
|
||
|
||
@media (max-width: 540px) {
|
||
div.nbinput.container div.prompt,
|
||
div.nboutput.container div.prompt {
|
||
width: unset;
|
||
text-align: left;
|
||
padding: 0.4em;
|
||
}
|
||
div.nboutput.container div.prompt.empty {
|
||
padding: 0;
|
||
}
|
||
|
||
div.nbinput.container div.prompt > div,
|
||
div.nboutput.container div.prompt > div {
|
||
position: unset;
|
||
}
|
||
}
|
||
|
||
/* disable scrollbars on prompts */
|
||
div.nbinput.container div.prompt pre,
|
||
div.nboutput.container div.prompt pre {
|
||
overflow: hidden;
|
||
}
|
||
|
||
/* input/output area */
|
||
div.nbinput.container div.input_area,
|
||
div.nboutput.container div.output_area {
|
||
-webkit-flex: 1;
|
||
flex: 1;
|
||
overflow: auto;
|
||
}
|
||
@media (max-width: 540px) {
|
||
div.nbinput.container div.input_area,
|
||
div.nboutput.container div.output_area {
|
||
width: 100%;
|
||
}
|
||
}
|
||
|
||
/* input area */
|
||
div.nbinput.container div.input_area {
|
||
border: 1px solid #e0e0e0;
|
||
border-radius: 2px;
|
||
/*background: #f5f5f5;*/
|
||
}
|
||
|
||
/* override MathJax center alignment in output cells */
|
||
div.nboutput.container div[class*=MathJax] {
|
||
text-align: left !important;
|
||
}
|
||
|
||
/* override sphinx.ext.imgmath center alignment in output cells */
|
||
div.nboutput.container div.math p {
|
||
text-align: left;
|
||
}
|
||
|
||
/* standard error */
|
||
div.nboutput.container div.output_area.stderr {
|
||
background: #fdd;
|
||
}
|
||
|
||
/* ANSI colors */
|
||
.ansi-black-fg { color: #3E424D; }
|
||
.ansi-black-bg { background-color: #3E424D; }
|
||
.ansi-black-intense-fg { color: #282C36; }
|
||
.ansi-black-intense-bg { background-color: #282C36; }
|
||
.ansi-red-fg { color: #E75C58; }
|
||
.ansi-red-bg { background-color: #E75C58; }
|
||
.ansi-red-intense-fg { color: #B22B31; }
|
||
.ansi-red-intense-bg { background-color: #B22B31; }
|
||
.ansi-green-fg { color: #00A250; }
|
||
.ansi-green-bg { background-color: #00A250; }
|
||
.ansi-green-intense-fg { color: #007427; }
|
||
.ansi-green-intense-bg { background-color: #007427; }
|
||
.ansi-yellow-fg { color: #DDB62B; }
|
||
.ansi-yellow-bg { background-color: #DDB62B; }
|
||
.ansi-yellow-intense-fg { color: #B27D12; }
|
||
.ansi-yellow-intense-bg { background-color: #B27D12; }
|
||
.ansi-blue-fg { color: #208FFB; }
|
||
.ansi-blue-bg { background-color: #208FFB; }
|
||
.ansi-blue-intense-fg { color: #0065CA; }
|
||
.ansi-blue-intense-bg { background-color: #0065CA; }
|
||
.ansi-magenta-fg { color: #D160C4; }
|
||
.ansi-magenta-bg { background-color: #D160C4; }
|
||
.ansi-magenta-intense-fg { color: #A03196; }
|
||
.ansi-magenta-intense-bg { background-color: #A03196; }
|
||
.ansi-cyan-fg { color: #60C6C8; }
|
||
.ansi-cyan-bg { background-color: #60C6C8; }
|
||
.ansi-cyan-intense-fg { color: #258F8F; }
|
||
.ansi-cyan-intense-bg { background-color: #258F8F; }
|
||
.ansi-white-fg { color: #C5C1B4; }
|
||
.ansi-white-bg { background-color: #C5C1B4; }
|
||
.ansi-white-intense-fg { color: #A1A6B2; }
|
||
.ansi-white-intense-bg { background-color: #A1A6B2; }
|
||
|
||
.ansi-default-inverse-fg { color: #FFFFFF; }
|
||
.ansi-default-inverse-bg { background-color: #000000; }
|
||
|
||
.ansi-bold { font-weight: bold; }
|
||
.ansi-underline { text-decoration: underline; }
|
||
|
||
|
||
div.nbinput.container div.input_area div[class*=highlight] > pre,
|
||
div.nboutput.container div.output_area div[class*=highlight] > pre,
|
||
div.nboutput.container div.output_area div[class*=highlight].math,
|
||
div.nboutput.container div.output_area.rendered_html,
|
||
div.nboutput.container div.output_area > div.output_javascript,
|
||
div.nboutput.container div.output_area:not(.rendered_html) > img{
|
||
padding: 5px;
|
||
margin: 0;
|
||
}
|
||
|
||
/* fix copybtn overflow problem in chromium (needed for 'sphinx_copybutton') */
|
||
div.nbinput.container div.input_area > div[class^='highlight'],
|
||
div.nboutput.container div.output_area > div[class^='highlight']{
|
||
overflow-y: hidden;
|
||
}
|
||
|
||
/* hide copybtn icon on prompts (needed for 'sphinx_copybutton') */
|
||
.prompt a.copybtn {
|
||
display: none;
|
||
}
|
||
|
||
/* Some additional styling taken form the Jupyter notebook CSS */
|
||
div.rendered_html table {
|
||
border: none;
|
||
border-collapse: collapse;
|
||
border-spacing: 0;
|
||
color: black;
|
||
font-size: 12px;
|
||
table-layout: fixed;
|
||
}
|
||
div.rendered_html thead {
|
||
border-bottom: 1px solid black;
|
||
vertical-align: bottom;
|
||
}
|
||
div.rendered_html tr,
|
||
div.rendered_html th,
|
||
div.rendered_html td {
|
||
text-align: right;
|
||
vertical-align: middle;
|
||
padding: 0.5em 0.5em;
|
||
line-height: normal;
|
||
white-space: normal;
|
||
max-width: none;
|
||
border: none;
|
||
}
|
||
div.rendered_html th {
|
||
font-weight: bold;
|
||
}
|
||
div.rendered_html tbody tr:nth-child(odd) {
|
||
background: #f5f5f5;
|
||
}
|
||
div.rendered_html tbody tr:hover {
|
||
background: rgba(66, 165, 245, 0.2);
|
||
}
|
||
|
||
/* CSS overrides for sphinx_rtd_theme */
|
||
|
||
/* 24px margin */
|
||
.nbinput.nblast.container,
|
||
.nboutput.nblast.container {
|
||
margin-bottom: 19px; /* padding has already 5px */
|
||
}
|
||
|
||
/* ... except between code cells! */
|
||
.nblast.container + .nbinput.container {
|
||
margin-top: -19px;
|
||
}
|
||
|
||
.admonition > p:before {
|
||
margin-right: 4px; /* make room for the exclamation icon */
|
||
}
|
||
|
||
/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
|
||
.math {
|
||
text-align: unset;
|
||
}
|
||
</style>
|
||
<div class="section" id="Fused-Softmax">
|
||
<h1>Fused Softmax<a class="headerlink" href="#Fused-Softmax" title="Permalink to this headline">¶</a></h1>
|
||
<p>Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:</p>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
|
||
<span class="c1"># Compute the row-wise softmax of x \in R^{M \times N}</span>
|
||
<span class="k">def</span> <span class="nf">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||
<span class="c1"># read MN elements ; write M elements</span>
|
||
<span class="n">x_max</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||
<span class="c1"># read MN elements ; write MN elements</span>
|
||
<span class="n">numerator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="c1"># read MN elements ; write M elements</span>
|
||
<span class="n">denominator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">numerator</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># read 2MN elements ; write MN elements</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="n">numerator</span> <span class="o">/</span> <span class="n">denominator</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span>
|
||
<span class="c1"># in total: read 7MN elements ; wrote 3MN + 2M elements</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<p>When implemented naively in pytorch, computing <span class="math notranslate nohighlight">\(y\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.</p>
|
||
<p>Instead, we want to write a custom “fused” pytorch operators that only reads X once and does all the necessary computations on-chip. This would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of 5x. In practice, though, we expect less because our kernel will spend some time computing exponentials and moving data around in shared memory.</p>
|
||
<div class="section" id="Writing-the-Compute-Kernel">
|
||
<h2>Writing the Compute Kernel<a class="headerlink" href="#Writing-the-Compute-Kernel" title="Permalink to this headline">¶</a></h2>
|
||
<p>Our softmax kernel works as follows: each program loads a row of X and writes back a normalized row of Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, which means that we need to guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||
<div class="highlight-c notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||
<span class="c1">// row index</span>
|
||
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||
<span class="c1">// column indices</span>
|
||
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||
<span class="c1">// the memory address of all the elements</span>
|
||
<span class="c1">// that we want to load can be computed as follows</span>
|
||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||
<span class="c1">// because BLOCK has to be a power of two</span>
|
||
<span class="c1">// (per Triton-C specs), it is important</span>
|
||
<span class="c1">// to guard each memory operation with predicates</span>
|
||
<span class="c1">// or we will read out of bounds</span>
|
||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
|
||
<span class="c1">// syntax for reduction in Triton is:</span>
|
||
<span class="c1">// x[..., OPERATOR, ...]</span>
|
||
<span class="c1">// ^</span>
|
||
<span class="c1">// index</span>
|
||
<span class="c1">// The operators currently supported are {min, max, +}</span>
|
||
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
|
||
<span class="c1">// The exponential in Triton is fast but approximate</span>
|
||
<span class="c1">// (i.e., like __expf in CUDA)</span>
|
||
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
|
||
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
|
||
<span class="c1">// The result of the reduction is now stored in y</span>
|
||
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
|
||
<span class="c1">// We write it back</span>
|
||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
|
||
<span class="p">}</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="Writing-the-Torch-bindings">
|
||
<h2>Writing the Torch bindings<a class="headerlink" href="#Writing-the-Torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">triton</span>
|
||
|
||
<span class="c1"># source-code for Triton compute kernel</span>
|
||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
|
||
<span class="s2"> int m = get_program_id(0);</span>
|
||
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
|
||
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
|
||
<span class="s2"> bool check[BLOCK] = n < N;</span>
|
||
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
|
||
<span class="s2"> float z [BLOCK] = x - x[max];</span>
|
||
<span class="s2"> float num [BLOCK] = exp(z);</span>
|
||
<span class="s2"> float denom = num[+];</span>
|
||
<span class="s2"> float y [BLOCK] = num / denom;</span>
|
||
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
|
||
<span class="s2"> *?(check)py = y;</span>
|
||
<span class="s2">}</span>
|
||
<span class="s2">"""</span>
|
||
|
||
<span class="c1"># We need to make sure that BLOCK is the smallest power of two</span>
|
||
<span class="c1"># greater than the number of rows N of the input matrix.</span>
|
||
<span class="c1"># Different values of BLOCK will result in different kernels</span>
|
||
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">4</span>
|
||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">8</span>
|
||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">16</span>
|
||
<span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="n">n</span>
|
||
|
||
<span class="n">_kernels</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
|
||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_kernels</span><span class="p">:</span>
|
||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
|
||
<span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_kernels</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||
|
||
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||
<span class="c1"># constraints of the op</span>
|
||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="c1"># *create launch grid*:</span>
|
||
<span class="c1"># here we just launch a grid of M programs</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
|
||
<span class="c1"># *launch kernel*:</span>
|
||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">grid</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">y</span>
|
||
|
||
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
<div class="section" id="Writing-a-Unit-Test">
|
||
<h2>Writing a Unit Test<a class="headerlink" href="#Writing-a-Unit-Test" title="Permalink to this headline">¶</a></h2>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1823</span><span class="p">,</span> <span class="mi">781</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||
<span class="n">y_tri</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="n">y_ref</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">y_tri</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">y_ref</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">y_tri</span><span class="p">,</span> <span class="n">y_ref</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area docutils container">
|
||
<div class="highlight"><pre>
|
||
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
|
||
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
|
||
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
|
||
...,
|
||
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
|
||
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
|
||
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
|
||
device='cuda:0')
|
||
tensor([[0.0004, 0.0006, 0.0004, ..., 0.0005, 0.0004, 0.0010],
|
||
[0.0003, 0.0029, 0.0004, ..., 0.0007, 0.0017, 0.0004],
|
||
[0.0002, 0.0006, 0.0005, ..., 0.0028, 0.0009, 0.0003],
|
||
...,
|
||
[0.0017, 0.0005, 0.0010, ..., 0.0006, 0.0004, 0.0001],
|
||
[0.0010, 0.0006, 0.0001, ..., 0.0006, 0.0017, 0.0014],
|
||
[0.0037, 0.0012, 0.0006, ..., 0.0003, 0.0005, 0.0003]],
|
||
device='cuda:0')
|
||
True
|
||
</pre></div></div>
|
||
</div>
|
||
<p>Seems to work!</p>
|
||
</div>
|
||
<div class="section" id="Writing-a-Benchmark">
|
||
<h2>Writing a Benchmark<a class="headerlink" href="#Writing-a-Benchmark" title="Permalink to this headline">¶</a></h2>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre>
|
||
<span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
|
||
|
||
<span class="n">M</span> <span class="o">=</span> <span class="mi">4096</span>
|
||
<span class="n">Ns</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="o">*</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)]</span>
|
||
<span class="n">tri_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">ref_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">def_ms</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">N</span> <span class="ow">in</span> <span class="n">Ns</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">gbps</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">nelement</span><span class="p">()</span> <span class="o">*</span> <span class="n">x</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="mf">1e-9</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||
<span class="n">tri_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||
<span class="n">ref_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))]</span>
|
||
<span class="n">def_ms</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gbps</span><span class="p">(</span><span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">naive_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">)))]</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'N'</span><span class="p">)</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'Bandwidth (GB/s)'</span><span class="p">)</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">tri_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Triton'</span><span class="p">)</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">ref_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Torch'</span><span class="p">)</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">Ns</span><span class="p">,</span> <span class="n">def_ms</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="s1">'Naive'</span><span class="p">)</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
|
||
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area docutils container">
|
||
<img alt="../_images/tutorials_02-fused-softmax_12_0.png" src="../_images/tutorials_02-fused-softmax_12_0.png" />
|
||
</div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</div>
|
||
|
||
</div>
|
||
<footer>
|
||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||
<a href="01-vector-add.html" class="btn btn-neutral float-left" title="Vector Addition" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||
</div>
|
||
|
||
<hr/>
|
||
|
||
<div role="contentinfo">
|
||
<p>
|
||
© Copyright 2020, Philippe Tillet.
|
||
|
||
</p>
|
||
</div>
|
||
|
||
|
||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
|
||
|
||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
|
||
|
||
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
||
|
||
</footer>
|
||
</div>
|
||
</div>
|
||
|
||
</section>
|
||
|
||
</div>
|
||
|
||
|
||
<script type="text/javascript">
|
||
jQuery(function () {
|
||
SphinxRtdTheme.Navigation.enable(true);
|
||
});
|
||
</script>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
</body>
|
||
</html> |