[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -67,8 +67,8 @@ def test_matmul(M, N, K):
import triton.language as tl
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta):
BLOCK_SIZE = meta['BLOCK_SIZE']
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -99,7 +99,7 @@ def test_elementwise(N):
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6

View File

@@ -40,7 +40,7 @@ def patch_kernel(template, to_replace):
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, **meta):
def kernel(X, SIZE: tl.constexpr):
pass
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
@@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = tl.arange(0, meta['SIZE'])
def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
@@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = tl.arange(0, meta['SIZE'])
def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
@@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'):
# Triton kernel
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
def kernel(Z, X, SIZE: tl.constexpr):
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
@@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
def kernel(X, Z):
pid = tl.program_id(0)
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
@@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
tl.store(Z, z)
# triton result
@@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X + tl.arange(0, meta['BLOCK']))
def kernel(X, Z, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device)
@@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
range_m = tl.arange(0, meta['BLOCK_M'])
range_n = tl.arange(0, meta['BLOCK_N'])
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
z = tl.sum(x, axis=meta['AXIS'])
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z)
# input
x = triton.testing.random(shape, dtype=dtype, device=device)
@@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn, **meta):
BLOCK_M = meta['BLOCK_M']
BLOCK_N = meta['BLOCK_N']
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
@@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'):
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn, **meta):
BLOCK_M = meta['BLOCK_M']
BLOCK_K = meta['BLOCK_K']
BLOCK_N = meta['BLOCK_N']
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
@@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'):
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys))
if meta['ADD_MATRIX']:
if ADD_MATRIX:
z += tl.load(Zs)
if meta['ADD_ROWS']:
if ADD_ROWS:
ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None]
if meta['ADD_COLS']:
if ADD_COLS:
ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :]
tl.store(Zs, z)
@@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'):
def test_dot_without_load():
@triton.jit
def kernel(out, **meta):
def kernel(out):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
@@ -538,9 +535,10 @@ def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit
def _kernel(z, **meta):
off = tl.arange(0, meta['BLOCK'])
val = tl.arange(meta['START'], meta['END'])
def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK)
val = tl.arange(START, END)
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
@@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
@triton.jit
def _kernel(in1_ptr, in2_ptr, output_ptr,
in_stride, in2_stride, out_stride,
in_numel, in2_numel, out_numel, **meta):
M = meta['M']
N = meta['N']
K = meta['K']
in_numel, in2_numel, out_numel,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
M_offsets = tl.arange(0, M)
N_offsets = tl.arange(0, N)
@@ -605,14 +601,13 @@ def test_load_cache_modifier(cache):
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst, src, **meta):
def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
x = tl.load(src+offsets, cache_modifier=CACHE)
tl.store(dst+offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
@@ -644,7 +639,7 @@ def test_load_cache_modifier(cache):
#----------------
def test_noop(device='cuda'):
@triton.jit
def kernel(**meta):
def kernel(x):
pass
x = triton.testing.random((1,), dtype=torch.int32, device=device)
kernel[(1, )](x)

View File

@@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
@@ -151,8 +151,8 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,

View File

@@ -66,8 +66,8 @@ import torch
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []

View File

@@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor):
arg_values = []
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = triton.language.core._to_ir(self.constants[i], self.builder)
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
if i in self.attributes:
@@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor):
fn.add_attr(i + 1, attr)
fn.args[i].name = arg_name
arg_values.append(fn.args[i])
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
@@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, node)
return node.arg
def visit_AnnAssign(self, node):
# extract attributes
annotation = self.visit(node.annotation)
target = self.visit(node.target)
value = self.visit(node.value)
# constexpr
if annotation == triton.language.constexpr:
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
self.lscope[target] = triton.language.constexpr(value)
return self.lscope[target]
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node):
_names = []
for target in node.targets:
@@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple):
values = [values]
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if not isinstance(value, triton.language.block):
value = triton.language.core._to_ir(value, self.builder)
self.set_value(name, value)
@@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
@@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor):
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
kws = dict()
if self.is_triton_object(lhs):
kws['_builder'] = self.builder
ret = getattr(lhs, fn)(rhs, **kws)
if ret is NotImplemented:
if self.is_triton_object(rhs):
kws['_builder'] = self.builder
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws)
return ret
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_If(self, node):
cond = self.visit(node.test)
@@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
@@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
if isinstance(op, triton.language.core.constexpr):
op = op.value
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor):
return fn(*args, **kws)
def visit_Num(self, node):
return node.n
return triton.language.constexpr(node.n)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
@@ -477,6 +505,8 @@ class Kernel:
}
if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr):
obj = obj.value
if isinstance(obj, int):
if abs(obj) <= 0xffffffff:
return 'I'
@@ -485,6 +515,8 @@ class Kernel:
return 'f'
if isinstance(obj, bool):
return 'B'
if isinstance(obj, str):
return 'str'
assert False
@@ -537,7 +569,8 @@ class Kernel:
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
@@ -547,7 +580,7 @@ class Kernel:
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.fn.parse())
except Exception as e:
@@ -566,7 +599,19 @@ class Kernel:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# handle arguments passed by name
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
if len(wargs) != len(self.fn.arg_names):
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
# handle annotations
for name, type in self.fn.__annotations__.items():
pos = self.fn.arg_names.index(name)
assert type == triton.language.core.constexpr
wargs[pos] = type(wargs[pos])
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
@@ -601,18 +646,19 @@ class Kernel:
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
# compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = tuple(attributes.items())
meta_key = tuple(sorted(meta.items()))
const_key = tuple(constants.items())
compute_capability = torch.cuda.get_device_capability(device)
key = (
self.fn.cache_key, version_key(), compute_capability,
types_key, attr_key, num_warps, num_stages, meta_key, const_key
types_key, attr_key, num_warps, num_stages, const_key
)
key = repr(key)
@@ -644,7 +690,7 @@ class Kernel:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages,
constants=constants, **meta
constants=constants,
)
if bin_cache_path:
assert bin_lock_path is not None
@@ -657,12 +703,15 @@ class Kernel:
drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args)
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
# enqueue cached function into stream
callable = drv_cache[key]
stream = torch.cuda.current_stream(device_idx).cuda_stream
grid = grid(meta) if hasattr(grid, '__call__') else grid
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
grid = grid(csts) if hasattr(grid, '__call__') else grid
if isinstance(grid, int):
grid = tuple(grid)
callable(stream, params, *grid)
return callable
@@ -697,31 +746,31 @@ class Autotuner:
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.meta.keys()
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
current = dict(meta, **config.kwargs)
def kernel_call():
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
def __call__(self, *args, **kwargs):
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
timings = {config: self._bench(*args, config=config, **meta) \
timings = {config: self._bench(*args, config=config, **kwargs) \
for config in self.configs}
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
config = self.cache[key]
else:
config = self.configs[0]
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@functools.lru_cache()
@@ -769,6 +818,8 @@ class JITFunction:
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
# annotations
self.__annotations__ = fn.__annotations__
# forward docs
self.__doc__ = fn.__doc__
@@ -839,8 +890,8 @@ class Config:
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
"""
def __init__(self, meta, num_warps=4, num_stages=2):
self.meta = meta
def __init__(self, kwargs, num_warps=4, num_stages=2):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages

View File

@@ -14,9 +14,11 @@ def _to_ir(x, builder):
return builder.get_int64(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
elif isinstance(x, constexpr):
return _to_ir(x.value, builder)
elif isinstance(x, block):
return x.handle
if isinstance(x, dtype):
elif isinstance(x, dtype):
return x.handle(builder)
return x
@@ -257,6 +259,86 @@ class block:
return frontend.cast(self, dtype, _builder)
# -----------------------
# constexpr
# -----------------------
class constexpr:
"""
This class is used to store a value that is known at compile-time.
"""
def __init__(self, value):
self.value = value
def __add__(self, other):
return self.value + other.value
def __radd__(self, other):
return other.value + self.value
def __sub__(self, other):
return self.value - other.value
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
def __rgt__(self, other):
return other.value > self.value
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __bool__(self):
return bool(self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)
# -----------------------
# SPMD Programming Model
# -----------------------
@@ -312,7 +394,12 @@ def zeros(shape, dtype, _builder=None):
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType
"""
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
for i, d in enumerate(shape):
if not isinstance(d, constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
return frontend.zeros(shape, dtype, _builder)

View File

@@ -1,6 +1,5 @@
import triton
import triton.language as tl
import triton._C.libtriton as libtriton
import torch
# ********************************************************
@@ -21,54 +20,46 @@ def _sdd_kernel(
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut, **meta
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
BLOCK = meta['BLOCK']
#------------#
#- Prologue -#
#------------#
pid1 = tl.program_id(1) + grid_offset
blockidm = tl.arange(0, TILE_M) // BLOCK
blockidn = tl.arange(0, TILE_N) // BLOCK
offlutm = blockidm * (TILE_N // BLOCK) * 4
offlutn = blockidn * 4
header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4
# batch offset
off_z = tl.program_id(2)
# head offset
off_h = tl.load(header + 0)
block_id = tl.program_id(1) + grid_offset
lut += block_id * 3
# offsets
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
# initialize pointers to A
start_am = tl.load(header + 1 + offlutm)
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + off_z * stride_za \
a_ptrs = A + (off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
+ offs_ak[None, :] * stride_ak)
# initialize pointers to B
start_bn = tl.load(header + 2 + offlutn)
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + off_z * stride_zb \
b_ptrs = B + (off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk
+ offs_bk[:, None] * stride_bk)
## ---------------- ##
## Inner Loop ##
## ---------------- ##
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
if meta['EVEN_K']:
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b)
a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk
@@ -76,22 +67,15 @@ def _sdd_kernel(
## ---------------- ##
## Epilogue ##
## ---------------- ##
blockidm = tl.arange(0, TILE_M) // BLOCK
blockidn = tl.arange(0, TILE_N) // BLOCK
offlutm = blockidm * (TILE_N // BLOCK) * 4
offlutn = blockidn * 4
off_block_id = 3 + offlutm[:, None] + offlutn[None, :]
block_id = tl.load(header + off_block_id)
# initialize pointers to C
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + off_z * stride_zc \
pc = C + (off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
+ offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
# (A * B)^T = B^T * A^T
if trans_c:
a, b = b, a
@@ -102,46 +86,28 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks,
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
if Ka != Kb:
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
if Ka % 16 != 0:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# allocate output
n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device)
# each iteration of the loop below
# computes the value for one group of super-blocks
# (e.g., all 4x4 super-blocks)
for lut, width, pack in zip(luts, widths, packs):
# maximum grid size in Triton/CUDA is 64k but we may have more
# super-blocks than that.
max_grid = 65535
for off_grid in range(0, width, max_grid):
grid = [1, min(max_grid, width - off_grid), c.shape[0]]
# fmt: off
pgm = _sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, off_grid, lut,
TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3,
num_warps=4,
)
# print(pgm.asm['ptx'])
# exit()
if out is None:
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
else:
assert out.shape == (a.shape[0], lut.shape[0], block, block)
c = out
grid = [1, c.shape[1], c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4,
num_warps=4,
)
return c
def sdd_lut(layout, block, device):
start_width = 128 // block
layout = layout.type(torch.int32)
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
nnz = nnz.reshape(-1, 4)
width = nnz.shape[0] // (size * size)
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
widths.append(width)
packs.append(size)
return luts, None, widths, packs
lut = layout.nonzero(as_tuple=False).to(device).int()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
@@ -154,12 +120,10 @@ def _dsd_kernel(
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut, **meta
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
GROUP_SIZE_M = meta['GROUP_SIZE_M']
#------------#
#- Prologue -#
#------------#
@@ -167,9 +131,9 @@ def _dsd_kernel(
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2)
header = lut + pid_m * 4
header = lut + pid_n * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
column = tl.load(header + 2)
@@ -185,7 +149,8 @@ def _dsd_kernel(
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N)
offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
@@ -197,28 +162,33 @@ def _dsd_kernel(
## Inner Loop ##
## ---------------- ##
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b)
pa += inc_a
pb += inc_b*stride_bk
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
pa += inc_a
pb += inc_b*stride_bk
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column*TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N)
offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
@@ -230,11 +200,15 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
# meta-parameter heuristics
TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block]
TILE_N = 128
# compute output
grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0]
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
# fmt: off
_dsd_kernel[grid](
a, b, c,
@@ -242,8 +216,8 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
num_warps=4, GROUP_SIZE_M=8,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
# exit()
return c
@@ -323,7 +297,7 @@ def dsd_lut(layout, block, step, trans, device):
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
return lut, None, width, None
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
@@ -334,12 +308,10 @@ def _dds_kernel(
stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut, **meta
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
GROUP_SIZE_M = meta['GROUP_SIZE_M']
#------------#
#- Prologue -#
#------------#
@@ -347,16 +319,17 @@ def _dds_kernel(
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pid_z = tl.program_id(2)
header = lut + pid_m * 4
header = lut + pid_n * 4
offset = tl.load(header + 0)
AS1 = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (dense)
offs_am = pid_n*TILE_M + tl.arange(0, TILE_M)
offs_am = pid_m*TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K)
@@ -394,7 +367,7 @@ def _dds_kernel(
## ---------------- ##
c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense)
offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M)
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \
@@ -403,7 +376,7 @@ def _dds_kernel(
# write back
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
@@ -415,9 +388,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0]
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
# fmt: off
_dds_kernel[grid](
a, b, c,
@@ -425,8 +402,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
num_warps=4, GROUP_SIZE_M=8,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
return c
@@ -439,25 +416,23 @@ class _matmul(torch.autograd.Function):
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks,
da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
ctx.trans_c = trans_c
ctx.has_out = out is not None
return c
@staticmethod
@@ -466,155 +441,55 @@ class _matmul(torch.autograd.Function):
a, b = ctx.saved_tensors
da, db = None, None
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width,
ctx.da_packs
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width,
ctx.db_packs
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
dout = dc if ctx.has_out else None
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
None, None, None, None, None, dout
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = min(block, 32)
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.block = block
self.mode = mode
self.trans_a = trans_a
self.trans_b = trans_b
layout_dim = layout.ndim
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
if not mode == 'sdd':
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
# Inner dim of the dense input should be equal to the inner dim of the sparse input
self.dense_inner_size = layout.shape[sparse_inner] * block
# Expected shape for sparse inputs
self.sparse_shape = (layout.sum().item(), block, block)
# Support using the same layout across attention heads etc.
if layout_dim == 2:
layout = layout.unsqueeze(0)
layout = layout.long() # Above code assumes the layout tensor is an integral type
self.trans_c = trans_c
self.layout = layout
self.spdims = layout.shape
step = min(block, 32)
if self.mode == 'sdd':
self.c_lut, self.c_width = sdd_lut(layout, block, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
if self.mode == 'dsd':
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
# and potential illegal memory accesses
original_dims = max(a.ndim, b.ndim)
a, b = self._validate_inputs(a, b)
# execute
def __call__(self, a, b, out = None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
dims_to_trim = c.ndim - original_dims
for _ in range(dims_to_trim):
c = c.squeeze(0)
return c
def _validate_inputs(self, a, b):
if a.device != b.device:
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
f"and {b.device} for tensor B")
if not a.is_cuda:
raise ValueError("Only GPU devices are supported for now")
# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
if torch.is_autocast_enabled():
a, b = a.half(), b.half()
elif a.dtype != b.dtype:
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
if mode != 'sdd':
# One input is sparse
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
dense_inner = dense.shape[self.dense_inner_dim]
if dense_inner != self.dense_inner_size:
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
f"{sparse_name}, got {sparse.shape}")
def add_extra_dims(x):
# Add extra leading singleton dimensions if needed
dims_needed = 4 - x.ndim
if dims_needed > 0:
singletons = [1] * dims_needed
x = x.view(*singletons, *x.shape)
elif dims_needed < 0:
raise ValueError("Tensors with more than 4 dimensions are not currently supported")
return x
# Pad shapes with leading singleton dimensions
a = add_extra_dims(a)
b = add_extra_dims(b)
return a, b
return c

View File

@@ -16,10 +16,9 @@ def num_warps(n):
@triton.jit
def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
**meta
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
):
TN = meta['TN']
BLOCK = meta['BLOCK']
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
# create index ranges
@@ -43,25 +42,25 @@ def _forward(
x = tl.load(px, mask=check, other=-float('inf'))
x = x.to(tl.float32)
# apply scale
if meta['APPLY_SCALE']:
if APPLY_SCALE:
x = x * scale
# apply RPE
if meta['APPLY_RPE']:
if APPLY_RPE:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = tl.load(prpe, mask=check, other=0)
x = x + rpe
# apply key-padding mask
if meta['APPLY_KP_MASK']:
if APPLY_KP_MASK:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
if meta['KP_MASK_MUL']:
if KP_MASK_MUL:
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m
# apply attention mask
if meta['APPLY_ATTN_MASK']:
if APPLY_ATTN_MASK:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
if meta['ATTN_MASK_MUL']:
if ATTN_MASK_MUL:
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# apply causal mask
@@ -75,11 +74,9 @@ def _forward(
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
TN = meta['TN']
BLOCK = meta['BLOCK']
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
@@ -172,8 +169,7 @@ class _softmax(torch.autograd.Function):
APPLY_KP_MASK = apply_kp_mask,
APPLY_ATTN_MASK = apply_attn_mask,
KP_MASK_MUL = (kp_mask_mode == 'mul'),
ATTN_MASK_MUL = (attn_mask_mode == 'mul'),
force_nc_cache = True)
ATTN_MASK_MUL = (attn_mask_mode == 'mul'))
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
@@ -196,7 +192,7 @@ class _softmax(torch.autograd.Function):
# run kernel
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block)
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None

View File

@@ -26,8 +26,7 @@ def num_warps(N):
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
BLOCK = meta['BLOCK']
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
@@ -52,8 +51,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, **meta):
BLOCK = meta['BLOCK']
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)

View File

@@ -26,13 +26,9 @@ def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
LOCKS, **META):
# extract meta-parameters
BLOCK_M = META['BLOCK_M']
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = META['GROUP_M']
SPLIT_K = META['SPLIT_K']
LOCKS,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
@@ -55,7 +51,7 @@ def _kernel(A, B, C, M, N, K,
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K*SPLIT_K):
if META['EVEN_K']:
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
@@ -113,14 +109,11 @@ class _matmul(torch.autograd.Function):
locks = _matmul._locks[device]
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c,
M, N, K,
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
locks,
GROUP_M=8)
# done
locks, GROUP_M=8)
return c
@staticmethod

View File

@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel
# --------------------------
from triton.language.core import constexpr
import torch
import triton
import triton.language as tl
@@ -23,9 +24,9 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -37,8 +38,8 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
# Load x and y from DRAM, masking out any extra elements in case
# the input is not a multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y

View File

@@ -65,11 +65,11 @@ import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each

View File

@@ -182,17 +182,13 @@ def matmul_kernel(
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
**meta,
):
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# extract meta-parameters
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse