[PYTHON] Added benchmark code for CUTLASS
This commit is contained in:
@@ -83,10 +83,11 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
|||||||
|
|
||||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import os
|
||||||
AT, BT = False, False
|
AT, BT = False, False
|
||||||
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH'])
|
has_cutlass = 'CUTLASS_PROFILER' in os.environ
|
||||||
# Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192]
|
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS'])
|
||||||
Ns = [8192]
|
Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144]
|
||||||
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
||||||
for AT, BT, M, N, K in configs:
|
for AT, BT, M, N, K in configs:
|
||||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||||
@@ -100,8 +101,25 @@ def perf_op(dtype=th.float16, warmup=10, rep=50):
|
|||||||
num_flops = 2*M*N*K
|
num_flops = 2*M*N*K
|
||||||
torch_tflops = num_flops / torch_ms * 1e-9
|
torch_tflops = num_flops / torch_ms * 1e-9
|
||||||
triton_tflops = num_flops / triton_ms * 1e-9
|
triton_tflops = num_flops / triton_ms * 1e-9
|
||||||
#print(min(alpha*bandwidth*1e-12, max_tflops), triton_tflops)
|
if 'CUTLASS_PROFILER' in os.environ:
|
||||||
#./tools/profiler/cutlass_profiler --m=8192 --n=8192 --k=8192 --A=f16:column --B=f16:column --C=f16:column --accum=f32 --operation=gemm
|
import subprocess
|
||||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops}, ignore_index=True)
|
# run program specified by CUTLASS_PROFILER env variable
|
||||||
|
layout_a = 'column' if AT else 'row'
|
||||||
|
layout_b = 'column' if BT else 'row'
|
||||||
|
# create temporary file name
|
||||||
|
import tempfile
|
||||||
|
fd, fname = tempfile.mkstemp()
|
||||||
|
# run program and gets its output
|
||||||
|
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||||
|
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', '--warmup-iterations=10', \
|
||||||
|
'--profiling-iterations=50', f'--output={fname}', '--verbose=false']
|
||||||
|
# run cmd
|
||||||
|
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||||
|
# read CSV output
|
||||||
|
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||||
|
cutlass_tflops = max(df_c['GFLOPs'])/1e3
|
||||||
|
else:
|
||||||
|
cutlass_tflops = None
|
||||||
|
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True)
|
||||||
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
||||||
print(df)
|
print(df)
|
Reference in New Issue
Block a user