[GH-PAGES] Updated website
This commit is contained in:
@@ -250,45 +250,50 @@ Benchmarking
|
||||
Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 181-204
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 181-209
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
M = 4096
|
||||
Ns = [256 * i for i in range(2, 50)]
|
||||
tri_bw = []
|
||||
ref_bw = []
|
||||
def_bw = []
|
||||
for N in Ns:
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
do_bench = lambda fn: gbps(triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True))
|
||||
tri_bw += [do_bench(lambda: softmax(x))]
|
||||
ref_bw += [do_bench(lambda: torch.softmax(x, axis=1))]
|
||||
def_bw += [do_bench(lambda: naive_softmax(x))]
|
||||
plt.xlabel('N')
|
||||
plt.ylabel('Bandwidth (GB/s)')
|
||||
plt.plot(Ns, tri_bw, label='Triton')
|
||||
plt.plot(Ns, ref_bw, label='Torch')
|
||||
plt.plot(Ns, def_bw, label='Naive')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'naive':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
|
||||
|
||||
|
||||
.. image:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png
|
||||
:alt: 02 fused softmax
|
||||
:alt: softmax-performance
|
||||
:class: sphx-glr-single-img
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 205-210
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 210-215
|
||||
|
||||
In the above plot, we can see that:
|
||||
|
||||
@@ -300,7 +305,7 @@ In the above plot, we can see that:
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 33.773 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 21.653 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
Reference in New Issue
Block a user