diff --git a/.ci/azure-pipelines.yml b/.ci/azure-pipelines.yml index 06d9c47fd..bfe07a8c1 100644 --- a/.ci/azure-pipelines.yml +++ b/.ci/azure-pipelines.yml @@ -30,7 +30,7 @@ steps: source $(venv)/bin/activate pip install matplotlib pandas cd python/bench - python -m run --with-plots + python -m run - publish: python/bench/results artifact: Benchmarks diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py index 519c16fe4..313cef108 100644 --- a/python/bench/bench_blocksparse.py +++ b/python/bench/bench_blocksparse.py @@ -14,7 +14,6 @@ square_confs = [ y_vals = [16, 32, 64], y_lines = ['Block16', 'Block32', 'Block64'], ylabel = 'TFLOPS', - loglog = False, plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', args = {'layout_mode': layout_mode, 'op_mode': op_mode, 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} @@ -65,7 +64,6 @@ square_confs = [ y_vals = [16, 32, 64], y_lines = ['Block16', 'Block32', 'Block64'], ylabel = 'GBPS', - loglog = False, plot_name = f'{layout_mode}-square', args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} )\ diff --git a/python/bench/bench_cross_entropy.py b/python/bench/bench_cross_entropy.py index 1a8b2189b..f8ac18a1d 100644 --- a/python/bench/bench_cross_entropy.py +++ b/python/bench/bench_cross_entropy.py @@ -9,7 +9,6 @@ confs = [ y_vals = ['triton', 'torch'], y_lines = ['Triton', 'Torch'], ylabel = 'GBPS', - loglog = False, plot_name = f'{mode}-2048', args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} )\ diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 3648657c2..ef4371052 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -20,7 +20,6 @@ square_confs = [ y_vals=["torch", "triton", "cutlass"], y_lines=["Torch", "Triton", "CUTLASS"], ylabel="TFLOPS", - loglog=False, plot_name=f"matmul-square-{nt[AT]}{nt[BT]}", args={ "AT": AT, @@ -39,17 +38,16 @@ transformer_confs = [ y_vals=["torch", "triton", "cutlass"], y_lines=["Torch", "Triton", "CUTLASS"], ylabel="TFLOPS", - loglog=False, plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} - ) for NK in [8192]\ + ) for NK in [12288]\ for i, x in enumerate(["N", "K"])\ for M in [2048] ] -@triton.testing.perf_report(square_confs) -def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50): +@triton.testing.perf_report(transformer_confs) +def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) if AT: a = a.t() diff --git a/python/bench/run.py b/python/bench/run.py index df52a07c4..c23884bb5 100644 --- a/python/bench/run.py +++ b/python/bench/run.py @@ -4,7 +4,8 @@ import os import inspect import triton -def run_all(result_dir, with_plots, names): + +def run_all(result_dir, names): if not os.path.exists(result_dir): os.makedirs(result_dir) for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))): @@ -26,16 +27,17 @@ def run_all(result_dir, with_plots, names): curr_dir = os.path.join(curr_dir, name.replace('bench_', '')) if not os.path.exists(curr_dir): os.makedirs(curr_dir) - bench.run(curr_dir, with_plots) + bench.run(save_path=curr_dir) + def main(args): parser = argparse.ArgumentParser(description="Run the benchmark suite.") parser.add_argument("-r", "--result-dir", type=str, default='results', required=False) parser.add_argument("-n", "--names", type=str, default='', required=False) - parser.add_argument("-p", "--with-plots", dest='with_plots', action='store_true') parser.set_defaults(feature=False) args = parser.parse_args(args) - run_all(args.result_dir, args.with_plots, args.names) + run_all(args.result_dir, args.names) + if __name__ == '__main__': main(sys.argv[1:]) diff --git a/python/triton/testing.py b/python/triton/testing.py index b3b64a498..6613e0bed 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -40,14 +40,29 @@ def allclose(x, y): return err < tol -def do_bench(fn, warmup=10, rep=50, grad_to_none=None): +def do_bench(fn, warmup=50, rep=50, grad_to_none=None, percentiles=[0.2, 0.8]): + # Estimate the runtime of the function + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 # We maintain a buffer of 256 MB that we clear # before each kernel call to make sure that the L2 # doesn't contain any input data before the run start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') - for i in range(warmup + rep): + # Warm-up + for _ in range(int(warmup / estimate_ms)): + fn() + # Benchmark + for i in range(rep): # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the # provided gradients @@ -56,29 +71,41 @@ def do_bench(fn, warmup=10, rep=50, grad_to_none=None): # we clear the L2 cache before each run cache.zero_() # record time of `fn` - if i >= warmup: - start_event[i - warmup].record() + start_event[i].record() fn() - if i >= warmup: - end_event[i - warmup].record() + end_event[i].record() torch.cuda.synchronize() times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) - q = torch.quantile(times, torch.tensor([0.1, 0.5, 0.9])) - min_ms = q[0].item() - mean_ms = q[1].item() - max_ms = q[2].item() - return mean_ms, min_ms, max_ms + percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist() + med_ms = torch.median(times).item() + if percentiles: + return tuple([med_ms] + percentiles) + else: + return med_ms class Benchmark: - def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): + def __init__( + self, + x_names, + x_vals, + y_name, + y_vals, + y_lines, + ylabel, + plot_name, + args, + x_log=False, + y_log=False, + ): self.x_names = x_names self.x_vals = x_vals + self.x_log = x_log self.y_name = y_name self.y_vals = y_vals self.y_lines = y_lines + self.y_log = y_log self.ylabel = ylabel - self.loglog = loglog self.plot_name = plot_name self.args = args @@ -88,7 +115,7 @@ class Mark: self.fn = fn self.benchmarks = benchmarks - def _run(self, bench, result_path, with_plot): + def _run(self, bench, save_path, show_plots): import matplotlib.pyplot as plt import pandas as pd import os @@ -109,7 +136,7 @@ class Mark: row_min += [y_min] row_max += [y_max] df.loc[len(df)] = [x] + row_mean + row_min + row_max - if with_plot and bench.plot_name: + if bench.plot_name: plt.figure() ax = plt.subplot() xlabel = " = ".join(bench.x_names) @@ -123,18 +150,27 @@ class Mark: ax.set_xlabel(xlabel) ax.set_ylabel(bench.ylabel) ax.set_title(bench.plot_name) - ax.set_xscale("log" if bench.loglog else "linear") - ax.set_yscale("log" if bench.loglog else "linear") - plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png")) - df = df[[bench.x_names[0]] + bench.y_lines] - df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + if save_path: + df = df[[bench.x_names[0]] + bench.y_lines] + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) - def run(self, result_path, with_plot): - with open(os.path.join(result_path, "results.html"), "w") as html: + def run(self, show_plots=False, save_path=''): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + if save_path: + html = open(os.path.join(save_path, "results.html"), "w") html.write("\n") - for bench in self.benchmarks: - self._run(bench, result_path, with_plot) + for bench in benchmarks: + self._run(bench, save_path, show_plots) + if save_path: html.write(f"\n") + if save_path: html.write("\n") diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index f04f1add5..39fdca2f0 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -147,33 +147,35 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch. # Benchmarking # -------------------------- # We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch. +# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op. +# for different problem sizes. -import matplotlib.pyplot as plt -# There are three tensors of 4N bytes each. So the bandwidth of a given kernel -# is 12N / time_ms * 1e-6 GB/s -gbps = lambda N, ms: 12 * N / ms * 1e-6 -# We want to benchmark small and large vector alike -sizes = [2**i for i in range(12, 25, 1)] -triton_bw = [] -torch_bw = [] -for N in sizes: - x = torch.rand(N, device='cuda', dtype=torch.float32) - y = torch.rand(N, device='cuda', dtype=torch.float32) - # Triton provide a do_bench utility function that can be used to benchmark - # arbitrary workloads. It supports a `warmup` parameter that is used to stabilize - # GPU clock speeds as well as a `rep` parameter that controls the number of times - # the benchmark is repeated. Importantly, we set `clear_l2 = True` to make sure - # that the L2 cache does not contain any element of x before each kernel call when - # N is small. - do_bench = lambda fn: gbps(N, triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True)) - triton_bw += [do_bench(lambda: add(x, y))] - torch_bw += [do_bench(lambda: x + y)] -# We plot the results as a semi-log -plt.semilogx(sizes, triton_bw, label='Triton') -plt.semilogx(sizes, torch_bw, label='Torch') -plt.legend() -plt.show() +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['size'], # argument names to use as an x-axis for the plot + x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` + x_log=True, # x axis is logarithmic + y_name='provider', # argument name whose value corresponds to a different line in the plot + y_vals=['torch', 'triton'], # possible keys for `y_name` + y_lines=["Torch", "Triton"], # label name for the lines + ylabel="GB/s", # label name for the y-axis + plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. + args={} # values for function arguments not in `x_names` and `y_name` + ) +) +def benchmark(size, provider): + x = torch.rand(size, device='cuda', dtype=torch.float32) + y = torch.rand(size, device='cuda', dtype=torch.float32) + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y)) + gbps = lambda ms: 12 * size / ms * 1e-6 + return gbps(ms), gbps(max_ms), gbps(min_ms) + # %% -# Seems like our simple element-wise operation operates at peak bandwidth. While this is a fairly low bar for a custom GPU programming language, this is a good start before we move to more advanced operations. \ No newline at end of file +# We can now run the decorated function above. Pass `show_plots=True` to see the plots and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data +benchmark.run(show_plots=True) \ No newline at end of file diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index af8ca44cf..4b2db3cd7 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -179,27 +179,32 @@ print(torch.allclose(y_tri, y_ref)) # Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. # We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. -import matplotlib.pyplot as plt -M = 4096 -Ns = [256 * i for i in range(2, 50)] -tri_bw = [] -ref_bw = [] -def_bw = [] -for N in Ns: +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], # argument names to use as an x-axis for the plot + x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name` + y_name='provider', # argument name whose value corresponds to a different line in the plot + y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name` + y_lines=["Torch", "Triton", 'Naive'], # label name for the lines + ylabel="GB/s", # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={'M': 4096} # values for function arguments not in `x_names` and `y_name` + ) +) +def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) - gbps = lambda ms: x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) - do_bench = lambda fn: gbps(triton.testing.do_bench(fn, warmup=10, rep=100, clear_l2=True)) - tri_bw += [do_bench(lambda: softmax(x))] - ref_bw += [do_bench(lambda: torch.softmax(x, axis=1))] - def_bw += [do_bench(lambda: naive_softmax(x))] -plt.xlabel('N') -plt.ylabel('Bandwidth (GB/s)') -plt.plot(Ns, tri_bw, label='Triton') -plt.plot(Ns, ref_bw, label='Torch') -plt.plot(Ns, def_bw, label='Naive') -plt.legend() -plt.show() + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x)) + if provider == 'naive': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x)) + gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +benchmark.run(show_plots=True) # %% # In the above plot, we can see that: