[DOCS] Improved plots in tutorials

This commit is contained in:
Philippe Tillet
2021-03-11 00:29:16 -05:00
parent eacbb73968
commit 50e58d73db
8 changed files with 122 additions and 82 deletions

View File

@@ -179,27 +179,32 @@ print(torch.allclose(y_tri, y_ref))
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
import matplotlib.pyplot as plt
M = 4096
Ns = [256 * i for i in range(2, 50)]
tri_bw = []
ref_bw = []
def_bw = []
for N in Ns:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
do_bench = lambda fn: gbps(triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True))
tri_bw += [do_bench(lambda: softmax(x))]
ref_bw += [do_bench(lambda: torch.softmax(x, axis=1))]
def_bw += [do_bench(lambda: naive_softmax(x))]
plt.xlabel('N')
plt.ylabel('Bandwidth (GB/s)')
plt.plot(Ns, tri_bw, label='Triton')
plt.plot(Ns, ref_bw, label='Torch')
plt.plot(Ns, def_bw, label='Naive')
plt.legend()
plt.show()
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True)
# %%
# In the above plot, we can see that: