[PYTHON] Changed benchmarking strategy. Instead of enqueueing many
kernels before synchronizing, the kernels are now enqueued one by one. This makes it possible to clear the L2 cache before running the workload, and also potentially collect some variance data for error bars in plots
This commit is contained in:
@@ -1,18 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def sparsify_tensor(x, mask, block):
|
def sparsify_tensor(x, mask, block):
|
||||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def mask_tensor(x, mask, block, value=0):
|
def mask_tensor(x, mask, block, value=0):
|
||||||
ret = x.clone()
|
ret = x.clone()
|
||||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def allclose(x, y):
|
def allclose(x, y):
|
||||||
assert x.dtype == y.dtype
|
assert x.dtype == y.dtype
|
||||||
diff = abs(x - y)
|
diff = abs(x - y)
|
||||||
@@ -22,22 +25,37 @@ def allclose(x, y):
|
|||||||
err = torch.max(diff) / torch.max(x_max, y_max)
|
err = torch.max(diff) / torch.max(x_max, y_max)
|
||||||
return err < tol
|
return err < tol
|
||||||
|
|
||||||
def do_bench(fn, flops=0, warmup=10, rep=50, grad_to_none=None):
|
|
||||||
|
def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False):
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
# warmup to put the clock in a stable regime
|
||||||
ret = fn()
|
ret = fn()
|
||||||
for i in range(warmup):
|
for i in range(warmup):
|
||||||
fn()
|
fn()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start_event.record()
|
total_ms = 0
|
||||||
|
# We maintain a buffer of 256 MB that we clear
|
||||||
|
# before each kernel call to make sure that the L2
|
||||||
|
# doesn't contain any input data before the run
|
||||||
|
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||||
for i in range(rep):
|
for i in range(rep):
|
||||||
|
# we don't want `fn` to accumulate gradient values
|
||||||
|
# if it contains a backward pass. So we clear the
|
||||||
|
# provided gradients
|
||||||
if grad_to_none is not None:
|
if grad_to_none is not None:
|
||||||
grad_to_none.grad = None
|
grad_to_none.grad = None
|
||||||
|
# reset L2
|
||||||
|
cache.zero_()
|
||||||
|
# record time of `fn`
|
||||||
|
start_event.record()
|
||||||
fn()
|
fn()
|
||||||
end_event.record()
|
end_event.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
time_ms = start_event.elapsed_time(end_event) / rep
|
total_ms += start_event.elapsed_time(end_event)
|
||||||
return time_ms
|
# return the average runtime of `fn`
|
||||||
|
return total_ms / rep
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
||||||
@@ -51,6 +69,7 @@ class Benchmark:
|
|||||||
self.plot_name = plot_name
|
self.plot_name = plot_name
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
class Mark:
|
class Mark:
|
||||||
def __init__(self, fn, benchmarks):
|
def __init__(self, fn, benchmarks):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
@@ -85,6 +104,7 @@ class Mark:
|
|||||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||||
html.write("</body></html>\n")
|
html.write("</body></html>\n")
|
||||||
|
|
||||||
|
|
||||||
def perf_report(benchmarks):
|
def perf_report(benchmarks):
|
||||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
Reference in New Issue
Block a user