|
|
|
|
@@ -411,7 +411,13 @@ Here, we want to re-tune our kernel only when the shape of input matrices change
|
|
|
|
|
<span class="n">cache</span> <span class="o">=</span> <span class="n">make_kernel</span><span class="o">.</span><span class="n">cache</span>
|
|
|
|
|
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">cache</span><span class="p">:</span>
|
|
|
|
|
<span class="n">defines</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'TYPE'</span><span class="p">:</span> <span class="n">dtype</span><span class="p">}</span>
|
|
|
|
|
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span> <span class="n">autotune_vals</span><span class="o">=</span><span class="n">autotune_configs</span><span class="p">,</span> <span class="n">autotune_key</span><span class="o">=</span><span class="n">autotune_key</span><span class="p">)</span>
|
|
|
|
|
<span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">kernel</span><span class="p">(</span>
|
|
|
|
|
<span class="n">src</span><span class="p">,</span>
|
|
|
|
|
<span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
|
|
|
|
|
<span class="n">defines</span><span class="o">=</span><span class="n">defines</span><span class="p">,</span>
|
|
|
|
|
<span class="n">autotune_configs</span><span class="o">=</span><span class="n">autotune_configs</span><span class="p">,</span>
|
|
|
|
|
<span class="n">autotune_key</span><span class="o">=</span><span class="n">autotune_key</span><span class="p">,</span>
|
|
|
|
|
<span class="p">)</span>
|
|
|
|
|
<span class="k">return</span> <span class="n">cache</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -515,7 +521,7 @@ make -j8 install
|
|
|
|
|
Triton comes with some basic Python bindings for benchmarking CUTLASS. These will be compiled when the environment variables <code class="code docutils literal notranslate"><span class="pre">CUTLASS_INCLUDE_DIR</span></code> and <code class="code docutils literal notranslate"><span class="pre">CUTLASS_LIBRARY_DIR</span></code> are set during the installation process.
|
|
|
|
|
To re-install Triton with the updated CUTLASS bindings, run the following command:</p>
|
|
|
|
|
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">export</span> <span class="nv">CUTLASS_INCLUDE_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/include/
|
|
|
|
|
<span class="nb">export</span> <span class="nv">CUTLASS_LIBRARY_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/lib/a
|
|
|
|
|
<span class="nb">export</span> <span class="nv">CUTLASS_LIBRARY_DIR</span><span class="o">=</span>/tmp/cutlass/build/install/lib/
|
|
|
|
|
pip uninstall -y triton
|
|
|
|
|
pip install -e <span class="s2">"git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"</span>
|
|
|
|
|
</pre></div>
|
|
|
|
|
@@ -549,8 +555,8 @@ True
|
|
|
|
|
<span class="n">x_names</span><span class="o">=</span><span class="p">[</span><span class="s1">'M'</span><span class="p">,</span> <span class="s1">'N'</span><span class="p">,</span> <span class="s1">'K'</span><span class="p">],</span> <span class="c1"># argument names to use as an x-axis for the plot</span>
|
|
|
|
|
<span class="n">x_vals</span><span class="o">=</span><span class="p">[</span><span class="mi">256</span> <span class="o">*</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">33</span><span class="p">)],</span> <span class="c1"># different possible values for `x_name`</span>
|
|
|
|
|
<span class="n">y_name</span><span class="o">=</span><span class="s1">'provider'</span><span class="p">,</span> <span class="c1"># argument name whose value corresponds to a different line in the plot</span>
|
|
|
|
|
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'torch'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'cutlass'</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
|
|
|
|
|
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">"Torch"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s1">'CUTLASS'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
|
|
|
|
<span class="n">y_vals</span><span class="o">=</span><span class="p">[</span><span class="s1">'cublas'</span><span class="p">,</span> <span class="s1">'triton'</span><span class="p">,</span> <span class="s1">'cutlass'</span><span class="p">],</span> <span class="c1"># possible keys for `y_name`</span>
|
|
|
|
|
<span class="n">y_lines</span><span class="o">=</span><span class="p">[</span><span class="s2">"cuBLAS"</span><span class="p">,</span> <span class="s2">"Triton"</span><span class="p">,</span> <span class="s1">'CUTLASS'</span><span class="p">],</span> <span class="c1"># label name for the lines</span>
|
|
|
|
|
<span class="n">ylabel</span><span class="o">=</span><span class="s2">"TFLOPS"</span><span class="p">,</span> <span class="c1"># label name for the y-axis</span>
|
|
|
|
|
<span class="n">plot_name</span><span class="o">=</span><span class="s2">"matmul-performance"</span><span class="p">,</span> <span class="c1"># name for the plot. Used also as a file name for saving the plot.</span>
|
|
|
|
|
<span class="n">args</span><span class="o">=</span><span class="p">{}</span>
|
|
|
|
|
@@ -559,7 +565,7 @@ True
|
|
|
|
|
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">provider</span><span class="p">):</span>
|
|
|
|
|
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
|
|
|
|
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
|
|
|
|
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'torch'</span><span class="p">:</span>
|
|
|
|
|
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'cublas'</span><span class="p">:</span>
|
|
|
|
|
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
|
|
|
|
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton'</span><span class="p">:</span>
|
|
|
|
|
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span><span class="k">lambda</span><span class="p">:</span> <span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
|
|
|
|
|
@@ -572,9 +578,9 @@ True
|
|
|
|
|
<span class="n">benchmark</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">show_plots</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
<img alt="matmul-performance" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
|
|
|
|
<img alt="03 matrix multiplication" class="sphx-glr-single-img" src="../../_images/sphx_glr_03-matrix-multiplication_001.png" />
|
|
|
|
|
<p>As we can see, the performance of our kernel is pretty good. It is in fact faster than CUTLASS, and therefore probably comparable to the absolute best CUDA code an expert could write.</p>
|
|
|
|
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 6.502 seconds)</p>
|
|
|
|
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 5.861 seconds)</p>
|
|
|
|
|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
|
|
|
|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
|
|
|
|
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
|
|
|
|
|