[RUNTIME] Auto-tuning now works as expected when the values of
autotune_key change
This commit is contained in:
@@ -78,21 +78,30 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms, flops/time_ms*1e-9, ret
|
||||
return time_ms
|
||||
|
||||
|
||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
AT, BT = False, False
|
||||
import pandas as pd
|
||||
AT, BT = False, False
|
||||
df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH'])
|
||||
Ns = [128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192]
|
||||
# Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192]
|
||||
Ns = [8192]
|
||||
configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns]
|
||||
for AT, BT, M, N, K in configs:
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': TT_TFLOPS, 'TORCH': TH_TFLOPS}, ignore_index=True)
|
||||
# benchmarks
|
||||
torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep)
|
||||
triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep)
|
||||
# store result
|
||||
num_flops = 2*M*N*K
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
#print(min(alpha*bandwidth*1e-12, max_tflops), triton_tflops)
|
||||
#./tools/profiler/cutlass_profiler --m=8192 --n=8192 --k=8192 --A=f16:column --B=f16:column --C=f16:column --accum=f32 --operation=gemm
|
||||
df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops}, ignore_index=True)
|
||||
pd.options.display.float_format = lambda x: '{:.2f}'.format(x)
|
||||
print(df)
|
Reference in New Issue
Block a user