Files
triton/python/triton/testing.py
Philippe Tillet 2f80a98776 [BUILD] Added automatic nightly build releases to pip in CI; removed build-time dependence on LLVM and PyTorch (#77)
Recently there has been more and more report about installation issues:

    - Installing Triton before upgrading pytorch can create some issues because Triton uses some torch headers

    - llvm-10-dev not available on some platform; llvm-11-dev not available on e.g. Ubuntu.
    absence of nightly builds

This PR should fix all these issues. Some CMake tricks are used to download and install llvm at build time. Triton Python bindings were modified to remove dependence on pytorch ops. Midnight CI job added to generate binary wheels for all Triton version and update them on pypi's new triton-nightly project.

This PR will also make it very easy to use LLVM forks in the future for whatever needs we have.
2021-07-27 12:38:49 -07:00

194 lines
6.5 KiB
Python

import torch
import os
try:
import triton._C.libtriton.cutlass as _cutlass
except ImportError:
_cutlass = None
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def cutlass_matmul(a, b):
if _cutlass is None:
raise RuntimeError("Cannot find cutlass library")
M, N = a.shape[0], b.shape[1]
Ka, Kb = a.shape[1], b.shape[0]
assert Ka == Kb
assert a.dtype == b.dtype
assert a.device == b.device
# allocate output
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
# run function
dtype = str(a.dtype).split('.')[-1]
_cutlass.matmul(a.data_ptr(), b.data_ptr(), c.data_ptr(), \
M, N, Ka,\
a.stride(0), a.stride(1),\
b.stride(0), b.stride(1),\
c.stride(0), c.stride(1),\
dtype, dtype, dtype,
a.device.index, torch.cuda.current_stream(a.device).cuda_stream)
return c
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def allclose(x, y):
assert x.dtype == y.dtype
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
# Estimate the runtime of the function
fn()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(int(warmup / estimate_ms)):
fn()
# Benchmark
for i in range(rep):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
grad_to_none.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
med_ms = torch.median(times).item()
if percentiles:
return tuple([med_ms] + percentiles)
else:
return med_ms
class Benchmark:
def __init__(
self,
x_names,
x_vals,
y_name,
y_vals,
y_lines,
ylabel,
plot_name,
args,
x_log=False,
y_log=False,
):
self.x_names = x_names
self.x_vals = x_vals
self.x_log = x_log
self.y_name = y_name
self.y_vals = y_vals
self.y_lines = y_lines
self.y_log = y_log
self.ylabel = ylabel
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots):
import matplotlib.pyplot as plt
import pandas as pd
import os
y_mean = bench.y_lines
y_min = [f'{x}-min' for x in bench.y_lines]
y_max = [f'{x}-max' for x in bench.y_lines]
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row_mean, row_min, row_max = [], [], []
for y in bench.y_vals:
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args)
try:
y_mean, y_min, y_max = ret
except TypeError:
y_mean, y_min, y_max = ret, None, None
row_mean += [y_mean]
row_min += [y_min]
row_max += [y_max]
df.loc[len(df)] = [x] + row_mean + row_min + row_max
if bench.plot_name:
plt.figure()
ax = plt.subplot()
xlabel = " = ".join(bench.x_names)
x = bench.x_names[0]
for y in bench.y_lines:
y_min, y_max = df[y + '-min'], df[y + '-max']
ax.plot(df[x], df[y], label=y)
if y_min is not None and y_max is not None:
ax.fill_between(df[x], y_min, y_max, alpha=0.5)
ax.legend()
ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel)
ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear")
ax.set_yscale("log" if bench.y_log else "linear")
if show_plots:
plt.show()
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
if save_path:
df = df[[bench.x_names[0]] + bench.y_lines]
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
def run(self, show_plots=False, save_path=''):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots)
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path:
html.write("</body></html>\n")
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper