[GH-PAGES] Updated website
This commit is contained in:
@@ -98,7 +98,6 @@
|
||||
<li class="toctree-l2 current"><a class="current reference internal" href="#">Fused Softmax</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#motivations">Motivations</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#compute-kernel">Compute Kernel</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#torch-bindings">Torch Bindings</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#unit-test">Unit Test</a></li>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#benchmark">Benchmark</a></li>
|
||||
</ul>
|
||||
@@ -107,12 +106,14 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -192,8 +193,7 @@ to download the full example code</p>
|
||||
<p>In this tutorial, you will write a fused softmax operation (that outperforms PyTorch) and learn about:</p>
|
||||
<ul class="simple">
|
||||
<li><p>The benefits of kernel fusion for bandwidth-bound operations.</p></li>
|
||||
<li><p>The syntax and usage of reduction operators in Triton.</p></li>
|
||||
<li><p>The automatic vectorization capabilities of the Triton compiler.</p></li>
|
||||
<li><p>The reduction operators in Triton.</p></li>
|
||||
</ul>
|
||||
<div class="section" id="motivations">
|
||||
<h2>Motivations<a class="headerlink" href="#motivations" title="Permalink to this headline">¶</a></h2>
|
||||
@@ -220,78 +220,41 @@ Let us consider instead the case of a simple (numerically stabilized) softmax op
|
||||
</div>
|
||||
<p>When implemented naively in pytorch, computing <code class="code docutils literal notranslate"><span class="pre">y</span> <span class="pre">=</span> <span class="pre">naive_softmax(x)</span></code> for <span class="math notranslate nohighlight">\(x \in R^{M \times N}\)</span> requires reading <span class="math notranslate nohighlight">\(7MN\)</span> elements from DRAM and writing back <span class="math notranslate nohighlight">\(3MN + 2M\)</span> elements.
|
||||
This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip.
|
||||
In this case, we would be reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
This solution would require reading and writing back only <span class="math notranslate nohighlight">\(MN\)</span> bytes, so we could expect a theoretical speed-up of ~5x (i.e., <span class="math notranslate nohighlight">\((10MN + 2M) / 2MN\)</span>).
|
||||
In practice, though, we would be getting a bit less as our kernel computes exponentials and internally moves data around in shared memory.</p>
|
||||
</div>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input X, normalizes it and writes back the result to the output Y.
|
||||
<p>Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y.
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally “pad” tiles and guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<blockquote>
|
||||
<div><div class="highlight-C notranslate"><div class="highlight"><pre><span></span><span class="n">__global__</span> <span class="kt">void</span> <span class="n">softmax</span><span class="p">(</span><span class="kt">float</span><span class="o">*</span> <span class="n">Y</span><span class="p">,</span> <span class="kt">float</span><span class="o">*</span> <span class="n">X</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="kt">int</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="kt">int</span> <span class="n">M</span><span class="p">,</span> <span class="kt">int</span> <span class="n">N</span><span class="p">){</span>
|
||||
<span class="c1">// row index</span>
|
||||
<span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="n">get_program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
||||
<span class="c1">// column indices</span>
|
||||
<span class="kt">int</span> <span class="n">n</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="p">...</span> <span class="n">BLOCK</span><span class="p">;</span>
|
||||
<span class="c1">// the memory address of all the elements</span>
|
||||
<span class="c1">// that we want to load can be computed as follows</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">px</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="c1">// because BLOCK has to be a power of two</span>
|
||||
<span class="c1">// (per Triton-C specs), it is important</span>
|
||||
<span class="c1">// to guard each memory operation with predicates</span>
|
||||
<span class="c1">// or we will read out of bounds</span>
|
||||
<span class="kt">bool</span> <span class="n">check</span><span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">;</span>
|
||||
<span class="kt">float</span> <span class="n">x</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">check</span> <span class="o">?</span> <span class="o">*</span><span class="nl">px</span> <span class="p">:</span> <span class="o">-</span><span class="n">F32_INFINITY</span><span class="p">;</span>
|
||||
<span class="c1">// syntax for reduction in Triton is:</span>
|
||||
<span class="c1">// x[:, :, OPERATOR, :, :]</span>
|
||||
<span class="c1">// ^</span>
|
||||
<span class="c1">// index</span>
|
||||
<span class="c1">// where operator is in {min, max, +}</span>
|
||||
<span class="c1">// for 1D vectors, this is just x[OPERATOR].</span>
|
||||
<span class="kt">float</span> <span class="n">z</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x</span><span class="p">[</span><span class="n">max</span><span class="p">];</span>
|
||||
<span class="c1">// Note that exponentials in Triton are fast</span>
|
||||
<span class="c1">// but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="kt">float</span> <span class="n">num</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
|
||||
<span class="kt">float</span> <span class="n">denom</span> <span class="o">=</span> <span class="n">num</span><span class="p">[</span><span class="o">+</span><span class="p">];</span>
|
||||
<span class="c1">// The result of the reduction is now stored in y</span>
|
||||
<span class="kt">float</span> <span class="n">y</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span><span class="p">;</span>
|
||||
<span class="c1">// We write it back</span>
|
||||
<span class="kt">float</span><span class="o">*</span> <span class="n">py</span> <span class="p">[</span><span class="n">BLOCK</span><span class="p">]</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span><span class="p">;</span>
|
||||
<span class="o">*?</span><span class="p">(</span><span class="n">check</span><span class="p">)</span><span class="n">py</span> <span class="o">=</span> <span class="n">y</span><span class="p">;</span>
|
||||
<span class="p">}</span>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_softmax</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="o">**</span><span class="n">meta</span><span class="p">):</span>
|
||||
<span class="c1"># row index</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># col indices</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># the memory address of all the elements</span>
|
||||
<span class="c1"># that we want to load can be computed as follows</span>
|
||||
<span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||||
<span class="c1"># Substract maximum for numerical stability</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">triton</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Note that exponentials in Triton are fast</span>
|
||||
<span class="c1"># but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="n">num</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">denom</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">num</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span>
|
||||
<span class="c1"># Write back to Y</span>
|
||||
<span class="n">Y</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
</div></blockquote>
|
||||
</div>
|
||||
<div class="section" id="torch-bindings">
|
||||
<h2>Torch Bindings<a class="headerlink" href="#torch-bindings" title="Permalink to this headline">¶</a></h2>
|
||||
<p>Here our torch bindings is quite similar to that of the vector addition mentioned in the previous tutorial.
|
||||
We just need to make sure that BLOCK is the smallest power of two greater than the number of columns N of the input matrix.
|
||||
This means that different values of BLOCK will result in different kernels</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
<span class="c1"># Source code for the Triton kernel</span>
|
||||
<span class="n">_src</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2">__global__ void softmax(float* Y, float* X, int stride_ym, int stride_xm, int M, int N){</span>
|
||||
<span class="s2"> int m = get_program_id(0);</span>
|
||||
<span class="s2"> int n [BLOCK] = 0 ... BLOCK;</span>
|
||||
<span class="s2"> float* px [BLOCK] = X + m*stride_xm + n;</span>
|
||||
<span class="s2"> bool check[BLOCK] = n < N;</span>
|
||||
<span class="s2"> float x [BLOCK] = check ? *px : -F32_INFINITY;</span>
|
||||
<span class="s2"> float z [BLOCK] = x - x[max];</span>
|
||||
<span class="s2"> float num [BLOCK] = exp(z);</span>
|
||||
<span class="s2"> float denom = num[+];</span>
|
||||
<span class="s2"> float y [BLOCK] = num / denom;</span>
|
||||
<span class="s2"> float* py [BLOCK] = Y + m*stride_ym + n;</span>
|
||||
<span class="s2"> *?(check)py = y;</span>
|
||||
<span class="s2">}</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
|
||||
<span class="c1"># helper function to get the smaller power-of-two larger than a given number</span>
|
||||
<span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<p>We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">next_power_of_2</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
|
||||
<span class="n">n</span> <span class="o">-=</span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">1</span>
|
||||
<span class="n">n</span> <span class="o">|=</span> <span class="n">n</span> <span class="o">>></span> <span class="mi">2</span>
|
||||
@@ -302,11 +265,9 @@ This means that different values of BLOCK will result in different kernels</p>
|
||||
<span class="k">return</span> <span class="n">n</span>
|
||||
|
||||
|
||||
<span class="c1"># kernel caching mechanism</span>
|
||||
<span class="k">def</span> <span class="nf">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
|
||||
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span>
|
||||
<span class="c1"># Now are kernels are indexed not only by the provided device but also</span>
|
||||
<span class="c1"># by the rounded number of columns in the input matrix</span>
|
||||
<span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="c1"># The block size is the smallest power of two greater than the number of columns in `x`</span>
|
||||
<span class="n">BLOCK</span> <span class="o">=</span> <span class="n">next_power_of_2</span><span class="p">(</span><span class="n">N</span><span class="p">)</span>
|
||||
<span class="c1"># Another trick we can use is to ask the compiler to parallelize each</span>
|
||||
<span class="c1"># row-normalization more aggressively -- i.e., with more warps -- vectors</span>
|
||||
@@ -316,36 +277,13 @@ This means that different values of BLOCK will result in different kernels</p>
|
||||
<span class="n">num_warps</span> <span class="o">=</span> <span class="mi">4</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">2048</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="k">if</span> <span class="n">BLOCK</span> <span class="o">>=</span> <span class="mi">4096</span><span class="p">:</span> <span class="n">num_warps</span> <span class="o">=</span> <span class="mi">16</span>
|
||||
<span class="c1"># Each (BLOCK, num_warps, device) results in a different kernel</span>
|
||||
<span class="n">key</span> <span class="o">=</span> <span class="p">(</span><span class="n">BLOCK</span><span class="p">,</span> <span class="n">num_warps</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
|
||||
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'BLOCK'</span><span class="p">:</span> <span class="n">BLOCK</span><span class="p">}</span>
|
||||
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">_src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span> <span class="n">num_warps</span><span class="o">=</span><span class="n">num_warps</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">_softmax</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||||
<span class="c1"># constraints of the op</span>
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># The launch grid is simple: we have one kernel instance per row of the input matrix</span>
|
||||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">grid</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">opt</span><span class="p">:</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="p">)</span>
|
||||
<span class="c1"># Launch kernel</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">kernel</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">grid</span><span class="o">=</span><span class="n">grid</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
|
||||
|
||||
<span class="n">softmax</span> <span class="o">=</span> <span class="n">_softmax</span><span class="o">.</span><span class="n">apply</span>
|
||||
<span class="c1"># Allocate output</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||||
<span class="c1"># Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix</span>
|
||||
<span class="n">_softmax</span><span class="p">[(</span><span class="n">M</span><span class="p">,</span> <span class="p">)](</span><span class="n">y</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">y</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">BLOCK</span><span class="o">=</span><span class="n">BLOCK</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">y</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can use the above softmax function to compute the row-wise softmax of a given matrix.</p>
|
||||
</div>
|
||||
<div class="section" id="unit-test">
|
||||
<h2>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline">¶</a></h2>
|
||||
@@ -405,7 +343,7 @@ This means that – when temporary data is too large to fit entirely in the GPU
|
||||
Note that our Triton kernel is not only faster than PyTorch’s CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 25.654 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 20.767 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|
Reference in New Issue
Block a user