[CODEGEN] Fixed bug that caused conditional operator to not always

properly mask load operations

Also includes minor improvement to benchmarking infrastructure
This commit is contained in:
Philippe Tillet
2021-03-07 14:53:48 -05:00
parent d1d09566b1
commit 5b9afaa688
9 changed files with 146 additions and 64 deletions

View File

@@ -2,33 +2,39 @@ import triton
import torch
import os
def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps)
ret = (ret.int() + div - 1) // div * div
ret = torch.unique(ret)
return list(map(int, ret))
# Square benchmarks
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 17, 128),
x_vals=rounded_linspace(512, 8192, 32, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
ylabel="TFLOPS",
loglog=False,
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False, True] for BT in [False, True]
args={
"AT": AT,
"BT": BT,
"dtype": torch.float16
},
) for AT in [False] for BT in [False]
]
# Transformer training benchmarks
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 33, 128),
x_vals = rounded_linspace(NK//16, NK, 32, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
@@ -41,21 +47,21 @@ transformer_confs = [
for M in [2048]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
@triton.testing.perf_report(transformer_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t()
if BT: b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess
import tempfile
@@ -87,6 +93,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops
tflops = max(df_c["GFLOPs"]) / 1e3
return tflops
return None

View File

@@ -26,7 +26,7 @@ __global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs,
TYPE local_dn = *(dneg_logprobs + row);
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
// and we have -log(p[k]) stored, so this is easy
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0;
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0;
// selected_logit_idx is selected logit index for our token
bool find_one[TILE] = ((0 ... TILE) == local_ind);
intermediate = intermediate - ((TYPE[TILE])find_one);

View File

@@ -26,35 +26,34 @@ def allclose(x, y):
return err < tol
def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warmup to put the clock in a stable regime
ret = fn()
for i in range(warmup):
fn()
torch.cuda.synchronize()
total_ms = 0
def do_bench(fn, warmup=10, rep=50, grad_to_none=None):
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
for i in range(rep):
for i in range(warmup + rep):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
grad_to_none.grad = None
# reset L2
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event.record()
if i >= warmup:
start_event[i - warmup].record()
fn()
end_event.record()
torch.cuda.synchronize()
total_ms += start_event.elapsed_time(end_event)
# return the average runtime of `fn`
return total_ms / rep
if i >= warmup:
end_event[i - warmup].record()
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
q = torch.quantile(times, torch.tensor([0.1, 0.5, 0.9]))
min_ms = q[0].item()
mean_ms = q[1].item()
max_ms = q[2].item()
return mean_ms, min_ms, max_ms
class Benchmark:
@@ -79,22 +78,42 @@ class Mark:
import matplotlib.pyplot as plt
import pandas as pd
import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
y_mean = bench.y_lines
y_min = [f'{x}-min' for x in bench.y_lines]
y_max = [f'{x}-max' for x in bench.y_lines]
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
df.loc[len(df)] = [x] + row
row_mean, row_min, row_max = [], [], []
for y in bench.y_vals:
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args)
try:
y_mean, y_min, y_max = ret
except TypeError:
y_mean, y_min, y_max = ret, None, None
row_mean += [y_mean]
row_min += [y_min]
row_max += [y_max]
df.loc[len(df)] = [x] + row_mean + row_min + row_max
if with_plot and bench.plot_name:
plt.figure()
ax = plt.subplot()
xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name)
plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale("log" if bench.loglog else "linear")
x = bench.x_names[0]
for y in bench.y_lines:
y_min, y_max = df[y + '-min'], df[y + '-max']
ax.plot(df[x], df[y], label=y)
if y_min is not None and y_max is not None:
ax.fill_between(df[x], y_min, y_max, alpha=0.5)
ax.legend()
ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel)
ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.loglog else "linear")
ax.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
df = df[[bench.x_names[0]] + bench.y_lines]
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
def run(self, result_path, with_plot):
with open(os.path.join(result_path, "results.html"), "w") as html: