|
|
|
@@ -241,11 +241,9 @@ algorithm to multiply a (MxK) by a (KxN) matrix:</p>
|
|
|
|
|
<div class="section" id="compute-kernel">
|
|
|
|
|
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
|
|
|
|
<p>The above algorithm is, actually, fairly straightforward to implement in Triton.
|
|
|
|
|
The main difficulty comes from the computation of the memory locations at which blocks</p>
|
|
|
|
|
<blockquote>
|
|
|
|
|
<div><p>of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need</p>
|
|
|
|
|
</div></blockquote>
|
|
|
|
|
<p>multi-dimensional pointer arithmetics.</p>
|
|
|
|
|
The main difficulty comes from the computation of the memory locations at which blocks
|
|
|
|
|
of <code class="code docutils literal notranslate"><span class="pre">A</span></code> and <code class="code docutils literal notranslate"><span class="pre">B</span></code> must be read in the inner loop. For that, we need
|
|
|
|
|
multi-dimensional pointer arithmetics.</p>
|
|
|
|
|
<div class="section" id="pointer-arithmetics">
|
|
|
|
|
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
|
|
|
|
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
|
|
|
|
@@ -282,11 +280,9 @@ Therefore, blocks of pointers for <code class="code docutils literal notranslate
|
|
|
|
|
</div>
|
|
|
|
|
<div class="section" id="l2-cache-optimizations">
|
|
|
|
|
<h3>L2 Cache Optimizations<a class="headerlink" href="#l2-cache-optimizations" title="Permalink to this headline">¶</a></h3>
|
|
|
|
|
<dl class="simple">
|
|
|
|
|
<dt>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code></dt><dd><p>block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.</p>
|
|
|
|
|
</dd>
|
|
|
|
|
</dl>
|
|
|
|
|
<p>It is important to remember that the order in which these blocks are computed does
|
|
|
|
|
<p>As mentioned above, each program instance computes a <code class="code docutils literal notranslate"><span class="pre">[BLOCK_SIZE_M,</span> <span class="pre">BLOCK_SIZE_N]</span></code>
|
|
|
|
|
block of <code class="code docutils literal notranslate"><span class="pre">C</span></code>.
|
|
|
|
|
It is important to remember that the order in which these blocks are computed does
|
|
|
|
|
matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
|
|
|
|
a simple row-major ordering</p>
|
|
|
|
|
<blockquote>
|
|
|
|
@@ -313,17 +309,15 @@ switching to the next column:</p>
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
</div></blockquote>
|
|
|
|
|
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># For example, in the following matmul where each matrix is 9 blocks by 9 blocks,</span>
|
|
|
|
|
<span class="c1"># we can see that if we compute the output in row-major ordering, we need to load 90</span>
|
|
|
|
|
<span class="c1"># blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped</span>
|
|
|
|
|
<span class="c1"># ordering, we only need to load 54 blocks.</span>
|
|
|
|
|
<span class="c1"># .. image:: grouped_vs_row_major_ordering.png</span>
|
|
|
|
|
<span class="c1">#</span>
|
|
|
|
|
<span class="c1"># In practice, this can improve the performance of our matrix multiplication kernel by</span>
|
|
|
|
|
<span class="c1"># more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</span>
|
|
|
|
|
<span class="c1">#</span>
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
<p>For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
|
|
|
|
we can see that if we compute the output in row-major ordering, we need to load 90
|
|
|
|
|
blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
|
|
|
|
ordering, we only need to load 54 blocks.</p>
|
|
|
|
|
<blockquote>
|
|
|
|
|
<div><img alt="../../_images/grouped_vs_row_major_ordering.png" src="../../_images/grouped_vs_row_major_ordering.png" />
|
|
|
|
|
</div></blockquote>
|
|
|
|
|
<p>In practice, this can improve the performance of our matrix multiplication kernel by
|
|
|
|
|
more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</p>
|
|
|
|
|
</div>
|
|
|
|
|
</div>
|
|
|
|
|
<div class="section" id="final-result">
|
|
|
|
@@ -501,8 +495,8 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
|
|
|
|
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
|
|
|
|
<span class="n">triton_output</span> <span class="o">=</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
|
|
|
<span class="n">torch_output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">triton_output</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">torch_output</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"triton_output=</span><span class="si">{</span><span class="n">triton_output</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"torch_output=</span><span class="si">{</span><span class="n">torch_output</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
|
<span class="k">if</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">triton_output</span><span class="p">,</span> <span class="n">torch_output</span><span class="p">):</span>
|
|
|
|
|
<span class="nb">print</span><span class="p">(</span><span class="s2">"✅ Triton and Torch match"</span><span class="p">)</span>
|
|
|
|
|
<span class="k">else</span><span class="p">:</span>
|
|
|
|
@@ -582,41 +576,41 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
|
|
|
|
|
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
|
|
|
|
0 128.0 0.455111 ... 0.512000 0.512000
|
|
|
|
|
1 256.0 2.730667 ... 3.276800 2.978909
|
|
|
|
|
2 384.0 7.372800 ... 8.507077 8.507077
|
|
|
|
|
3 512.0 14.563555 ... 16.384000 15.420235
|
|
|
|
|
2 384.0 7.372800 ... 7.899428 7.899428
|
|
|
|
|
3 512.0 14.563555 ... 16.384000 16.384000
|
|
|
|
|
4 640.0 22.260869 ... 24.380953 24.380953
|
|
|
|
|
5 768.0 32.768000 ... 34.028308 34.028308
|
|
|
|
|
6 896.0 39.025776 ... 40.140799 35.123201
|
|
|
|
|
6 896.0 39.025776 ... 40.140799 35.150663
|
|
|
|
|
7 1024.0 49.932191 ... 52.428801 52.428801
|
|
|
|
|
8 1152.0 44.566925 ... 46.656000 45.938215
|
|
|
|
|
8 1152.0 44.566925 ... 46.656000 46.656000
|
|
|
|
|
9 1280.0 51.200001 ... 56.109587 56.109587
|
|
|
|
|
10 1408.0 64.138541 ... 64.902096 64.138541
|
|
|
|
|
11 1536.0 80.430545 ... 76.106321 75.296679
|
|
|
|
|
12 1664.0 63.372618 ... 62.492442 62.061463
|
|
|
|
|
12 1664.0 62.929456 ... 62.061463 62.061463
|
|
|
|
|
13 1792.0 72.983276 ... 69.810085 69.379162
|
|
|
|
|
14 1920.0 68.435645 ... 67.764707 69.818184
|
|
|
|
|
15 2048.0 73.584279 ... 75.234154 74.898285
|
|
|
|
|
16 2176.0 83.500614 ... 81.143743 78.916269
|
|
|
|
|
17 2304.0 68.056616 ... 73.501144 73.051599
|
|
|
|
|
18 2432.0 71.125224 ... 80.269900 80.963875
|
|
|
|
|
19 2560.0 77.833728 ... 76.920185 76.382283
|
|
|
|
|
20 2688.0 80.027544 ... 79.524227 82.284288
|
|
|
|
|
21 2816.0 83.392363 ... 79.587973 76.785575
|
|
|
|
|
22 2944.0 82.509987 ... 79.230573 79.993627
|
|
|
|
|
23 3072.0 81.589488 ... 83.761985 82.301023
|
|
|
|
|
24 3200.0 84.768213 ... 89.385477 89.012517
|
|
|
|
|
25 3328.0 80.617354 ... 80.707733 86.217120
|
|
|
|
|
26 3456.0 81.518272 ... 85.223646 82.183044
|
|
|
|
|
27 3584.0 84.033077 ... 93.564405 95.047985
|
|
|
|
|
28 3712.0 86.267139 ... 88.015279 89.194055
|
|
|
|
|
29 3840.0 84.874902 ... 88.402879 87.217666
|
|
|
|
|
30 3968.0 92.442373 ... 87.850207 87.347124
|
|
|
|
|
31 4096.0 93.531519 ... 85.926841 85.871865
|
|
|
|
|
14 1920.0 69.120002 ... 70.892307 69.120002
|
|
|
|
|
15 2048.0 73.584279 ... 74.898285 74.565406
|
|
|
|
|
16 2176.0 83.155572 ... 80.817862 79.855747
|
|
|
|
|
17 2304.0 68.446623 ... 72.828879 73.275679
|
|
|
|
|
18 2432.0 71.305746 ... 82.388456 81.908060
|
|
|
|
|
19 2560.0 78.019048 ... 77.283019 75.676673
|
|
|
|
|
20 2688.0 83.552988 ... 83.552988 83.922689
|
|
|
|
|
21 2816.0 81.827785 ... 77.330158 79.154642
|
|
|
|
|
22 2944.0 81.166173 ... 77.747321 79.483304
|
|
|
|
|
23 3072.0 79.863336 ... 82.661468 82.420822
|
|
|
|
|
24 3200.0 83.660130 ... 90.395483 85.906037
|
|
|
|
|
25 3328.0 83.226931 ... 87.368079 83.613586
|
|
|
|
|
26 3456.0 80.220468 ... 81.600781 83.459178
|
|
|
|
|
27 3584.0 87.466332 ... 92.887804 84.983685
|
|
|
|
|
28 3712.0 84.159518 ... 83.178475 83.666116
|
|
|
|
|
29 3840.0 83.591840 ... 84.228485 85.663823
|
|
|
|
|
30 3968.0 91.885495 ... 84.680037 84.154440
|
|
|
|
|
31 4096.0 89.181212 ... 90.260743 90.200084
|
|
|
|
|
|
|
|
|
|
[32 rows x 5 columns]
|
|
|
|
|
</pre></div>
|
|
|
|
|
</div>
|
|
|
|
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 30.126 seconds)</p>
|
|
|
|
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 29.710 seconds)</p>
|
|
|
|
|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
|
|
|
|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
|
|
|
|
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
|
|
|
|