From ad005d49acf91e9f9e8fccd21bffe461117bc856 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 31 Jan 2021 21:23:01 -0500 Subject: [PATCH] [PYTHON] Added benchmark code for CUTLASS --- python/tests/test_matmul.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py index 2857cde2f..fb48ce311 100644 --- a/python/tests/test_matmul.py +++ b/python/tests/test_matmul.py @@ -83,10 +83,11 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50): def perf_op(dtype=th.float16, warmup=10, rep=50): import pandas as pd + import os AT, BT = False, False - df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH']) - # Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192] - Ns = [8192] + has_cutlass = 'CUTLASS_PROFILER' in os.environ + df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS']) + Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144] configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns] for AT, BT, M, N, K in configs: a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 @@ -100,8 +101,25 @@ def perf_op(dtype=th.float16, warmup=10, rep=50): num_flops = 2*M*N*K torch_tflops = num_flops / torch_ms * 1e-9 triton_tflops = num_flops / triton_ms * 1e-9 - #print(min(alpha*bandwidth*1e-12, max_tflops), triton_tflops) - #./tools/profiler/cutlass_profiler --m=8192 --n=8192 --k=8192 --A=f16:column --B=f16:column --C=f16:column --accum=f32 --operation=gemm - df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops}, ignore_index=True) + if 'CUTLASS_PROFILER' in os.environ: + import subprocess + # run program specified by CUTLASS_PROFILER env variable + layout_a = 'column' if AT else 'row' + layout_b = 'column' if BT else 'row' + # create temporary file name + import tempfile + fd, fname = tempfile.mkstemp() + # run program and gets its output + cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ + '--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \ + '--profiling-iterations=50', f'--output={fname}', '--verbose=false'] + # run cmd + subprocess.run(cmd, stdout=subprocess.PIPE) + # read CSV output + df_c = pd.read_csv(f'{fname}.gemm.csv') + cutlass_tflops = max(df_c['GFLOPs'])/1e3 + else: + cutlass_tflops = None + df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True) pd.options.display.float_format = lambda x: '{:.2f}'.format(x) print(df) \ No newline at end of file