[DOCS] Various improvements and typo fixes
This commit is contained in:
committed by
Philippe Tillet
parent
3f6ba1020d
commit
1fdb465b71
@@ -17,15 +17,11 @@ square_confs = [
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=rounded_linspace(512, 8192, 32, 128),
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
y_vals=["cublas", "triton", "cutlass"],
|
||||
y_lines=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={
|
||||
"AT": AT,
|
||||
"BT": BT,
|
||||
"dtype": torch.float16
|
||||
},
|
||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||
) for AT in [False] for BT in [False]
|
||||
]
|
||||
|
||||
@@ -35,8 +31,8 @@ transformer_confs = [
|
||||
x_names=[x],
|
||||
x_vals = rounded_linspace(NK//16, NK, 32, 128),
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
y_vals=["cublas", "triton", "cutlass"],
|
||||
y_lines=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
@@ -46,7 +42,7 @@ transformer_confs = [
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(transformer_confs)
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
@@ -54,7 +50,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
if BT: b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||
if provider == "torch":
|
||||
if provider == "cublas":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "triton":
|
||||
|
@@ -59,7 +59,7 @@ class CMakeBuild(build_ext):
|
||||
if not os.path.exists(llvm_build_dir):
|
||||
os.makedirs(llvm_build_dir)
|
||||
# python directories
|
||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
|
||||
cmake_args = [
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
@@ -68,7 +68,7 @@ class CMakeBuild(build_ext):
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs])
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
|
||||
]
|
||||
# configuration
|
||||
cfg = "Debug" if self.debug else "Release"
|
||||
|
@@ -1,32 +1,44 @@
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional, Dict, List
|
||||
from typing import Optional, Dict, List, Callable
|
||||
import torch
|
||||
# C bindings
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
codes = {
|
||||
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
|
||||
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
|
||||
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
|
||||
_triton.runtime.arg_type.int1: 'B',
|
||||
_triton.runtime.arg_type.int8: 'B',
|
||||
_triton.runtime.arg_type.int32: 'I',
|
||||
_triton.runtime.arg_type.int64: 'Q',
|
||||
_triton.runtime.arg_type.half: 'H',
|
||||
_triton.runtime.arg_type.float: 'f',
|
||||
_triton.runtime.arg_type.double: 'd',
|
||||
_triton.runtime.arg_type.buffer: 'P',
|
||||
}
|
||||
|
||||
|
||||
def th_to_triton(obj):
|
||||
""" Convert a `torch.dtype` to a Triton-C type string. """
|
||||
tys = {
|
||||
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
|
||||
torch.float16: 'half', torch.float32: 'float', torch.float64: 'double'
|
||||
torch.int8: 'char',
|
||||
torch.int16: 'short',
|
||||
torch.int32: 'int',
|
||||
torch.int64: 'long',
|
||||
torch.float16: 'half',
|
||||
torch.float32: 'float',
|
||||
torch.float64: 'double',
|
||||
}
|
||||
if isinstance(obj, torch.dtype):
|
||||
return tys[obj]
|
||||
return str(obj)
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
""" Ceil division (a + b - 1) // b"""
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def read(path, kernel_names: Optional[List] = None):
|
||||
def read(path: str, kernel_names: Optional[List] = None) -> str:
|
||||
""" Extracts the source code for `kernel_names` from the given `path` file."""
|
||||
if kernel_names is None:
|
||||
kernel_names = []
|
||||
with open(path, 'r') as f:
|
||||
@@ -39,19 +51,31 @@ config = _triton.runtime.config
|
||||
|
||||
|
||||
class kernel:
|
||||
"""
|
||||
A class used to represent a Triton kernel.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
src,
|
||||
device,
|
||||
src: str,
|
||||
device: torch.device,
|
||||
defines: Optional[Dict] = None,
|
||||
num_warps: int = 4,
|
||||
autotune_vals: Optional[List] = None,
|
||||
autotune_configs: Optional[List] = None,
|
||||
autotune_key: Optional[List] = None
|
||||
):
|
||||
"""
|
||||
:param src: The source code of the kernel.
|
||||
:param device: The device to compile the kernel for.
|
||||
:param defines: A dictionary of preprocessor #define for the compiler.
|
||||
:param num_warps: Optimization flag for the compiler's internal auto-parallelization engine.
|
||||
:param autotune_configs: A list of triton.config objects for the autotuner to try.
|
||||
:param autotune_key: A list of kernel argument names whose change in value should trigger the autotuner to re-run.
|
||||
"""
|
||||
|
||||
if defines is None:
|
||||
defines = {}
|
||||
if autotune_vals is None:
|
||||
autotune_vals = []
|
||||
if autotune_configs is None:
|
||||
autotune_configs = []
|
||||
if autotune_key is None:
|
||||
autotune_key = []
|
||||
# check if src is empty
|
||||
@@ -74,11 +98,17 @@ class kernel:
|
||||
self.opt = _triton.runtime.options()
|
||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||
self.opt.num_warps = num_warps
|
||||
# autotune_vals = [({}, 4)]
|
||||
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key)
|
||||
# autotune_configs = [({}, 4)]
|
||||
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_configs, autotune_key)
|
||||
self.tys = ''.join([codes[x] for x in self.fn.signature()])
|
||||
|
||||
def __call__(self, *args, grid):
|
||||
def __call__(self, *args, grid: Callable[[_triton.runtime.options], tuple]):
|
||||
"""
|
||||
Runs the kernel on the given arguments and launch grid.
|
||||
:param args: The arguments to the kernel in the orders that they appear in the Triton-C source.
|
||||
:param grid: The launch grid for the kernel, i.e., callable that transform compilation options into a tuple of at most 3 integers.
|
||||
:return: None
|
||||
"""
|
||||
# make sure that the executing thread is on the right device
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# pack parameters into a byte buffer
|
||||
|
@@ -2,6 +2,7 @@ import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
||||
|
||||
@@ -83,7 +84,7 @@ class _matmul(torch.autograd.Function):
|
||||
_matmul.src,
|
||||
device,
|
||||
defines=defines,
|
||||
autotune_vals=_matmul._CONFIGS,
|
||||
autotune_configs=_matmul._CONFIGS,
|
||||
autotune_key=["M", "N", "K"],
|
||||
)
|
||||
kernel = _matmul._kernels[key]
|
||||
@@ -93,24 +94,8 @@ class _matmul(torch.autograd.Function):
|
||||
locks = _matmul._locks[device]
|
||||
# enqueue
|
||||
alpha = 1.0
|
||||
args = [
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
alpha,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
locks.data_ptr(),
|
||||
]
|
||||
grid = lambda opt: [
|
||||
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
|
||||
1,
|
||||
opt.SPLITK,
|
||||
]
|
||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK]
|
||||
kernel(*args, grid=grid)
|
||||
return c
|
||||
|
||||
@@ -119,4 +104,5 @@ class _matmul(torch.autograd.Function):
|
||||
c = _matmul._call(a, b)
|
||||
return c
|
||||
|
||||
|
||||
matmul = _matmul.apply
|
||||
|
@@ -108,9 +108,10 @@ class Benchmark:
|
||||
y_name,
|
||||
y_vals,
|
||||
y_lines,
|
||||
ylabel,
|
||||
plot_name,
|
||||
args,
|
||||
xlabel='',
|
||||
ylabel='',
|
||||
x_log=False,
|
||||
y_log=False,
|
||||
):
|
||||
@@ -121,6 +122,8 @@ class Benchmark:
|
||||
self.y_vals = y_vals
|
||||
self.y_lines = y_lines
|
||||
self.y_log = y_log
|
||||
# plot info
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.plot_name = plot_name
|
||||
self.args = args
|
||||
@@ -131,7 +134,7 @@ class Mark:
|
||||
self.fn = fn
|
||||
self.benchmarks = benchmarks
|
||||
|
||||
def _run(self, bench, save_path, show_plots):
|
||||
def _run(self, bench, save_path, show_plots, print_data):
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
@@ -155,7 +158,6 @@ class Mark:
|
||||
if bench.plot_name:
|
||||
plt.figure()
|
||||
ax = plt.subplot()
|
||||
xlabel = " = ".join(bench.x_names)
|
||||
x = bench.x_names[0]
|
||||
for y in bench.y_lines:
|
||||
y_min, y_max = df[y + '-min'], df[y + '-max']
|
||||
@@ -163,27 +165,30 @@ class Mark:
|
||||
if y_min is not None and y_max is not None:
|
||||
ax.fill_between(df[x], y_min, y_max, alpha=0.5)
|
||||
ax.legend()
|
||||
xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(bench.ylabel)
|
||||
ax.set_title(bench.plot_name)
|
||||
#ax.set_title(bench.plot_name)
|
||||
ax.set_xscale("log" if bench.x_log else "linear")
|
||||
ax.set_yscale("log" if bench.y_log else "linear")
|
||||
if show_plots:
|
||||
plt.show()
|
||||
if save_path:
|
||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||
df = df[[bench.x_names[0]] + bench.y_lines]
|
||||
if print_data:
|
||||
print(df)
|
||||
if save_path:
|
||||
df = df[[bench.x_names[0]] + bench.y_lines]
|
||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||
|
||||
def run(self, show_plots=False, save_path=''):
|
||||
def run(self, show_plots=False, print_data=False, save_path=''):
|
||||
has_single_bench = isinstance(self.benchmarks, Benchmark)
|
||||
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
|
||||
if save_path:
|
||||
html = open(os.path.join(save_path, "results.html"), "w")
|
||||
html.write("<html><body>\n")
|
||||
for bench in benchmarks:
|
||||
self._run(bench, save_path, show_plots)
|
||||
self._run(bench, save_path, show_plots, print_data)
|
||||
if save_path:
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
if save_path:
|
||||
|
@@ -229,7 +229,13 @@ def make_kernel(device, dtype):
|
||||
cache = make_kernel.cache
|
||||
if key not in cache:
|
||||
defines = {'TYPE': dtype}
|
||||
cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
|
||||
cache[key] = triton.kernel(
|
||||
src,
|
||||
device=device,
|
||||
defines=defines,
|
||||
autotune_configs=autotune_configs,
|
||||
autotune_key=autotune_key,
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
|
||||
@@ -319,7 +325,7 @@ print(torch.allclose(c_0, c_1, rtol=1e-3, atol=1e-3))
|
||||
# .. code-block:: bash
|
||||
#
|
||||
# export CUTLASS_INCLUDE_DIR=/tmp/cutlass/build/install/include/
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/a
|
||||
# export CUTLASS_LIBRARY_DIR=/tmp/cutlass/build/install/lib/
|
||||
# pip uninstall -y triton
|
||||
# pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
#
|
||||
@@ -343,8 +349,8 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
||||
y_vals=['torch', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["Torch", "Triton", 'CUTLASS'], # label name for the lines
|
||||
y_vals=['cublas', 'triton', 'cutlass'], # possible keys for `y_name`
|
||||
y_lines=["cuBLAS", "Triton", 'CUTLASS'], # label name for the lines
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||
args={}
|
||||
@@ -353,7 +359,7 @@ print(torch.allclose(c_0, c_2, rtol=1e-3, atol=1e-3))
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'torch':
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: dot(a, b))
|
||||
|
Reference in New Issue
Block a user