From 85752037ebb82713d35ad618c277b638d1c1b084 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 6 Mar 2021 22:02:18 -0500 Subject: [PATCH] [PYTHON] Changed benchmarking strategy. Instead of enqueueing many kernels before synchronizing, the kernels are now enqueued one by one. This makes it possible to clear the L2 cache before running the workload, and also potentially collect some variance data for error bars in plots --- python/triton/testing.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index 510375b08..c31ebe4dd 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,18 +1,21 @@ import torch import os + def sparsify_tensor(x, mask, block): ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] return ret + def mask_tensor(x, mask, block, value=0): ret = x.clone() for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value return ret + def allclose(x, y): assert x.dtype == y.dtype diff = abs(x - y) @@ -22,22 +25,37 @@ def allclose(x, y): err = torch.max(diff) / torch.max(x_max, y_max) return err < tol -def do_bench(fn, flops=0, warmup=10, rep=50, grad_to_none=None): + +def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + # warmup to put the clock in a stable regime ret = fn() for i in range(warmup): fn() torch.cuda.synchronize() - start_event.record() + total_ms = 0 + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') for i in range(rep): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients if grad_to_none is not None: grad_to_none.grad = None + # reset L2 + cache.zero_() + # record time of `fn` + start_event.record() fn() - end_event.record() - torch.cuda.synchronize() - time_ms = start_event.elapsed_time(end_event) / rep - return time_ms + end_event.record() + torch.cuda.synchronize() + total_ms += start_event.elapsed_time(end_event) + # return the average runtime of `fn` + return total_ms / rep + class Benchmark: def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): @@ -51,6 +69,7 @@ class Benchmark: self.plot_name = plot_name self.args = args + class Mark: def __init__(self, fn, benchmarks): self.fn = fn @@ -85,6 +104,7 @@ class Mark: html.write(f"\n") html.write("\n") + def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) return wrapper