Files
triton/python/triton/testing.py
Philippe Tillet b8f2875d28 [PYTHON] Changed benchmarking strategy. Instead of enqueueing many
kernels before synchronizing, the kernels are now  enqueued one by one.

This makes it possible to clear the L2 cache before running the
workload, and also potentially collect some variance data for error bars
in plots
2021-03-06 22:02:18 -05:00

111 lines
3.7 KiB
Python

import torch
import os
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def allclose(x, y):
assert x.dtype == y.dtype
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# warmup to put the clock in a stable regime
ret = fn()
for i in range(warmup):
fn()
torch.cuda.synchronize()
total_ms = 0
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
for i in range(rep):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
grad_to_none.grad = None
# reset L2
cache.zero_()
# record time of `fn`
start_event.record()
fn()
end_event.record()
torch.cuda.synchronize()
total_ms += start_event.elapsed_time(end_event)
# return the average runtime of `fn`
return total_ms / rep
class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
self.x_names = x_names
self.x_vals = x_vals
self.y_name = y_name
self.y_vals = y_vals
self.y_lines = y_lines
self.ylabel = ylabel
self.loglog = loglog
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench, result_path, with_plot):
import matplotlib.pyplot as plt
import pandas as pd
import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name:
xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name)
plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot):
with open(os.path.join(result_path, "results.html"), "w") as html:
html.write("<html><body>\n")
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
html.write("</body></html>\n")
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper