2021-02-08 12:16:41 -08:00
|
|
|
import torch
|
2021-03-04 01:51:11 -05:00
|
|
|
import os
|
2021-02-08 12:16:41 -08:00
|
|
|
|
2021-03-06 22:02:18 -05:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
def sparsify_tensor(x, mask, block):
|
2021-02-26 02:37:05 -05:00
|
|
|
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
2021-02-08 12:16:41 -08:00
|
|
|
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
2021-02-26 02:37:05 -05:00
|
|
|
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
2021-02-08 12:16:41 -08:00
|
|
|
return ret
|
|
|
|
|
2021-03-06 22:02:18 -05:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
def mask_tensor(x, mask, block, value=0):
|
|
|
|
ret = x.clone()
|
|
|
|
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
2021-02-26 02:37:05 -05:00
|
|
|
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
2021-02-08 12:16:41 -08:00
|
|
|
return ret
|
|
|
|
|
2021-03-06 22:02:18 -05:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
def allclose(x, y):
|
|
|
|
assert x.dtype == y.dtype
|
2021-02-21 15:19:39 -08:00
|
|
|
diff = abs(x - y)
|
|
|
|
x_max = torch.max(x)
|
|
|
|
y_max = torch.max(y)
|
|
|
|
tol = 1e-2
|
|
|
|
err = torch.max(diff) / torch.max(x_max, y_max)
|
|
|
|
return err < tol
|
|
|
|
|
2021-03-06 22:02:18 -05:00
|
|
|
|
|
|
|
def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False):
|
2021-02-08 12:16:41 -08:00
|
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
2021-03-06 22:02:18 -05:00
|
|
|
# warmup to put the clock in a stable regime
|
2021-02-08 12:16:41 -08:00
|
|
|
ret = fn()
|
|
|
|
for i in range(warmup):
|
|
|
|
fn()
|
|
|
|
torch.cuda.synchronize()
|
2021-03-06 22:02:18 -05:00
|
|
|
total_ms = 0
|
|
|
|
# We maintain a buffer of 256 MB that we clear
|
|
|
|
# before each kernel call to make sure that the L2
|
|
|
|
# doesn't contain any input data before the run
|
|
|
|
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
2021-02-08 12:16:41 -08:00
|
|
|
for i in range(rep):
|
2021-03-06 22:02:18 -05:00
|
|
|
# we don't want `fn` to accumulate gradient values
|
|
|
|
# if it contains a backward pass. So we clear the
|
|
|
|
# provided gradients
|
2021-02-26 02:37:05 -05:00
|
|
|
if grad_to_none is not None:
|
|
|
|
grad_to_none.grad = None
|
2021-03-06 22:02:18 -05:00
|
|
|
# reset L2
|
|
|
|
cache.zero_()
|
|
|
|
# record time of `fn`
|
|
|
|
start_event.record()
|
2021-02-08 12:16:41 -08:00
|
|
|
fn()
|
2021-03-06 22:02:18 -05:00
|
|
|
end_event.record()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
total_ms += start_event.elapsed_time(end_event)
|
|
|
|
# return the average runtime of `fn`
|
|
|
|
return total_ms / rep
|
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
|
|
|
|
class Benchmark:
|
2021-02-26 02:37:05 -05:00
|
|
|
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
2021-02-08 12:16:41 -08:00
|
|
|
self.x_names = x_names
|
|
|
|
self.x_vals = x_vals
|
|
|
|
self.y_name = y_name
|
|
|
|
self.y_vals = y_vals
|
|
|
|
self.y_lines = y_lines
|
|
|
|
self.ylabel = ylabel
|
|
|
|
self.loglog = loglog
|
|
|
|
self.plot_name = plot_name
|
|
|
|
self.args = args
|
|
|
|
|
2021-03-06 22:02:18 -05:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
class Mark:
|
|
|
|
def __init__(self, fn, benchmarks):
|
|
|
|
self.fn = fn
|
|
|
|
self.benchmarks = benchmarks
|
|
|
|
|
|
|
|
def _run(self, bench, result_path, with_plot):
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import pandas as pd
|
|
|
|
import os
|
2021-02-21 15:19:39 -08:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
|
|
|
|
for x in bench.x_vals:
|
|
|
|
x_args = {x_name: x for x_name in bench.x_names}
|
2021-02-26 02:37:05 -05:00
|
|
|
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
|
2021-02-08 12:16:41 -08:00
|
|
|
df.loc[len(df)] = [x] + row
|
|
|
|
if with_plot and bench.plot_name:
|
2021-02-21 15:19:39 -08:00
|
|
|
xlabel = " = ".join(bench.x_names)
|
2021-02-08 12:16:41 -08:00
|
|
|
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
|
|
|
|
plot.set_xlabel(xlabel)
|
|
|
|
plot.set_ylabel(bench.ylabel)
|
|
|
|
plot.set_title(bench.plot_name)
|
2021-02-21 15:19:39 -08:00
|
|
|
plot.set_xscale("log" if bench.loglog else "linear")
|
|
|
|
plot.set_yscale("log" if bench.loglog else "linear")
|
|
|
|
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
|
|
|
|
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
2021-02-08 12:16:41 -08:00
|
|
|
|
|
|
|
def run(self, result_path, with_plot):
|
2021-03-04 01:51:11 -05:00
|
|
|
with open(os.path.join(result_path, "results.html"), "w") as html:
|
|
|
|
html.write("<html><body>\n")
|
|
|
|
for bench in self.benchmarks:
|
|
|
|
self._run(bench, result_path, with_plot)
|
|
|
|
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
|
|
|
html.write("</body></html>\n")
|
2021-02-08 12:16:41 -08:00
|
|
|
|
2021-03-06 22:02:18 -05:00
|
|
|
|
2021-02-08 12:16:41 -08:00
|
|
|
def perf_report(benchmarks):
|
|
|
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
|
|
|
return wrapper
|