[GH-PAGES] Updated website
This commit is contained in:
@@ -430,8 +430,8 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
|
||||
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||||
<span class="c1"># you can fuse arbitrary activation functions here</span>
|
||||
<span class="c1"># while the accumulator is still in FP32!</span>
|
||||
<span class="k">if</span> <span class="n">ACTIVATION</span><span class="p">:</span>
|
||||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">ACTIVATION</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">ACTIVATION</span> <span class="o">==</span> <span class="s2">"leaky_relu"</span><span class="p">:</span>
|
||||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">leaky_relu</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
|
||||
<span class="n">c</span> <span class="o">=</span> <span class="n">accumulator</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># -----------------------------------------------------------</span>
|
||||
@@ -452,7 +452,7 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
|
||||
</div>
|
||||
<p>We can now create a convenience wrapper function that only takes two input tensors
|
||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel</p>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">""</span><span class="p">):</span>
|
||||
<span class="c1"># checks constraints</span>
|
||||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">"incompatible dimensions"</span>
|
||||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(),</span> <span class="s2">"matrix A must be contiguous"</span>
|
||||
@@ -554,7 +554,7 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
|
||||
<span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">provider</span> <span class="o">==</span> <span class="s1">'triton + relu'</span><span class="p">:</span>
|
||||
<span class="n">ms</span><span class="p">,</span> <span class="n">min_ms</span><span class="p">,</span> <span class="n">max_ms</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">do_bench</span><span class="p">(</span>
|
||||
<span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">leaky_relu</span><span class="p">)</span>
|
||||
<span class="k">lambda</span><span class="p">:</span> <span class="n">matmul</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"leaky_relu"</span><span class="p">)</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">perf</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">ms</span><span class="p">:</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span> <span class="o">*</span> <span class="n">K</span> <span class="o">*</span> <span class="mf">1e-12</span> <span class="o">/</span> <span class="p">(</span><span class="n">ms</span> <span class="o">*</span> <span class="mf">1e-3</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">perf</span><span class="p">(</span><span class="n">ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">max_ms</span><span class="p">),</span> <span class="n">perf</span><span class="p">(</span><span class="n">min_ms</span><span class="p">)</span>
|
||||
@@ -567,42 +567,42 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
|
||||
<p class="sphx-glr-script-out">Out:</p>
|
||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 256.0 2.730667 ... 2.978909 2.978909
|
||||
1 384.0 7.372800 ... 7.899428 7.899428
|
||||
0 256.0 2.978909 ... 2.978909 2.978909
|
||||
1 384.0 7.372800 ... 8.507077 8.507077
|
||||
2 512.0 14.563555 ... 15.420235 15.420235
|
||||
3 640.0 22.260869 ... 24.380953 24.380953
|
||||
4 768.0 32.768000 ... 34.028308 34.028308
|
||||
5 896.0 39.025776 ... 40.140799 39.025776
|
||||
6 1024.0 49.932191 ... 53.773130 52.428801
|
||||
7 1152.0 45.242181 ... 47.396572 47.396572
|
||||
7 1152.0 45.242181 ... 48.161033 47.396572
|
||||
8 1280.0 51.200001 ... 57.690139 57.690139
|
||||
9 1408.0 64.138541 ... 68.147202 67.305878
|
||||
10 1536.0 80.430545 ... 80.430545 79.526831
|
||||
10 1536.0 80.430545 ... 81.355034 79.526831
|
||||
11 1664.0 63.372618 ... 63.372618 62.492442
|
||||
12 1792.0 72.983276 ... 73.460287 59.467852
|
||||
13 1920.0 68.776119 ... 71.626943 71.257735
|
||||
14 2048.0 73.908442 ... 78.398206 77.314362
|
||||
15 2176.0 83.500614 ... 87.494120 85.998493
|
||||
16 2304.0 68.446623 ... 78.064941 77.307030
|
||||
17 2432.0 71.125224 ... 86.179335 85.653855
|
||||
18 2560.0 77.833728 ... 82.331658 81.108913
|
||||
19 2688.0 83.737433 ... 91.185232 89.888756
|
||||
20 2816.0 83.233216 ... 84.441840 84.197315
|
||||
21 2944.0 81.564701 ... 83.758038 82.373605
|
||||
22 3072.0 82.540970 ... 89.593522 88.335577
|
||||
23 3200.0 83.989503 ... 95.096582 89.012517
|
||||
24 3328.0 82.464255 ... 82.939284 84.596116
|
||||
25 3456.0 81.932484 ... 90.994998 91.200871
|
||||
26 3584.0 87.127323 ... 99.354022 92.600816
|
||||
27 3712.0 84.159518 ... 89.353616 83.247783
|
||||
28 3840.0 85.136259 ... 93.484358 86.738820
|
||||
29 3968.0 92.302520 ... 87.976885 90.926929
|
||||
30 4096.0 91.741443 ... 90.933416 91.304576
|
||||
13 1920.0 68.776119 ... 71.257735 70.892307
|
||||
14 2048.0 73.584279 ... 78.033565 76.959706
|
||||
15 2176.0 83.155572 ... 87.494120 85.998493
|
||||
16 2304.0 68.446623 ... 78.320893 77.558029
|
||||
17 2432.0 71.305746 ... 86.711310 75.320281
|
||||
18 2560.0 77.833728 ... 82.747477 81.715711
|
||||
19 2688.0 83.552988 ... 90.532356 89.676257
|
||||
20 2816.0 83.552120 ... 84.035084 83.392363
|
||||
21 2944.0 81.832567 ... 83.758038 81.967162
|
||||
22 3072.0 82.540970 ... 89.877939 89.170242
|
||||
23 3200.0 84.321474 ... 96.822991 95.380032
|
||||
24 3328.0 83.034941 ... 85.806075 84.596116
|
||||
25 3456.0 82.183044 ... 91.928814 87.632137
|
||||
26 3584.0 87.381330 ... 92.696281 96.891584
|
||||
27 3712.0 84.694652 ... 87.244203 88.092894
|
||||
28 3840.0 85.136259 ... 88.900318 90.279183
|
||||
29 3968.0 88.008611 ... 92.547541 84.268854
|
||||
30 4096.0 93.368854 ... 87.781379 86.592080
|
||||
|
||||
[31 rows x 5 columns]
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 6 minutes 44.471 seconds)</p>
|
||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 6 minutes 21.318 seconds)</p>
|
||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||
|
Reference in New Issue
Block a user