[GH-PAGES] Updated website
This commit is contained in:
@@ -102,9 +102,11 @@
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="tutorials/index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../language-reference/python-api/index.html">Python API</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
|
@@ -103,9 +103,11 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
@@ -196,6 +198,7 @@ to download the full example code</p>
|
||||
<div class="section" id="compute-kernel">
|
||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
|
||||
|
||||
@@ -207,19 +210,19 @@ to download the full example code</p>
|
||||
<span class="n">N</span><span class="p">,</span> <span class="c1"># Size of the vector</span>
|
||||
<span class="o">**</span><span class="n">meta</span> <span class="c1"># Optional meta-parameters for the kernel</span>
|
||||
<span class="p">):</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Create an offset for the blocks of pointers to be</span>
|
||||
<span class="c1"># processed by this program instance</span>
|
||||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="n">offsets</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">*</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">]</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># Create a mask to guard memory operations against</span>
|
||||
<span class="c1"># out-of-bounds accesses</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="n">offsets</span> <span class="o"><</span> <span class="n">N</span>
|
||||
<span class="c1"># Load x</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Y</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="c1"># Write back x + y</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Z</span> <span class="o">+</span> <span class="n">offsets</span><span class="p">,</span> <span class="n">z</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can also declara a helper function that handles allocating the output vector
|
||||
@@ -270,9 +273,9 @@ for different problem sizes.</p>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'size'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">x_log</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># x axis is logarithmic</span>
|
||||
<span class="n">y_name</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
|
||||
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg`</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"GB/s"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"vector-add-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
@@ -295,7 +298,7 @@ for different problem sizes.</p>
|
||||
</pre></div>
|
||||
</div>
|
||||
<img alt="01 vector add" class="sphx-glr-single-img" src="../../_images/sphx_glr_01-vector-add_001.png" />
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 7.044 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 7.682 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||
|
@@ -106,9 +106,11 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
@@ -229,28 +231,29 @@ In practice, though, we would be getting a bit less as our kernel computes expon
|
||||
Note that one important limitation of Triton is that each block must have a power-of-two number of elements,
|
||||
so we need to internally “pad” tiles and guard the memory operations properly if we want to handle any possible input shapes:</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">_softmax</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">stride_xm</span><span class="p">,</span> <span class="n">stride_ym</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="o">**</span><span class="n">meta</span><span class="p">):</span>
|
||||
<span class="c1"># row index</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># col indices</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="n">n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK'</span><span class="p">])</span>
|
||||
<span class="c1"># the memory address of all the elements</span>
|
||||
<span class="c1"># that we want to load can be computed as follows</span>
|
||||
<span class="n">X</span> <span class="o">=</span> <span class="n">X</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_xm</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||||
<span class="n">x</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">,</span> <span class="n">other</span><span class="o">=-</span><span class="nb">float</span><span class="p">(</span><span class="s1">'inf'</span><span class="p">))</span>
|
||||
<span class="c1"># Substract maximum for numerical stability</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">triton</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">z</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">tl</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="c1"># Note that exponentials in Triton are fast</span>
|
||||
<span class="c1"># but approximate (i.e., think __expf in CUDA)</span>
|
||||
<span class="n">num</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">denom</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">num</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">num</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
|
||||
<span class="n">denom</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">num</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">y</span> <span class="o">=</span> <span class="n">num</span> <span class="o">/</span> <span class="n">denom</span>
|
||||
<span class="c1"># Write back to Y</span>
|
||||
<span class="n">Y</span> <span class="o">=</span> <span class="n">Y</span> <span class="o">+</span> <span class="n">m</span> <span class="o">*</span> <span class="n">stride_ym</span> <span class="o">+</span> <span class="n">n</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">n</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.</p>
|
||||
@@ -310,9 +313,9 @@ We will then compare its performance against (1) <code class="code docutils lite
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'N'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">50</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">y_name</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'naive'</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
|
||||
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s1">'Naive'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'naive'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s1">'Naive'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"GB/s"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"softmax-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{</span><span class="s1">'M'</span><span class="p">:</span> <span class="mi">4096</span><span class="p">}</span> <span class="c1"># values for function arguments not in `x_names` and `y_name`</span>
|
||||
@@ -343,7 +346,7 @@ This means that – when temporary data is too large to fit entirely in the GPU
|
||||
Note that our Triton kernel is not only faster than PyTorch’s CUDA kernel, it is also <strong>easier to read, understand and maintain</strong>.</p></li>
|
||||
</ul>
|
||||
</div></blockquote>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 20.176 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 20.250 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||
|
@@ -43,7 +43,7 @@
|
||||
|
||||
<link rel="index" title="Index" href="../../genindex.html" />
|
||||
<link rel="search" title="Search" href="../../search.html" />
|
||||
<link rel="next" title="Python API" href="../../language-reference/python-api/index.html" />
|
||||
<link rel="next" title="triton" href="../../python-api/triton.html" />
|
||||
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
|
||||
</head>
|
||||
|
||||
@@ -113,9 +113,11 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
@@ -299,6 +301,7 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<h2>Final Result<a class="headerlink" href="#final-result" title="Permalink to this headline">¶</a></h2>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">import</span> <span class="nn">triton</span>
|
||||
<span class="kn">import</span> <span class="nn">triton.language</span> <span class="k">as</span> <span class="nn">tl</span>
|
||||
|
||||
<span class="c1"># %</span>
|
||||
<span class="c1"># :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:</span>
|
||||
@@ -308,9 +311,9 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
<span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
|
||||
<span class="n">ret_true</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="n">ret_false</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">triton</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">ret_true</span><span class="p">,</span> <span class="n">ret_false</span><span class="p">)</span>
|
||||
<span class="n">ret_true</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="n">ret_false</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">tl</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">ret_true</span><span class="p">,</span> <span class="n">ret_false</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||
@@ -335,7 +338,7 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<span class="n">BLOCK_K</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_K'</span><span class="p">]</span>
|
||||
<span class="n">GROUP_M</span> <span class="o">=</span> <span class="mi">8</span>
|
||||
<span class="c1"># matrix multiplication</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">grid_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_M</span>
|
||||
<span class="n">grid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_N</span>
|
||||
<span class="c1"># re-order program ID for better L2 performance</span>
|
||||
@@ -345,16 +348,16 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">)</span>
|
||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">group_size</span><span class="p">)</span>
|
||||
<span class="c1"># do matrix multiplication</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">rk</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">rk</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_K</span><span class="p">)</span>
|
||||
<span class="n">A</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_am</span> <span class="o">+</span> <span class="n">rk</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">)</span>
|
||||
<span class="n">B</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_bk</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_bn</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">triton</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_M</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="n">BLOCK_K</span><span class="p">):</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">triton</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
|
||||
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
|
||||
<span class="n">acc</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||
<span class="n">A</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
||||
<span class="n">B</span> <span class="o">+=</span> <span class="n">BLOCK_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||||
<span class="c1"># triton can accept arbitrary activation function</span>
|
||||
@@ -362,11 +365,11 @@ This can be done by ‘super-grouping’ blocks in groups of <code class="code d
|
||||
<span class="k">if</span> <span class="n">META</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">]:</span>
|
||||
<span class="n">acc</span> <span class="o">=</span> <span class="n">META</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">](</span><span class="n">acc</span><span class="p">)</span>
|
||||
<span class="c1"># rematerialize rm and rn to save registers</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_M</span><span class="p">)</span>
|
||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_N</span><span class="p">)</span>
|
||||
<span class="n">C</span> <span class="o">=</span> <span class="n">C</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_cm</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_cn</span><span class="p">)</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">acc</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">acc</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>We can also create a convenience wrapper function that only takes two input tensors
|
||||
@@ -406,32 +409,32 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[-5.9605e-08, 5.1094e+01, -1.8477e-05, ..., 2.6547e+01,
|
||||
-7.2598e-05, -4.2510e-04],
|
||||
[-2.7100e-01, -3.0220e-05, 5.9414e+00, ..., 2.8340e+00,
|
||||
-1.8644e-04, 1.3094e+01],
|
||||
[-1.5332e-01, 4.8125e+00, 8.4277e-01, ..., 3.6387e+00,
|
||||
4.3375e+01, 1.6865e+00],
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[-4.5061e-05, 4.1656e+01, 1.7500e+01, ..., -2.7405e-02,
|
||||
-2.3251e-03, -0.0000e+00],
|
||||
[-1.0967e-04, -4.2915e-06, -0.0000e+00, ..., -1.4901e-06,
|
||||
-0.0000e+00, 1.4367e+01],
|
||||
[ 5.8156e+01, -0.0000e+00, -1.4603e-04, ..., 1.3930e+01,
|
||||
-2.1362e-01, 9.4062e+00],
|
||||
...,
|
||||
[-0.0000e+00, 2.9453e+01, -4.7684e-07, ..., 6.2617e+00,
|
||||
4.1133e+00, -0.0000e+00],
|
||||
[ 1.6562e+01, -8.1539e-04, 1.3836e+01, ..., 1.9844e+00,
|
||||
-1.1238e-02, 8.4375e+00],
|
||||
[-1.0876e-01, -2.7295e-01, 3.2156e+01, ..., -1.6907e-02,
|
||||
-0.0000e+00, -0.0000e+00]], device='cuda:0', dtype=torch.float16)
|
||||
tensor([[-5.9605e-08, 5.1094e+01, -1.8537e-05, ..., 2.6547e+01,
|
||||
-7.2658e-05, -4.2605e-04],
|
||||
[-2.7100e-01, -3.0220e-05, 5.9414e+00, ..., 2.8340e+00,
|
||||
-1.8632e-04, 1.3094e+01],
|
||||
[-1.5332e-01, 4.8125e+00, 8.4277e-01, ..., 3.6387e+00,
|
||||
4.3375e+01, 1.6875e+00],
|
||||
[ 2.3703e+01, -9.2163e-02, -1.3471e-05, ..., -9.5215e-02,
|
||||
2.0047e+01, 1.4891e+01],
|
||||
[-1.9073e-06, 5.0664e+00, -0.0000e+00, ..., 2.0281e+01,
|
||||
-1.7583e-05, 3.8000e+01],
|
||||
[-1.7285e-05, 5.3945e+00, -1.3916e-01, ..., -2.0984e-01,
|
||||
5.3750e+00, -1.5993e-03]], device='cuda:0', dtype=torch.float16)
|
||||
tensor([[-4.4942e-05, 4.1656e+01, 1.7500e+01, ..., -2.7405e-02,
|
||||
-2.3232e-03, -0.0000e+00],
|
||||
[-1.1003e-04, -4.2915e-06, -0.0000e+00, ..., -1.4901e-06,
|
||||
-0.0000e+00, 1.4367e+01],
|
||||
[ 5.8156e+01, -0.0000e+00, -1.4639e-04, ..., 1.3930e+01,
|
||||
-2.1362e-01, 9.4062e+00],
|
||||
...,
|
||||
[-0.0000e+00, 2.9453e+01, -4.7684e-07, ..., 6.2617e+00,
|
||||
4.1133e+00, -0.0000e+00],
|
||||
[ 1.6562e+01, -8.1778e-04, 1.3836e+01, ..., 1.9844e+00,
|
||||
-1.1238e-02, 8.4375e+00],
|
||||
[-1.0876e-01, -2.7295e-01, 3.2156e+01, ..., -1.6891e-02,
|
||||
-0.0000e+00, -0.0000e+00]], device='cuda:0', dtype=torch.float16)
|
||||
[ 2.3703e+01, -9.2163e-02, -1.3471e-05, ..., -9.5276e-02,
|
||||
2.0047e+01, 1.4891e+01],
|
||||
[-1.9073e-06, 5.0664e+00, -0.0000e+00, ..., 2.0281e+01,
|
||||
-1.7583e-05, 3.8000e+01],
|
||||
[-1.7345e-05, 5.3945e+00, -1.3916e-01, ..., -2.0984e-01,
|
||||
5.3750e+00, -1.6031e-03]], device='cuda:0', dtype=torch.float16)
|
||||
tensor(True, device='cuda:0')
|
||||
</pre></div>
|
||||
</div>
|
||||
@@ -445,9 +448,9 @@ tensor(True, device='cuda:0')
|
||||
<span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">Benchmark</span><span class="p">(</span>
|
||||
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
||||
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
||||
<span class="n">y_name</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
|
||||
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">line_arg</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
||||
<span class="n">line_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">],</span> <span class="c1"># possible values for `line_arg``</span>
|
||||
<span class="n">line_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
||||
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
||||
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
||||
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
|
||||
@@ -465,7 +468,7 @@ tensor(True, device='cuda:0')
|
||||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">print_data</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<img alt="03 matrix multiplication" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
||||
@@ -473,38 +476,38 @@ tensor(True, device='cuda:0')
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span> M cuBLAS Triton
|
||||
0 512.0 20.164923 15.420235
|
||||
1 768.0 58.982401 40.215272
|
||||
2 1024.0 91.180520 72.315584
|
||||
3 1280.0 157.538463 117.028568
|
||||
4 1536.0 153.867127 144.446699
|
||||
2 1024.0 95.325090 72.315584
|
||||
3 1280.0 151.703703 117.028568
|
||||
4 1536.0 153.867127 150.593357
|
||||
5 1792.0 208.137481 190.498706
|
||||
6 2048.0 199.728763 152.520144
|
||||
7 2304.0 246.266731 178.267699
|
||||
8 2560.0 235.741014 215.578957
|
||||
9 2816.0 231.990461 198.246398
|
||||
10 3072.0 236.916752 221.184001
|
||||
11 3328.0 239.173747 210.500857
|
||||
12 3584.0 248.385067 230.552287
|
||||
13 3840.0 251.917998 222.519114
|
||||
14 4096.0 263.172024 244.032234
|
||||
15 4352.0 249.595626 232.307632
|
||||
16 4608.0 276.560014 254.803966
|
||||
17 4864.0 266.614125 245.366501
|
||||
18 5120.0 257.003930 238.096276
|
||||
19 5376.0 252.676487 236.527241
|
||||
20 5632.0 270.057027 248.514009
|
||||
21 5888.0 264.206935 242.511113
|
||||
22 6144.0 259.441481 241.205983
|
||||
23 6400.0 257.157204 235.078047
|
||||
24 6656.0 254.161678 232.699140
|
||||
25 6912.0 251.844029 233.178785
|
||||
26 7168.0 253.282797 231.740709
|
||||
27 7424.0 251.868505 230.377264
|
||||
28 7680.0 250.988932 231.606284
|
||||
29 7936.0 253.293068 229.692102
|
||||
30 8192.0 253.002304 231.360005
|
||||
6 2048.0 202.135135 151.146088
|
||||
7 2304.0 251.451276 178.267699
|
||||
8 2560.0 237.449270 218.453323
|
||||
9 2816.0 238.329010 200.987140
|
||||
10 3072.0 243.017615 223.806730
|
||||
11 3328.0 244.868356 210.500857
|
||||
12 3584.0 250.460703 232.941430
|
||||
13 3840.0 256.593972 225.697957
|
||||
14 4096.0 266.305018 247.634187
|
||||
15 4352.0 247.675667 237.797917
|
||||
16 4608.0 280.621108 260.713476
|
||||
17 4864.0 272.431168 252.534501
|
||||
18 5120.0 265.596772 245.223576
|
||||
19 5376.0 261.381955 244.335299
|
||||
20 5632.0 283.439220 260.383339
|
||||
21 5888.0 276.674704 254.103421
|
||||
22 6144.0 274.869441 252.078378
|
||||
23 6400.0 269.190319 249.027231
|
||||
24 6656.0 269.252160 249.104840
|
||||
25 6912.0 267.069377 247.115909
|
||||
26 7168.0 268.504352 246.006552
|
||||
27 7424.0 267.373291 246.355964
|
||||
28 7680.0 266.406511 245.760004
|
||||
29 7936.0 228.348876 248.331598
|
||||
30 8192.0 227.680622 247.977332
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 32.933 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 37.657 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||
@@ -524,7 +527,7 @@ tensor(True, device='cuda:0')
|
||||
</div>
|
||||
<footer>
|
||||
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
||||
<a href="../../language-reference/python-api/index.html" class="btn btn-neutral float-right" title="Python API" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="../../python-api/triton.html" class="btn btn-neutral float-right" title="triton" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
||||
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
||||
</div>
|
||||
|
||||
|
@@ -99,9 +99,11 @@
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
|
@@ -92,9 +92,11 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="index.html">Tutorials</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Language Reference</span></p>
|
||||
<p class="caption"><span class="caption-text">Python API</span></p>
|
||||
<ul>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../language-reference/python-api/index.html">Python API</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.html">triton</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.language.html">triton.language</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/triton.testing.html">triton.testing</a></li>
|
||||
</ul>
|
||||
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
||||
<ul>
|
||||
@@ -169,7 +171,7 @@
|
||||
|
||||
<div class="section" id="computation-times">
|
||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||
<p><strong>01:00.154</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<p><strong>00:37.657</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||
<table class="docutils align-default">
|
||||
<colgroup>
|
||||
<col style="width: 85%" />
|
||||
@@ -178,15 +180,15 @@
|
||||
</colgroup>
|
||||
<tbody>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
||||
<td><p>00:32.933</p></td>
|
||||
<td><p>00:37.657</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>00:20.176</p></td>
|
||||
<tr class="row-even"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:00.000</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||
<td><p>00:07.044</p></td>
|
||||
<tr class="row-odd"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||
<td><p>00:00.000</p></td>
|
||||
<td><p>0.0 MB</p></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
|
Reference in New Issue
Block a user