diff --git a/python/triton/testing.py b/python/triton/testing.py index 510375b08..c31ebe4dd 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,18 +1,21 @@ import torch import os + def sparsify_tensor(x, mask, block): ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] return ret + def mask_tensor(x, mask, block, value=0): ret = x.clone() for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value return ret + def allclose(x, y): assert x.dtype == y.dtype diff = abs(x - y) @@ -22,22 +25,37 @@ def allclose(x, y): err = torch.max(diff) / torch.max(x_max, y_max) return err < tol -def do_bench(fn, flops=0, warmup=10, rep=50, grad_to_none=None): + +def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + # warmup to put the clock in a stable regime ret = fn() for i in range(warmup): fn() torch.cuda.synchronize() - start_event.record() + total_ms = 0 + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') for i in range(rep): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients if grad_to_none is not None: grad_to_none.grad = None + # reset L2 + cache.zero_() + # record time of `fn` + start_event.record() fn() - end_event.record() - torch.cuda.synchronize() - time_ms = start_event.elapsed_time(end_event) / rep - return time_ms + end_event.record() + torch.cuda.synchronize() + total_ms += start_event.elapsed_time(end_event) + # return the average runtime of `fn` + return total_ms / rep + class Benchmark: def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): @@ -51,6 +69,7 @@ class Benchmark: self.plot_name = plot_name self.args = args + class Mark: def __init__(self, fn, benchmarks): self.fn = fn @@ -85,6 +104,7 @@ class Mark: html.write(f"\n") html.write("\n") + def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) return wrapper