[DOCS] Improved plots in tutorials
This commit is contained in:
@@ -179,27 +179,32 @@ print(torch.allclose(y_tri, y_ref))
|
||||
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
M = 4096
|
||||
Ns = [256 * i for i in range(2, 50)]
|
||||
tri_bw = []
|
||||
ref_bw = []
|
||||
def_bw = []
|
||||
for N in Ns:
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
do_bench = lambda fn: gbps(triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True))
|
||||
tri_bw += [do_bench(lambda: softmax(x))]
|
||||
ref_bw += [do_bench(lambda: torch.softmax(x, axis=1))]
|
||||
def_bw += [do_bench(lambda: naive_softmax(x))]
|
||||
plt.xlabel('N')
|
||||
plt.ylabel('Bandwidth (GB/s)')
|
||||
plt.plot(Ns, tri_bw, label='Triton')
|
||||
plt.plot(Ns, ref_bw, label='Torch')
|
||||
plt.plot(Ns, def_bw, label='Naive')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'naive':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
# %%
|
||||
# In the above plot, we can see that:
|
||||
|
Reference in New Issue
Block a user