[DOCS] Improved plots in tutorials
This commit is contained in:
@@ -30,7 +30,7 @@ steps:
|
||||
source $(venv)/bin/activate
|
||||
pip install matplotlib pandas
|
||||
cd python/bench
|
||||
python -m run --with-plots
|
||||
python -m run
|
||||
|
||||
- publish: python/bench/results
|
||||
artifact: Benchmarks
|
||||
|
@@ -14,7 +14,6 @@ square_confs = [
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
@@ -65,7 +64,6 @@ square_confs = [
|
||||
y_vals = [16, 32, 64],
|
||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{layout_mode}-square',
|
||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)\
|
||||
|
@@ -9,7 +9,6 @@ confs = [
|
||||
y_vals = ['triton', 'torch'],
|
||||
y_lines = ['Triton', 'Torch'],
|
||||
ylabel = 'GBPS',
|
||||
loglog = False,
|
||||
plot_name = f'{mode}-2048',
|
||||
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)\
|
||||
|
@@ -20,7 +20,6 @@ square_confs = [
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={
|
||||
"AT": AT,
|
||||
@@ -39,17 +38,16 @@ transformer_confs = [
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [8192]\
|
||||
) for NK in [12288]\
|
||||
for i, x in enumerate(["N", "K"])\
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||
@triton.testing.perf_report(transformer_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT: a = a.t()
|
||||
|
@@ -4,7 +4,8 @@ import os
|
||||
import inspect
|
||||
import triton
|
||||
|
||||
def run_all(result_dir, with_plots, names):
|
||||
|
||||
def run_all(result_dir, names):
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
|
||||
@@ -26,16 +27,17 @@ def run_all(result_dir, with_plots, names):
|
||||
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
|
||||
if not os.path.exists(curr_dir):
|
||||
os.makedirs(curr_dir)
|
||||
bench.run(curr_dir, with_plots)
|
||||
bench.run(save_path=curr_dir)
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
|
||||
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
|
||||
parser.add_argument("-n", "--names", type=str, default='', required=False)
|
||||
parser.add_argument("-p", "--with-plots", dest='with_plots', action='store_true')
|
||||
parser.set_defaults(feature=False)
|
||||
args = parser.parse_args(args)
|
||||
run_all(args.result_dir, args.with_plots, args.names)
|
||||
run_all(args.result_dir, args.names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
|
@@ -40,14 +40,29 @@ def allclose(x, y):
|
||||
return err < tol
|
||||
|
||||
|
||||
def do_bench(fn, warmup=10, rep=50, grad_to_none=None):
|
||||
def do_bench(fn, warmup=50, rep=50, grad_to_none=None, percentiles=[0.2, 0.8]):
|
||||
# Estimate the runtime of the function
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(5):
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||
# We maintain a buffer of 256 MB that we clear
|
||||
# before each kernel call to make sure that the L2
|
||||
# doesn't contain any input data before the run
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
for i in range(warmup + rep):
|
||||
# Warm-up
|
||||
for _ in range(int(warmup / estimate_ms)):
|
||||
fn()
|
||||
# Benchmark
|
||||
for i in range(rep):
|
||||
# we don't want `fn` to accumulate gradient values
|
||||
# if it contains a backward pass. So we clear the
|
||||
# provided gradients
|
||||
@@ -56,29 +71,41 @@ def do_bench(fn, warmup=10, rep=50, grad_to_none=None):
|
||||
# we clear the L2 cache before each run
|
||||
cache.zero_()
|
||||
# record time of `fn`
|
||||
if i >= warmup:
|
||||
start_event[i - warmup].record()
|
||||
start_event[i].record()
|
||||
fn()
|
||||
if i >= warmup:
|
||||
end_event[i - warmup].record()
|
||||
end_event[i].record()
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
q = torch.quantile(times, torch.tensor([0.1, 0.5, 0.9]))
|
||||
min_ms = q[0].item()
|
||||
mean_ms = q[1].item()
|
||||
max_ms = q[2].item()
|
||||
return mean_ms, min_ms, max_ms
|
||||
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
|
||||
med_ms = torch.median(times).item()
|
||||
if percentiles:
|
||||
return tuple([med_ms] + percentiles)
|
||||
else:
|
||||
return med_ms
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
||||
def __init__(
|
||||
self,
|
||||
x_names,
|
||||
x_vals,
|
||||
y_name,
|
||||
y_vals,
|
||||
y_lines,
|
||||
ylabel,
|
||||
plot_name,
|
||||
args,
|
||||
x_log=False,
|
||||
y_log=False,
|
||||
):
|
||||
self.x_names = x_names
|
||||
self.x_vals = x_vals
|
||||
self.x_log = x_log
|
||||
self.y_name = y_name
|
||||
self.y_vals = y_vals
|
||||
self.y_lines = y_lines
|
||||
self.y_log = y_log
|
||||
self.ylabel = ylabel
|
||||
self.loglog = loglog
|
||||
self.plot_name = plot_name
|
||||
self.args = args
|
||||
|
||||
@@ -88,7 +115,7 @@ class Mark:
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, result_path, with_plot):
|
||||
def _run(self, bench, save_path, show_plots):
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
@@ -109,7 +136,7 @@ class Mark:
|
||||
row_min += [y_min]
|
||||
row_max += [y_max]
|
||||
df.loc[len(df)] = [x] + row_mean + row_min + row_max
|
||||
if with_plot and bench.plot_name:
|
||||
if bench.plot_name:
|
||||
plt.figure()
|
||||
ax = plt.subplot()
|
||||
xlabel = " = ".join(bench.x_names)
|
||||
@@ -123,18 +150,27 @@ class Mark:
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(bench.ylabel)
|
||||
ax.set_title(bench.plot_name)
|
||||
ax.set_xscale("log" if bench.loglog else "linear")
|
||||
ax.set_yscale("log" if bench.loglog else "linear")
|
||||
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
|
||||
df = df[[bench.x_names[0]] + bench.y_lines]
|
||||
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
ax.set_xscale("log" if bench.x_log else "linear")
|
||||
ax.set_yscale("log" if bench.y_log else "linear")
|
||||
if show_plots:
|
||||
plt.show()
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
if save_path:
|
||||
df = df[[bench.x_names[0]] + bench.y_lines]
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
|
||||
def run(self, result_path, with_plot):
|
||||
with open(os.path.join(result_path, "results.html"), "w") as html:
|
||||
def run(self, show_plots=False, save_path=''):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in self.benchmarks:
|
||||
self._run(bench, result_path, with_plot)
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots)
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
html.write("</body></html>\n")
|
||||
|
||||
|
||||
|
@@ -147,33 +147,35 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
|
||||
# Benchmarking
|
||||
# --------------------------
|
||||
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
|
||||
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
|
||||
# for different problem sizes.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# There are three tensors of 4N bytes each. So the bandwidth of a given kernel
|
||||
# is 12N / time_ms * 1e-6 GB/s
|
||||
gbps = lambda N, ms: 12 * N / ms * 1e-6
|
||||
# We want to benchmark small and large vector alike
|
||||
sizes = [2**i for i in range(12, 25, 1)]
|
||||
triton_bw = []
|
||||
torch_bw = []
|
||||
for N in sizes:
|
||||
x = torch.rand(N, device='cuda', dtype=torch.float32)
|
||||
y = torch.rand(N, device='cuda', dtype=torch.float32)
|
||||
# Triton provide a do_bench utility function that can be used to benchmark
|
||||
# arbitrary workloads. It supports a `warmup` parameter that is used to stabilize
|
||||
# GPU clock speeds as well as a `rep` parameter that controls the number of times
|
||||
# the benchmark is repeated. Importantly, we set `clear_l2 = True` to make sure
|
||||
# that the L2 cache does not contain any element of x before each kernel call when
|
||||
# N is small.
|
||||
do_bench = lambda fn: gbps(N, triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True))
|
||||
triton_bw += [do_bench(lambda: add(x, y))]
|
||||
torch_bw += [do_bench(lambda: x + y)]
|
||||
# We plot the results as a semi-log
|
||||
plt.semilogx(sizes, triton_bw, label='Triton')
|
||||
plt.semilogx(sizes, torch_bw, label='Torch')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
||||
x_log=True, # x axis is logarithmic
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton"], # label name for the lines
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={} # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(size, provider):
|
||||
x = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
y = torch.rand(size, device='cuda', dtype=torch.float32)
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
|
||||
gbps = lambda ms: 12 * size / ms * 1e-6
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# %%
|
||||
# Seems like our simple element-wise operation operates at peak bandwidth. While this is a fairly low bar for a custom GPU programming language, this is a good start before we move to more advanced operations.
|
||||
# We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or
|
||||
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
|
||||
benchmark.run(show_plots=True)
|
@@ -179,27 +179,32 @@ print(torch.allclose(y_tri, y_ref))
|
||||
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
|
||||
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
M = 4096
|
||||
Ns = [256 * i for i in range(2, 50)]
|
||||
tri_bw = []
|
||||
ref_bw = []
|
||||
def_bw = []
|
||||
for N in Ns:
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
|
||||
ylabel="GB/s", # label name for the y-axis
|
||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, provider):
|
||||
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
|
||||
gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
do_bench = lambda fn: gbps(triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True))
|
||||
tri_bw += [do_bench(lambda: softmax(x))]
|
||||
ref_bw += [do_bench(lambda: torch.softmax(x, axis=1))]
|
||||
def_bw += [do_bench(lambda: naive_softmax(x))]
|
||||
plt.xlabel('N')
|
||||
plt.ylabel('Bandwidth (GB/s)')
|
||||
plt.plot(Ns, tri_bw, label='Triton')
|
||||
plt.plot(Ns, ref_bw, label='Torch')
|
||||
plt.plot(Ns, def_bw, label='Naive')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
|
||||
if provider == 'naive':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
|
||||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
|
||||
# %%
|
||||
# In the above plot, we can see that:
|
||||
|
Reference in New Issue
Block a user