[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -2,58 +2,74 @@ import triton
|
||||
import torch
|
||||
|
||||
# square benchmarks
|
||||
nt = {False: 'n', True: 't'}
|
||||
nt = {False: "n", True: "t"}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'provider',
|
||||
y_vals = ['torch', 'triton', 'cutlass'],
|
||||
y_lines = ['Torch', 'Triton', 'CUTLASS'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
|
||||
)\
|
||||
for AT in [False, True] for BT in [False, True]
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=[512 * i for i in range(1, 16)],
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||
) for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
import os
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
if provider == 'torch':
|
||||
if provider == "torch":
|
||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
return torch_tflops
|
||||
if provider == 'triton':
|
||||
if provider == "triton":
|
||||
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
return triton_tflops
|
||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
||||
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = 'column' if AT else 'row'
|
||||
layout_b = 'column' if BT else 'row'
|
||||
layout_a = "column" if AT else "row"
|
||||
layout_b = "column" if BT else "row"
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
|
||||
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
|
||||
cmd = [
|
||||
os.environ["CUTLASS_PROFILER"],
|
||||
f"--m={M}",
|
||||
f"--n={N}",
|
||||
f"--k={K}",
|
||||
f"--A=f16:{layout_a}",
|
||||
f"--B=f16:{layout_b}",
|
||||
"--C=f16:column",
|
||||
"--accum=f32",
|
||||
"--operation=gemm",
|
||||
"--verification-enabled=false",
|
||||
f"--warmup-iterations={warmup}",
|
||||
f"--profiling-iterations={rep}",
|
||||
f"--output={fname}",
|
||||
"--verbose=false",
|
||||
]
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
|
||||
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
||||
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
bench_op.run()
|
||||
|
Reference in New Issue
Block a user