From cc84a476a3feca482d67accbb3c71beb09ccb3e9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 4 Feb 2021 15:35:53 -0500 Subject: [PATCH] [TESTS] test_matmul.py now plots benchmarks --- python/tests/test_matmul.py | 105 +++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 42 deletions(-) diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py index 9a3e71c3d..bc13a19c6 100644 --- a/python/tests/test_matmul.py +++ b/python/tests/test_matmul.py @@ -81,50 +81,71 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50): time_ms = start_event.elapsed_time(end_event) / rep return time_ms - -def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50): - import pandas as pd +def time_all(fn, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog=True, plot_name='', **kwargs): import matplotlib.pyplot as plt + import pandas as pd + df = pd.DataFrame(columns = [x_names[0]] + y_lines) + for x in x_vals: + x_args = {x_name: x for x_name in x_names} + row = [fn(**x_args, **{y_name: y}, **kwargs) for y in y_vals] + df.loc[len(df)] = [x] + row + print(df) + if plot_name: + df.plot(x=x_names[0], y=y_lines, ylabel=ylabel, xlabel=' = '.join(x_names), title=f'{plot_name}', loglog=loglog) + plt.savefig(f'{plot_name}.pdf') + +def perf_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50): import os - has_cutlass = 'CUTLASS_PROFILER' in os.environ - df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS']) - Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144] - configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns] - for AT, BT, M, N, K in configs: - a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 - b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 - if AT: a = a.t() - if BT: b = b.t() - # benchmarks + a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 + b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 + if AT: a = a.t() + if BT: b = b.t() + num_flops = 2*M*N*K + if provider == 'torch': torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep) - triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep) - # store result - num_flops = 2*M*N*K torch_tflops = num_flops / torch_ms * 1e-9 + return torch_tflops + if provider == 'triton': + triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep) triton_tflops = num_flops / triton_ms * 1e-9 - if 'CUTLASS_PROFILER' in os.environ: - import subprocess - # run program specified by CUTLASS_PROFILER env variable - layout_a = 'column' if AT else 'row' - layout_b = 'column' if BT else 'row' - # create temporary file name - import tempfile - fd, fname = tempfile.mkstemp() - # run program and gets its output - cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ - '--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \ - '--profiling-iterations=50', f'--output={fname}', '--verbose=false'] - # run cmd - subprocess.run(cmd, stdout=subprocess.PIPE) - # read CSV output - df_c = pd.read_csv(f'{fname}.gemm.csv') - cutlass_tflops = max(df_c['GFLOPs'])/1e3 - else: - cutlass_tflops = None - df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True) - # name - AT = {True: 'T', False: 'N'}[AT] - BT = {True: 'T', False: 'N'}[BT] - name = f'{AT}{BT}' - df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green']) - plt.savefig(f'matmul-{mode}-{name}.pdf') \ No newline at end of file + return triton_tflops + if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ: + import subprocess + import tempfile + import pandas as pd + # run program specified by CUTLASS_PROFILER env variable + layout_a = 'column' if AT else 'row' + layout_b = 'column' if BT else 'row' + # create temporary file name + fd, fname = tempfile.mkstemp() + # run program and gets its output + cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ + '--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \ + '--profiling-iterations=50', f'--output={fname}', '--verbose=false'] + # run cmd + subprocess.run(cmd, stdout=subprocess.PIPE) + # read CSV output + df_c = pd.read_csv(f'{fname}.gemm.csv') + cutlass_tflops = max(df_c['GFLOPs'])/1e3 + return cutlass_tflops + return None + +if __name__ == '__main__': + # # square + x_square = [128, 256, 512, 1024, 2048, 3072, 4096, 6144] + time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nn') + time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-nt') + time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = False, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tn') + time_all(perf_op, x_names = ['M', 'N', 'K'], x_vals = x_square, y_name = 'provider' , y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = True, BT = True, dtype = th.float16, loglog=False, plot_name = 'matmul-square-tt') + # tall-skinny + x_tall_skinny = [64, 96, 128, 160, 192, 256, 320, 384, 512, 768, 1024, 1536] + time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=2048, K=2048, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-2k-2k') + time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=4096, K=4096, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-4k-4k') + time_all(perf_op, x_names = ['M'], x_vals = x_tall_skinny, y_name = 'provider', y_vals = ['torch', 'triton', 'cutlass'], + ylabel = 'TFLOPS', y_lines = ['Torch', 'Triton', 'CUTLASS'], AT = False, BT = False, N=6144, K=6144, dtype = th.float16, loglog=False, plot_name = 'matmul-tall-skinny-6k-6k') \ No newline at end of file