[DOCS] Various improvements and typo fixes

This commit is contained in:
Philippe Tillet
2021-03-29 02:35:13 -04:00
committed by Philippe Tillet
parent 3f6ba1020d
commit 1fdb465b71
17 changed files with 207 additions and 99 deletions

View File

@@ -17,15 +17,11 @@ square_confs = [
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 32, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
y_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={
"AT": AT,
"BT": BT,
"dtype": torch.float16
},
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False] for BT in [False]
]
@@ -35,8 +31,8 @@ transformer_confs = [
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
y_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
@@ -46,7 +42,7 @@ transformer_confs = [
]
@triton.testing.perf_report(transformer_confs)
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
@@ -54,7 +50,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
if BT: b = b.t()
num_flops = 2 * M * N * K
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "torch":
if provider == "cublas":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "triton":

View File

@@ -59,7 +59,7 @@ class CMakeBuild(build_ext):
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
cmake_args = [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
@@ -68,7 +68,7 @@ class CMakeBuild(build_ext):
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs])
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
]
# configuration
cfg = "Debug" if self.debug else "Release"

View File

@@ -1,32 +1,44 @@
import os
import struct
from typing import Optional, Dict, List
from typing import Optional, Dict, List, Callable
import torch
# C bindings
import triton._C.libtriton.triton as _triton
codes = {
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
_triton.runtime.arg_type.int1: 'B',
_triton.runtime.arg_type.int8: 'B',
_triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q',
_triton.runtime.arg_type.half: 'H',
_triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd',
_triton.runtime.arg_type.buffer: 'P',
}
def th_to_triton(obj):
""" Convert a `torch.dtype` to a Triton-C type string. """
tys = {
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
torch.float16: 'half', torch.float32: 'float', torch.float64: 'double'
torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double',
}
if isinstance(obj, torch.dtype):
return tys[obj]
return str(obj)
def cdiv(a, b):
def cdiv(a: int, b: int) -> int:
""" Ceil division (a + b - 1) // b"""
return (a + b - 1) // b
def read(path, kernel_names: Optional[List] = None):
def read(path: str, kernel_names: Optional[List] = None) -> str:
""" Extracts the source code for `kernel_names` from the given `path` file."""
if kernel_names is None:
kernel_names = []
with open(path, 'r') as f:
@@ -39,19 +51,31 @@ config = _triton.runtime.config
class kernel:
"""
A class used to represent a Triton kernel.
"""
def __init__(
self,
src,
device,
src: str,
device: torch.device,
defines: Optional[Dict] = None,
num_warps: int = 4,
autotune_vals: Optional[List] = None,
autotune_configs: Optional[List] = None,
autotune_key: Optional[List] = None
):
"""
:param src: The source code of the kernel.
:param device: The device to compile the kernel for.
:param defines: A dictionary of preprocessor #define for the compiler.
:param num_warps: Optimization flag for the compiler's internal auto-parallelization engine.
:param autotune_configs: A list of triton.config objects for the autotuner to try.
:param autotune_key: A list of kernel argument names whose change in value should trigger the autotuner to re-run.
"""
if defines is None:
defines = {}
if autotune_vals is None:
autotune_vals = []
if autotune_configs is None:
autotune_configs = []
if autotune_key is None:
autotune_key = []
# check if src is empty
@@ -74,11 +98,17 @@ class kernel:
self.opt = _triton.runtime.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# autotune_vals = [({}, 4)]
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key)
# autotune_configs = [({}, 4)]
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_configs, autotune_key)
self.tys = ''.join([codes[x] for x in self.fn.signature()])
def __call__(self, *args, grid):
def __call__(self, *args, grid: Callable[[_triton.runtime.options], tuple]):
"""
Runs the kernel on the given arguments and launch grid.
:param args: The arguments to the kernel in the orders that they appear in the Triton-C source.
:param grid: The launch grid for the kernel, i.e., callable that transform compilation options into a tuple of at most 3 integers.
:return: None
"""
# make sure that the executing thread is on the right device
torch.cuda.set_device(self.device_id)
# pack parameters into a byte buffer

View File

@@ -2,6 +2,7 @@ import torch
import triton
import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
@@ -83,7 +84,7 @@ class _matmul(torch.autograd.Function):
_matmul.src,
device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_configs=_matmul._CONFIGS,
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key]
@@ -93,24 +94,8 @@ class _matmul(torch.autograd.Function):
locks = _matmul._locks[device]
# enqueue
alpha = 1.0
args = [
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
alpha,
M,
N,
K,
lda,
ldb,
ldc,
locks.data_ptr(),
]
grid = lambda opt: [
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
1,
opt.SPLITK,
]
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK]
kernel(*args, grid=grid)
return c
@@ -119,4 +104,5 @@ class _matmul(torch.autograd.Function):
c = _matmul._call(a, b)
return c
matmul = _matmul.apply

View File

@@ -108,9 +108,10 @@ class Benchmark:
y_name,
y_vals,
y_lines,
ylabel,
plot_name,
args,
xlabel='',
ylabel='',
x_log=False,
y_log=False,
):
@@ -121,6 +122,8 @@ class Benchmark:
self.y_vals = y_vals
self.y_lines = y_lines
self.y_log = y_log
# plot info
self.xlabel = xlabel
self.ylabel = ylabel
self.plot_name = plot_name
self.args = args
@@ -131,7 +134,7 @@ class Mark:
self.fn = fn
self.benchmarks = benchmarks
def _run(self, bench, save_path, show_plots):
def _run(self, bench, save_path, show_plots, print_data):
import matplotlib.pyplot as plt
import pandas as pd
import os
@@ -155,7 +158,6 @@ class Mark:
if bench.plot_name:
plt.figure()
ax = plt.subplot()
xlabel = " = ".join(bench.x_names)
x = bench.x_names[0]
for y in bench.y_lines:
y_min, y_max = df[y + '-min'], df[y + '-max']
@@ -163,27 +165,30 @@ class Mark:
if y_min is not None and y_max is not None:
ax.fill_between(df[x], y_min, y_max, alpha=0.5)
ax.legend()
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
ax.set_xlabel(xlabel)
ax.set_ylabel(bench.ylabel)
ax.set_title(bench.plot_name)
#ax.set_title(bench.plot_name)
ax.set_xscale("log" if bench.x_log else "linear")
ax.set_yscale("log" if bench.y_log else "linear")
if show_plots:
plt.show()
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[bench.x_names[0]] + bench.y_lines]
if print_data:
print(df)
if save_path:
df = df[[bench.x_names[0]] + bench.y_lines]
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
def run(self, show_plots=False, save_path=''):
def run(self, show_plots=False, print_data=False, save_path=''):
has_single_bench = isinstance(self.benchmarks, Benchmark)
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
if save_path:
html = open(os.path.join(save_path, "results.html"), "w")
html.write("<html><body>\n")
for bench in benchmarks:
self._run(bench, save_path, show_plots)
self._run(bench, save_path, show_plots, print_data)
if save_path:
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
if save_path:

View File

@@ -229,7 +229,13 @@ def make_kernel(device, dtype):
cache = make_kernel.cache
if key not in cache:
defines = {'TYPE': dtype}
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
cache[key] = triton.kernel(
src,
device=device,
defines=defines,
autotune_configs=autotune_configs,
autotune_key=autotune_key,
)
return cache[key]
@@ -319,7 +325,7 @@ print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
# .. code-block:: bash
#
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
# pip uninstall -y triton
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
#
@@ -343,8 +349,8 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
y_name='provider', # argument name whose value corresponds to a different line in the plot
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
args={}
@@ -353,7 +359,7 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'torch':
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))