|
|
|
@@ -43,7 +43,7 @@
|
|
|
|
|
|
|
|
|
|
<link rel="index" title="Index" href="../../genindex.html" />
|
|
|
|
|
<link rel="search" title="Search" href="../../search.html" />
|
|
|
|
|
<link rel="next" title="Introduction" href="../../programming-guide/introduction.html" />
|
|
|
|
|
<link rel="next" title="Introduction" href="../../programming-guide/chapter-1/introduction.html" />
|
|
|
|
|
<link rel="prev" title="Fused Softmax" href="02-fused-softmax.html" />
|
|
|
|
|
</head>
|
|
|
|
|
|
|
|
|
@@ -121,9 +121,10 @@
|
|
|
|
|
</ul>
|
|
|
|
|
<p class="caption"><span class="caption-text">Programming Guide</span></p>
|
|
|
|
|
<ul>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/introduction.html">Introduction</a></li>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/related-work.html">Related Work</a></li>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/triton-c.html">The Triton-C Language</a></li>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-1/introduction.html">Introduction</a></li>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-2/related-work.html">Related Work</a></li>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-3/triton-c.html">The Triton-C Language</a></li>
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../programming-guide/chapter-4/triton-ir.html">The Triton-IR Intermediate Representation</a></li>
|
|
|
|
|
</ul>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -355,46 +356,14 @@ If <code class="code docutils literal notranslate"><span class="pre">TYPE</span>
|
|
|
|
|
<span class="kn">import</span> <span class="nn">triton</span>
|
|
|
|
|
|
|
|
|
|
<span class="n">autotune_configs</span> <span class="o">=</span> <span class="p">[</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s2">"MB"</span><span class="p">:</span> <span class="s2">"128"</span><span class="p">,</span>
|
|
|
|
|
<span class="s2">"NB"</span><span class="p">:</span> <span class="s2">"128"</span><span class="p">,</span>
|
|
|
|
|
<span class="s2">"KB"</span><span class="p">:</span> <span class="s2">"32"</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'32'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'32'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span>
|
|
|
|
|
<span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span>
|
|
|
|
|
<span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span>
|
|
|
|
|
<span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s2">"MB"</span><span class="p">:</span> <span class="s2">"128"</span><span class="p">,</span> <span class="s2">"NB"</span><span class="p">:</span> <span class="s2">"128"</span><span class="p">,</span> <span class="s2">"KB"</span><span class="p">:</span> <span class="s2">"32"</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'128'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
|
|
|
|
|
<span class="n">triton</span><span class="o">.</span><span class="n">config</span><span class="p">(</span><span class="n">defines</span><span class="o">=</span><span class="p">{</span><span class="s1">'MB'</span><span class="p">:</span> <span class="s1">'32'</span><span class="p">,</span> <span class="s1">'NB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">,</span> <span class="s1">'KB'</span><span class="p">:</span> <span class="s1">'64'</span><span class="p">},</span> <span class="n">num_warps</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
|
|
|
|
<span class="p">]</span>
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
@@ -490,21 +459,21 @@ Note that we need to modify the :code`atol` and <code class="code docutils liter
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
<p class="sphx-glr-script-out">Out:</p>
|
|
|
|
|
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
|
|
|
|
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
|
|
|
|
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
|
|
|
|
...,
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
|
|
|
|
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
|
|
|
|
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
|
|
|
|
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
|
|
|
|
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
|
|
|
|
...,
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
|
|
|
|
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
|
|
|
|
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
True
|
|
|
|
|
</pre></div>
|
|
|
|
@@ -518,7 +487,7 @@ True
|
|
|
|
|
For this reason, we will instead compare the performance of our kernel against <a class="reference external" href="https://github.com/NVIDIA/cutlass/">CUTLASS</a> , a highly optimized CUDA library for matrix multiplication written by NVIDIA themselves._
|
|
|
|
|
To install CUTLASS, you need a recent version of cmake:</p>
|
|
|
|
|
<blockquote>
|
|
|
|
|
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span> /tmp/
|
|
|
|
|
<div><div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span> /path/to/cutlass/
|
|
|
|
|
git clone https://github.com/NVIDIA/cutlass.git
|
|
|
|
|
<span class="nb">cd</span> cutlass
|
|
|
|
|
mkdir build
|
|
|
|
@@ -546,7 +515,7 @@ make -j8 install
|
|
|
|
|
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables <code class="code docutils literal notranslate"><span class="pre">CUTLASS_INCLUDE_DIR</span></code> and <code class="code docutils literal notranslate"><span class="pre">CUTLASS_LIBRARY_DIR</span></code> are set during the installation process.
|
|
|
|
|
To re-install Triton with the updated CUTLASS bindings, run the following command:</p>
|
|
|
|
|
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span> <span class="nv">CUTLASS_INCLUDE_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/include/
|
|
|
|
|
<span class="nb">export</span> <span class="nv">CUTLASS_LIBRARY_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/lib/
|
|
|
|
|
<span class="nb">export</span> <span class="nv">CUTLASS_LIBRARY_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/lib/a
|
|
|
|
|
pip uninstall -y triton
|
|
|
|
|
pip install -e <span class="s2">"git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"</span>
|
|
|
|
|
</pre></div>
|
|
|
|
@@ -559,13 +528,13 @@ pip install -e <span class="s2">"git+https://github.com/ptillet/triton.git#
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
<p class="sphx-glr-script-out">Out:</p>
|
|
|
|
|
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[199.0000, 199.1250, 195.8750, ..., 190.6250, 200.7500, 186.3750],
|
|
|
|
|
[196.1250, 201.6250, 197.6250, ..., 189.6250, 197.7500, 190.0000],
|
|
|
|
|
[198.0000, 196.6250, 200.1250, ..., 198.6250, 199.7500, 190.8750],
|
|
|
|
|
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>tensor([[199.6250, 198.0000, 195.0000, ..., 186.0000, 193.6250, 202.1250],
|
|
|
|
|
[192.6250, 193.6250, 190.7500, ..., 184.2500, 191.2500, 192.1250],
|
|
|
|
|
[192.3750, 196.6250, 188.8750, ..., 185.5000, 188.7500, 191.8750],
|
|
|
|
|
...,
|
|
|
|
|
[190.3750, 192.0000, 190.5000, ..., 187.0000, 191.7500, 180.8750],
|
|
|
|
|
[185.2500, 187.6250, 181.2500, ..., 185.1250, 188.2500, 175.5000],
|
|
|
|
|
[191.6250, 191.6250, 194.2500, ..., 188.2500, 192.1250, 182.0000]],
|
|
|
|
|
[196.6250, 199.8750, 196.1250, ..., 182.6250, 194.5000, 200.8750],
|
|
|
|
|
[199.2500, 200.3750, 191.7500, ..., 186.8750, 192.8750, 193.5000],
|
|
|
|
|
[193.5000, 195.2500, 194.1250, ..., 188.3750, 192.6250, 198.3750]],
|
|
|
|
|
device='cuda:0', dtype=torch.float16)
|
|
|
|
|
True
|
|
|
|
|
</pre></div>
|
|
|
|
@@ -605,7 +574,7 @@ True
|
|
|
|
|
</div>
|
|
|
|
|
<img alt="matmul-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
|
|
|
|
<p>As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.</p>
|
|
|
|
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 10.094 seconds)</p>
|
|
|
|
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 6.502 seconds)</p>
|
|
|
|
|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
|
|
|
|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
|
|
|
|
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
|
|
|
@@ -625,7 +594,7 @@ True
|
|
|
|
|
</div>
|
|
|
|
|
<footer>
|
|
|
|
|
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
|
|
|
|
<a href="../../programming-guide/introduction.html" class="btn btn-neutral float-right" title="Introduction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
|
|
|
|
<a href="../../programming-guide/chapter-1/introduction.html" class="btn btn-neutral float-right" title="Introduction" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
|
|
|
|
|
<a href="02-fused-softmax.html" class="btn btn-neutral float-left" title="Fused Softmax" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
|
|
|
|
|
</div>
|
|
|
|
|
|
|
|
|
|