[CODEGEN] Fixed bug that caused conditional operator to not always

properly mask load operations

Also includes minor improvement to benchmarking infrastructure
This commit is contained in:
Philippe Tillet
2021-03-07 14:53:48 -05:00
parent d1d09566b1
commit 5b9afaa688
9 changed files with 146 additions and 64 deletions

View File

@@ -2,33 +2,39 @@ import triton
import torch
import os
def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps)
ret = (ret.int() + div - 1) // div * div
ret = torch.unique(ret)
return list(map(int, ret))
# Square benchmarks
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 17, 128),
x_vals=rounded_linspace(512, 8192, 32, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
ylabel="TFLOPS",
loglog=False,
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False, True] for BT in [False, True]
args={
"AT": AT,
"BT": BT,
"dtype": torch.float16
},
) for AT in [False] for BT in [False]
]
# Transformer training benchmarks
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 33, 128),
x_vals = rounded_linspace(NK//16, NK, 32, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
@@ -41,21 +47,21 @@ transformer_confs = [
for M in [2048]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
@triton.testing.perf_report(transformer_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t()
if BT: b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess
import tempfile
@@ -87,6 +93,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops
tflops = max(df_c["GFLOPs"]) / 1e3
return tflops
return None