[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -198,8 +198,6 @@ scanline_layout::scanline_layout(size_t num_warps,
bool is_dot = std::any_of(values.begin(), values.end(), bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); }); [&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
std::vector<ir::value*> ptrs; std::vector<ir::value*> ptrs;
for(ir::value *v: values) for(ir::value *v: values)
for(ir::user *usr: v->get_users()) for(ir::user *usr: v->get_users())
@@ -215,7 +213,6 @@ scanline_layout::scanline_layout(size_t num_warps,
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits)); contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
} }
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i]; size /= shape_[i];

View File

@@ -77,7 +77,6 @@ void coalesce::run(ir::module &mod) {
builder.insert(new_x); builder.insert(new_x);
x->replace_all_uses_with(new_x); x->replace_all_uses_with(new_x);
new_x->replace_uses_of_with(new_x, x); new_x->replace_uses_of_with(new_x, x);
// new_x->replace_uses_of_with(new_x, new_x);
} }
} }
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
@@ -101,6 +100,8 @@ void coalesce::run(ir::module &mod) {
ir::instruction* curr = queue.back(); ir::instruction* curr = queue.back();
seen.insert(curr); seen.insert(curr);
queue.pop_back(); queue.pop_back();
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
break;
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){ if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
in_contig = align_->contiguous(io_inst->get_pointer_operand()); in_contig = align_->contiguous(io_inst->get_pointer_operand());
break; break;

View File

@@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
ofs.close(); ofs.close();
std::string cmd; std::string cmd;
int err; int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o"; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str()); err = system(cmd.c_str());
CUmodule ret; CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary ); std::ifstream _cubin(_fbin, std::ios::binary );

View File

@@ -67,8 +67,8 @@ def test_matmul(M, N, K):
import triton.language as tl import triton.language as tl
@triton.jit @triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta): def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE = meta['BLOCK_SIZE'] BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE) offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -99,7 +99,7 @@ def test_elementwise(N):
z = torch.empty((N, ), dtype=torch.float16, device='cuda') z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z) x = torch.randn_like(z)
y = torch.randn_like(z) y = torch.randn_like(z)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250) ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6

View File

@@ -40,7 +40,7 @@ def patch_kernel(template, to_replace):
def test_empty_kernel(dtype_x, device='cuda'): def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128 SIZE = 128
@triton.jit @triton.jit
def kernel(X, **meta): def kernel(X, SIZE: tl.constexpr):
pass pass
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device) x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4) kernel[(1, )](x, SIZE=SIZE, num_warps=4)
@@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, **meta): def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, meta['SIZE']) off = tl.arange(0, SIZE)
x = tl.load(X + off) x = tl.load(X + off)
z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
tl.store(Z + off, z) tl.store(Z + off, z)
@@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
SIZE = 128 SIZE = 128
# define the kernel / launch-grid # define the kernel / launch-grid
@triton.jit @triton.jit
def kernel(Z, X, Y, **meta): def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, meta['SIZE']) off = tl.arange(0, SIZE)
x = tl.load(X + off) x = tl.load(X + off)
y = tl.load(Y + off) y = tl.load(Y + off)
z = GENERATE_TEST_HERE z = GENERATE_TEST_HERE
@@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'):
# Triton kernel # Triton kernel
@triton.jit @triton.jit
def kernel(Z, X, **meta): def kernel(Z, X, SIZE: tl.constexpr):
SIZE = meta['SIZE']
m = tl.arange(0, SIZE) m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE) n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR) x = tl.load(X_PTR_EXPR)
@@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, Z, **meta): def kernel(X, Z):
pid = tl.program_id(0) pid = tl.program_id(0)
x = tl.load(X + pid) x = tl.load(X + pid)
old = GENERATE_TEST_HERE old = GENERATE_TEST_HERE
@@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, Z, **meta): def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X) x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST']) z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
tl.store(Z, z) tl.store(Z, z)
# triton result # triton result
@@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'):
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, Z, **meta): def kernel(X, Z, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, meta['BLOCK'])) x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, tl.sum(x, axis=0)) tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device) x = triton.testing.random((shape,), dtype=dtype, device=device)
@@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
dtype = cvt[dtype] dtype = cvt[dtype]
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, Z, **meta): def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, meta['BLOCK_M']) range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, meta['BLOCK_N']) range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :]) x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=meta['AXIS']) z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z) tl.store(Z + range_m, z)
# input # input
x = triton.testing.random(shape, dtype=dtype, device=device) x = triton.testing.random(shape, dtype=dtype, device=device)
@@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xn, def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn, **meta): Z, stride_zm, stride_zn,
BLOCK_M = meta['BLOCK_M'] BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
BLOCK_N = meta['BLOCK_N']
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N) off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
@@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'):
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xk, def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn, Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn, **meta): Z, stride_zm, stride_zn,
BLOCK_M = meta['BLOCK_M'] BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
BLOCK_K = meta['BLOCK_K'] ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
BLOCK_N = meta['BLOCK_N']
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N) off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K) off_k = tl.arange(0, BLOCK_K)
@@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'):
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys)) z = tl.dot(tl.load(Xs), tl.load(Ys))
if meta['ADD_MATRIX']: if ADD_MATRIX:
z += tl.load(Zs) z += tl.load(Zs)
if meta['ADD_ROWS']: if ADD_ROWS:
ZRs = Z + off_m * stride_zm ZRs = Z + off_m * stride_zm
z += tl.load(ZRs)[:, None] z += tl.load(ZRs)[:, None]
if meta['ADD_COLS']: if ADD_COLS:
ZCs = Z + off_n * stride_zn ZCs = Z + off_n * stride_zn
z += tl.load(ZCs)[None, :] z += tl.load(ZCs)[None, :]
tl.store(Zs, z) tl.store(Zs, z)
@@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'):
def test_dot_without_load(): def test_dot_without_load():
@triton.jit @triton.jit
def kernel(out, **meta): def kernel(out):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32) a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32) b = tl.zeros((32, 32), tl.float32)
@@ -538,9 +535,10 @@ def test_arange(start, device='cuda'):
BLOCK = 128 BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit @triton.jit
def _kernel(z, **meta): def _kernel(z, BLOCK: tl.constexpr,
off = tl.arange(0, meta['BLOCK']) START: tl.constexpr, END: tl.constexpr):
val = tl.arange(meta['START'], meta['END']) off = tl.arange(0, BLOCK)
val = tl.arange(START, END)
tl.store(z + off, val) tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK) _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device) z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
@@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
@triton.jit @triton.jit
def _kernel(in1_ptr, in2_ptr, output_ptr, def _kernel(in1_ptr, in2_ptr, output_ptr,
in_stride, in2_stride, out_stride, in_stride, in2_stride, out_stride,
in_numel, in2_numel, out_numel, **meta): in_numel, in2_numel, out_numel,
M = meta['M'] M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
N = meta['N']
K = meta['K']
M_offsets = tl.arange(0, M) M_offsets = tl.arange(0, M)
N_offsets = tl.arange(0, N) N_offsets = tl.arange(0, N)
@@ -605,14 +601,13 @@ def test_load_cache_modifier(cache):
dst = torch.empty(128, device='cuda') dst = torch.empty(128, device='cuda')
@triton.jit @triton.jit
def _kernel(dst, src, **meta): def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128) offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=meta['CACHE']) x = tl.load(src+offsets, cache_modifier=CACHE)
tl.store(dst+offsets, x) tl.store(dst+offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache) pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
if cache == '': if cache == '':
assert 'ld.global.ca' not in ptx assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx assert 'ld.global.cg' not in ptx
@@ -644,7 +639,7 @@ def test_load_cache_modifier(cache):
#---------------- #----------------
def test_noop(device='cuda'): def test_noop(device='cuda'):
@triton.jit @triton.jit
def kernel(**meta): def kernel(x):
pass pass
x = triton.testing.random((1,), dtype=torch.int32, device=device) x = triton.testing.random((1,), dtype=torch.int32, device=device)
kernel[(1, )](x) kernel[(1, )](x)

View File

@@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
}[MODE] }[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result # triton result
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
@@ -151,8 +151,8 @@ def triton_attention(
value: torch.Tensor, value: torch.Tensor,
scale: float, scale: float,
): ):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True) sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False) sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax( sparse_softmax = triton.ops.blocksparse.softmax(
layout, layout,
block, block,

View File

@@ -66,8 +66,8 @@ import torch
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
torch.manual_seed(0) torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually # nuke kernel decorators -- will set meta-parameters manually
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)] configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
kernel = triton.ops._matmul.kernel kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators decorators = kernel.kernel_decorators
kernel.kernel_decorators = [] kernel.kernel_decorators = []

View File

@@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor):
arg_values = [] arg_values = []
for i, arg_name in enumerate(arg_names): for i, arg_name in enumerate(arg_names):
if i in self.constants: if i in self.constants:
cst = triton.language.core._to_ir(self.constants[i], self.builder) cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst) arg_values.append(cst)
else: else:
if i in self.attributes: if i in self.attributes:
@@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor):
fn.add_attr(i + 1, attr) fn.add_attr(i + 1, attr)
fn.args[i].name = arg_name fn.args[i].name = arg_name
arg_values.append(fn.args[i]) arg_values.append(fn.args[i])
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.set_value(arg_name, arg_value)
if inline: if inline:
@@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, node) ast.NodeVisitor.generic_visit(self, node)
return node.arg return node.arg
def visit_AnnAssign(self, node):
# extract attributes
annotation = self.visit(node.annotation)
target = self.visit(node.target)
value = self.visit(node.value)
# constexpr
if annotation == triton.language.constexpr:
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
self.lscope[target] = triton.language.constexpr(value)
return self.lscope[target]
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node): def visit_Assign(self, node):
_names = [] _names = []
for target in node.targets: for target in node.targets:
@@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple): if not isinstance(values, tuple):
values = [values] values = [values]
for name, value in zip(names, values): for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if not isinstance(value, triton.language.block): if not isinstance(value, triton.language.block):
value = triton.language.core._to_ir(value, self.builder) value = triton.language.core._to_ir(value, self.builder)
self.set_value(name, value) self.set_value(name, value)
@@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BinOp(self, node): def visit_BinOp(self, node):
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.right) rhs = self.visit(node.right)
if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
fn = { fn = {
ast.Add: '__add__', ast.Add: '__add__',
ast.Sub: '__sub__', ast.Sub: '__sub__',
@@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor):
ast.BitOr: '__or__', ast.BitOr: '__or__',
ast.BitXor: '__xor__', ast.BitXor: '__xor__',
}[type(node.op)] }[type(node.op)]
kws = dict()
if self.is_triton_object(lhs): if self.is_triton_object(lhs):
kws['_builder'] = self.builder return getattr(lhs, fn)(rhs, _builder=self.builder)
ret = getattr(lhs, fn)(rhs, **kws) elif self.is_triton_object(rhs):
if ret is NotImplemented:
if self.is_triton_object(rhs):
kws['_builder'] = self.builder
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
ret = getattr(rhs, fn)(lhs, **kws) return getattr(rhs, fn)(lhs, _builder=self.builder)
return ret else:
return getattr(lhs, fn)(rhs)
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
@@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.ops) == 1 assert len(node.ops) == 1
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0]) rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
fn = { fn = {
ast.Eq: '__eq__', ast.Eq: '__eq__',
ast.NotEq: '__ne__', ast.NotEq: '__ne__',
@@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
op = self.visit(node.operand) op = self.visit(node.operand)
if isinstance(op, triton.language.core.constexpr):
op = op.value
fn = { fn = {
ast.USub: '__neg__', ast.USub: '__neg__',
ast.UAdd: '__pos__', ast.UAdd: '__pos__',
@@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor):
return fn(*args, **kws) return fn(*args, **kws)
def visit_Num(self, node): def visit_Num(self, node):
return node.n return triton.language.constexpr(node.n)
def visit_Attribute(self, node): def visit_Attribute(self, node):
lhs = self.visit(node.value) lhs = self.visit(node.value)
@@ -477,6 +505,8 @@ class Kernel:
} }
if hasattr(obj, 'data_ptr'): if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype] return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr):
obj = obj.value
if isinstance(obj, int): if isinstance(obj, int):
if abs(obj) <= 0xffffffff: if abs(obj) <= 0xffffffff:
return 'I' return 'I'
@@ -485,6 +515,8 @@ class Kernel:
return 'f' return 'f'
if isinstance(obj, bool): if isinstance(obj, bool):
return 'B' return 'B'
if isinstance(obj, str):
return 'str'
assert False assert False
@@ -537,7 +569,8 @@ class Kernel:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
# create IR module # create IR module
context = _triton.ir.context() context = _triton.ir.context()
# get just-in-time proto-type of kernel # get just-in-time proto-type of kernel
@@ -547,7 +580,7 @@ class Kernel:
# generate Triton-IR # generate Triton-IR
# export symbols visible from self.fn into code-generator object # export symbols visible from self.fn into code-generator object
gscope = sys.modules[self.fn.module].__dict__ gscope = sys.modules[self.fn.module].__dict__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta) generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try: try:
generator.visit(self.fn.parse()) generator.visit(self.fn.parse())
except Exception as e: except Exception as e:
@@ -566,7 +599,19 @@ class Kernel:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory") raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps) return Binary(backend, name, asm, shared_mem, num_warps)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# handle arguments passed by name
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
if len(wargs) != len(self.fn.arg_names):
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
# handle annotations
for name, type in self.fn.__annotations__.items():
pos = self.fn.arg_names.index(name)
assert type == triton.language.core.constexpr
wargs[pos] = type(wargs[pos])
# device inference # device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0: if len(tensor_idxs) == 0:
@@ -601,18 +646,19 @@ class Kernel:
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)] args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \ attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
if isinstance(a, int) and i not in self.fn.do_not_specialize} if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
# compute hash for caching this kernel # compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs) types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = tuple(attributes.items()) attr_key = tuple(attributes.items())
meta_key = tuple(sorted(meta.items()))
const_key = tuple(constants.items()) const_key = tuple(constants.items())
compute_capability = torch.cuda.get_device_capability(device) compute_capability = torch.cuda.get_device_capability(device)
key = ( key = (
self.fn.cache_key, version_key(), compute_capability, self.fn.cache_key, version_key(), compute_capability,
types_key, attr_key, num_warps, num_stages, meta_key, const_key types_key, attr_key, num_warps, num_stages, const_key
) )
key = repr(key) key = repr(key)
@@ -644,7 +690,7 @@ class Kernel:
binary = self._compile( binary = self._compile(
*wargs, device=device_idx, attributes=attributes, *wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages, num_warps=num_warps, num_stages=num_stages,
constants=constants, **meta constants=constants,
) )
if bin_cache_path: if bin_cache_path:
assert bin_lock_path is not None assert bin_lock_path is not None
@@ -657,12 +703,15 @@ class Kernel:
drv_cache[key] = LoadedBinary(device_idx, binary) drv_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments # pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)]) fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
params = struct.pack(fmt, *args) params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
# enqueue cached function into stream # enqueue cached function into stream
callable = drv_cache[key] callable = drv_cache[key]
stream = torch.cuda.current_stream(device_idx).cuda_stream stream = torch.cuda.current_stream(device_idx).cuda_stream
grid = grid(meta) if hasattr(grid, '__call__') else grid csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
grid = grid(csts) if hasattr(grid, '__call__') else grid
if isinstance(grid, int):
grid = tuple(grid)
callable(stream, params, *grid) callable(stream, params, *grid)
return callable return callable
@@ -697,31 +746,31 @@ class Autotuner:
def _bench(self, *args, config, **meta): def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided # check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner # as kwargs and by the autotuner
conflicts = meta.keys() & config.meta.keys() conflicts = meta.keys() & config.kwargs.keys()
if conflicts: if conflicts:
raise ValueError( raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}." f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols." " Make sure that you don't re-define auto-tuned symbols."
) )
# augment meta-parameters with tunable ones # augment meta-parameters with tunable ones
current = dict(meta, **config.meta) current = dict(meta, **config.kwargs)
def kernel_call(): def kernel_call():
self.hook(args) self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call) return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta): def __call__(self, *args, **kwargs):
if len(self.configs) > 1: if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx]) key = tuple([args[i] for i in self.key_idx])
if key not in self.cache: if key not in self.cache:
timings = {config: self._bench(*args, config=config, **meta) \ timings = {config: self._bench(*args, config=config, **kwargs) \
for config in self.configs} for config in self.configs}
self.cache[key] = builtins.min(timings, key=timings.get) self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args) self.hook(args)
config = self.cache[key] config = self.cache[key]
else: else:
config = self.configs[0] config = self.configs[0]
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@functools.lru_cache() @functools.lru_cache()
@@ -769,6 +818,8 @@ class JITFunction:
# when called with a grid using __getitem__ # when called with a grid using __getitem__
self.kernel_decorators = [] self.kernel_decorators = []
self.kernel = None self.kernel = None
# annotations
self.__annotations__ = fn.__annotations__
# forward docs # forward docs
self.__doc__ = fn.__doc__ self.__doc__ = fn.__doc__
@@ -839,8 +890,8 @@ class Config:
Mostly useful for matrix multiplication workloads on SM80+ GPUs. Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int :type num_stages: int
""" """
def __init__(self, meta, num_warps=4, num_stages=2): def __init__(self, kwargs, num_warps=4, num_stages=2):
self.meta = meta self.kwargs = kwargs
self.num_warps = num_warps self.num_warps = num_warps
self.num_stages = num_stages self.num_stages = num_stages

View File

@@ -14,9 +14,11 @@ def _to_ir(x, builder):
return builder.get_int64(x) return builder.get_int64(x)
elif isinstance(x, float): elif isinstance(x, float):
return builder.get_float32(x) return builder.get_float32(x)
if isinstance(x, block): elif isinstance(x, constexpr):
return _to_ir(x.value, builder)
elif isinstance(x, block):
return x.handle return x.handle
if isinstance(x, dtype): elif isinstance(x, dtype):
return x.handle(builder) return x.handle(builder)
return x return x
@@ -257,6 +259,86 @@ class block:
return frontend.cast(self, dtype, _builder) return frontend.cast(self, dtype, _builder)
# -----------------------
# constexpr
# -----------------------
class constexpr:
"""
This class is used to store a value that is known at compile-time.
"""
def __init__(self, value):
self.value = value
def __add__(self, other):
return self.value + other.value
def __radd__(self, other):
return other.value + self.value
def __sub__(self, other):
return self.value - other.value
def __rsub__(self, other):
return other.value - self.value
def __mul__(self, other):
return self.value * other.value
def __rmul__(self, other):
return other.value * self.value
def __truediv__(self, other):
return self.value / other.value
def __rtruediv__(self, other):
return other.value / self.value
def __floordiv__(self, other):
return self.value // other.value
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
def __rgt__(self, other):
return other.value > self.value
def __ge__(self, other):
return self.value >= other.value
def __rge__(self, other):
return other.value >= self.value
def __lt__(self, other):
return self.value < other.value
def __rlt__(self, other):
return other.value < self.value
def __le__(self, other):
return self.value <= other.value
def __rle__(self, other):
return other.value <= self.value
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __bool__(self):
return bool(self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)
# ----------------------- # -----------------------
# SPMD Programming Model # SPMD Programming Model
# ----------------------- # -----------------------
@@ -312,7 +394,12 @@ def zeros(shape, dtype, _builder=None):
:param dtype: Data-type of the new array, e.g., :code:`tl.float16` :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
:type dtype: DType :type dtype: DType
""" """
shape = [int(x.handle) if isinstance(x, block) else x for x in shape] for i, d in enumerate(shape):
if not isinstance(d, constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
return frontend.zeros(shape, dtype, _builder) return frontend.zeros(shape, dtype, _builder)

View File

@@ -1,6 +1,5 @@
import triton import triton
import triton.language as tl import triton.language as tl
import triton._C.libtriton as libtriton
import torch import torch
# ******************************************************** # ********************************************************
@@ -21,54 +20,46 @@ def _sdd_kernel(
stride_za, stride_ha, stride_ma, stride_ak, stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb, stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc, stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut, **meta K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
): ):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
BLOCK = meta['BLOCK']
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
pid1 = tl.program_id(1) + grid_offset block_id = tl.program_id(1) + grid_offset
blockidm = tl.arange(0, TILE_M) // BLOCK lut += block_id * 3
blockidn = tl.arange(0, TILE_N) // BLOCK # offsets
offlutm = blockidm * (TILE_N // BLOCK) * 4 off_z = tl.program_id(2) # batch
offlutn = blockidn * 4 off_h = tl.load(lut + 0) # head
header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4
# batch offset
off_z = tl.program_id(2)
# head offset
off_h = tl.load(header + 0)
# initialize pointers to A # initialize pointers to A
start_am = tl.load(header + 1 + offlutm) start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K) offs_ak = tl.arange(0, TILE_K)
a_ptrs = A + off_z * stride_za \ a_ptrs = A + (off_z * stride_za \
+ off_h * stride_ha \ + off_h * stride_ha \
+ offs_am[:, None] * stride_ma \ + offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak + offs_ak[None, :] * stride_ak)
# initialize pointers to B # initialize pointers to B
start_bn = tl.load(header + 2 + offlutn) start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K) offs_bk = tl.arange(0, TILE_K)
b_ptrs = B + off_z * stride_zb \ b_ptrs = B + (off_z * stride_zb \
+ off_h * stride_hb \ + off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \ + offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk + offs_bk[:, None] * stride_bk)
## ---------------- ## ## ---------------- ##
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K): for k in range(K, 0, -TILE_K):
if meta['EVEN_K']: if EVEN_K:
a = tl.load(a_ptrs) a = tl.load(a_ptrs)
b = tl.load(b_ptrs) b = tl.load(b_ptrs)
else: else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b) acc += tl.dot(a, b)
a_ptrs += TILE_K * stride_ak a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk b_ptrs += TILE_K * stride_bk
@@ -76,22 +67,15 @@ def _sdd_kernel(
## ---------------- ## ## ---------------- ##
## Epilogue ## ## Epilogue ##
## ---------------- ## ## ---------------- ##
blockidm = tl.arange(0, TILE_M) // BLOCK
blockidn = tl.arange(0, TILE_N) // BLOCK
offlutm = blockidm * (TILE_N // BLOCK) * 4
offlutn = blockidn * 4
off_block_id = 3 + offlutm[:, None] + offlutn[None, :]
block_id = tl.load(header + off_block_id)
# initialize pointers to C
offs_cm = tl.arange(0, TILE_M) % BLOCK offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C + off_z * stride_zc \ pc = C + (off_z * stride_zc \
+ block_id * stride_hc \ + block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \ + offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc + offs_cn[None, :] * stride_nc)
tl.store(pc, c, mask=True) tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
# (A * B)^T = B^T * A^T # (A * B)^T = B^T * A^T
if trans_c: if trans_c:
a, b = b, a a, b = b, a
@@ -102,46 +86,28 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks,
Ka, Kb = a.shape[a_dim], b.shape[b_dim] Ka, Kb = a.shape[a_dim], b.shape[b_dim]
if Ka != Kb: if Ka != Kb:
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
if Ka % 16 != 0:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# allocate output # allocate output
n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)]) if out is None:
c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device) c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
# each iteration of the loop below else:
# computes the value for one group of super-blocks assert out.shape == (a.shape[0], lut.shape[0], block, block)
# (e.g., all 4x4 super-blocks) c = out
for lut, width, pack in zip(luts, widths, packs): grid = [1, c.shape[1], c.shape[0]]
# maximum grid size in Triton/CUDA is 64k but we may have more _sdd_kernel[grid](
# super-blocks than that. a, b, c,
max_grid = 65535 a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
for off_grid in range(0, width, max_grid): b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
grid = [1, min(max_grid, width - off_grid), c.shape[0]] c.stride(0), c.stride(1), c.stride(2), c.stride(3),
# fmt: off Ka, 0, lut,
pgm = _sdd_kernel[grid]( TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4,
a, b, c, num_warps=4,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), )
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, off_grid, lut,
TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3,
num_warps=4,
)
# print(pgm.asm['ptx'])
# exit()
return c return c
def sdd_lut(layout, block, device): def sdd_lut(layout, block, device):
start_width = 128 // block lut = layout.nonzero(as_tuple=False).to(device).int()
layout = layout.type(torch.int32) return lut, None
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
nnz = nnz.reshape(-1, 4)
width = nnz.shape[0] // (size * size)
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
widths.append(width)
packs.append(size)
return luts, None, widths, packs
# ----------------------------- # -----------------------------
# Dense = Sparse x Dense (DSD) # Dense = Sparse x Dense (DSD)
@@ -154,12 +120,10 @@ def _dsd_kernel(
stride_az, stride_ha, stride_am, stride_ak, stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn, stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn, stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut, **meta DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
): ):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
GROUP_SIZE_M = meta['GROUP_SIZE_M']
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
@@ -167,9 +131,9 @@ def _dsd_kernel(
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1) num_pid_n = tl.num_programs(1)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2) pidz = tl.program_id(2)
header = lut + pid_m * 4 header = lut + pid_n * 4
offset = tl.load(header + 0) offset = tl.load(header + 0)
K = tl.load(header + 1) K = tl.load(header + 1)
column = tl.load(header + 2) column = tl.load(header + 2)
@@ -185,7 +149,8 @@ def _dsd_kernel(
+ offs_am[:, None] * stride_am \ + offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak + offs_ak[None, :] * stride_ak
# initialize pointers to B (dense) # initialize pointers to B (dense)
offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N) offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc) start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K) offs_bk = start_bk + tl.arange(0, TILE_K)
@@ -197,28 +162,33 @@ def _dsd_kernel(
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K): for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True) a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0) b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b) acc += tl.dot(a, b)
pa += inc_a
pb += inc_b*stride_bk
pinc += 2 pinc += 2
inc_a = tl.load(pinc + 1) inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8) inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc) inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8) inc_b = tl.multiple_of(inc_b, 8)
pa += inc_a
pb += inc_b*stride_bk
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
# initialize pointers to C # initialize pointers to C
offs_cm = column*TILE_M + tl.arange(0, TILE_M) offs_cm = column*TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N) offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N)
pc = C + off_h * stride_hc \ pc = C + off_h * stride_hc \
+ pidz * stride_zc \ + pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \ + offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn + offs_cn[None, :] * stride_cn
tl.store(pc, c, mask = offs_cn[None, :] < DS0) tl.store(pc, c, mask = offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
# shapes / dtypes # shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1] AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0) BS0 = b.size(0)
@@ -230,11 +200,15 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
CS1 = BS1 CS1 = BS1
CS2 = BS3 if trans_c else AS1 CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3 CS3 = AS1 if trans_c else BS3
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
# meta-parameter heuristics # meta-parameter heuristics
TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block] TILE_N = 128
# compute output # compute output
grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0] grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
# fmt: off # fmt: off
_dsd_kernel[grid]( _dsd_kernel[grid](
a, b, c, a, b, c,
@@ -242,8 +216,8 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut, BS3, AS1, lut,
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3, TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
num_warps=4, GROUP_SIZE_M=8, num_warps=4, GROUP_SIZE_M=4,
) )
# exit() # exit()
return c return c
@@ -323,7 +297,7 @@ def dsd_lut(layout, block, step, trans, device):
lut = torch.cat((header, incs)) lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device) lut = lut.type(torch.int32).to(device)
# create locks # create locks
return lut, None, width, None return lut, width
# ----------------------------- # -----------------------------
# Dense = Dense x Sparse (DDS) # Dense = Dense x Sparse (DDS)
@@ -334,12 +308,10 @@ def _dds_kernel(
stride_za, stride_ha, stride_ma, stride_ka, stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn, stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc, stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut, **meta DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
): ):
TILE_M = meta['TILE_M']
TILE_N = meta['TILE_N']
TILE_K = meta['TILE_K']
GROUP_SIZE_M = meta['GROUP_SIZE_M']
#------------# #------------#
#- Prologue -# #- Prologue -#
#------------# #------------#
@@ -347,16 +319,17 @@ def _dds_kernel(
pid_n = tl.program_id(1) pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0) num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1) num_pid_n = tl.num_programs(1)
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pid_z = tl.program_id(2) pid_z = tl.program_id(2)
header = lut + pid_m * 4 header = lut + pid_n * 4
offset = tl.load(header + 0) offset = tl.load(header + 0)
AS1 = tl.load(header + 1) AS1 = tl.load(header + 1)
column = tl.load(header + 2) column = tl.load(header + 2)
off_h = tl.load(header + 3) off_h = tl.load(header + 3)
pinc = lut + offset pinc = lut + offset
# initialize pointers to A (dense) # initialize pointers to A (dense)
offs_am = pid_n*TILE_M + tl.arange(0, TILE_M) offs_am = pid_m*TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc) start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8) start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K) offs_ak = start_ak + tl.arange(0, TILE_K)
@@ -394,7 +367,7 @@ def _dds_kernel(
## ---------------- ## ## ---------------- ##
c = acc.to(C.dtype.element_ty) c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense) # initialize pointers to C (dense)
offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M) offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N) offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \ ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \ + pid_z * stride_zc \
@@ -403,7 +376,7 @@ def _dds_kernel(
# write back # write back
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0) tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
# shapes / dtypes # shapes / dtypes
AS0 = a.size(0) AS0 = a.size(0)
AS1 = a.size(1) AS1 = a.size(1)
@@ -415,9 +388,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
CS1 = AS1 CS1 = AS1
CS2 = BS2 if trans_c else AS2 CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2 CS3 = AS2 if trans_c else BS2
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block] TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0] grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
# fmt: off # fmt: off
_dds_kernel[grid]( _dds_kernel[grid](
a, b, c, a, b, c,
@@ -425,8 +402,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut, AS2, BS2, lut,
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3, TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
num_warps=4, GROUP_SIZE_M=8, num_warps=4, GROUP_SIZE_M=4,
) )
return c return c
@@ -439,25 +416,23 @@ class _matmul(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks, ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
da_width, da_packs, db_lut, db_num_locks, db_width, db_packs c_lut, c_width, da_lut, da_width, db_lut, db_width, out
): ):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs) c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward # save for backward
ctx.save_for_backward(a, b) ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut ctx.da_lut = da_lut
ctx.da_width = da_width ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode ctx.mode = mode
ctx.spdims = spdims ctx.spdims = spdims
ctx.block = block ctx.block = block
ctx.trans_a = trans_a ctx.trans_a = trans_a
ctx.trans_b = trans_b ctx.trans_b = trans_b
ctx.trans_c = trans_c
ctx.has_out = out is not None
return c return c
@staticmethod @staticmethod
@@ -466,155 +441,55 @@ class _matmul(torch.autograd.Function):
a, b = ctx.saved_tensors a, b = ctx.saved_tensors
da, db = None, None da, db = None, None
mode = ctx.mode mode = ctx.mode
# gradients w.r.t. a # gradients w.r.t. a
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2] mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da]( da = _matmul.fn[mode_da](
dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width, dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
ctx.da_packs
) )
# gradients w.r.t. b # gradients w.r.t. b
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0] mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db]( db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width, a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
ctx.db_packs
) )
dout = dc if ctx.has_out else None
return da, db, None, None, None,\ return da, db, None, None, None,\
None, None, None, None,\ None, None, None, None,\
None, None, None, None, None, None,\ None, None, None, None, None, dout
None, None, None, None, None, None,\
None, None, None, None, None, None
class matmul: class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = min(block, 32)
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a=False, trans_b=False): def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
if mode not in ['sdd', 'dsd', 'dds']: if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds') raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.block = block self.block = block
self.mode = mode self.mode = mode
self.trans_a = trans_a self.trans_a = trans_a
self.trans_b = trans_b self.trans_b = trans_b
self.trans_c = trans_c
layout_dim = layout.ndim
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
if not mode == 'sdd':
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
# Inner dim of the dense input should be equal to the inner dim of the sparse input
self.dense_inner_size = layout.shape[sparse_inner] * block
# Expected shape for sparse inputs
self.sparse_shape = (layout.sum().item(), block, block)
# Support using the same layout across attention heads etc.
if layout_dim == 2:
layout = layout.unsqueeze(0)
layout = layout.long() # Above code assumes the layout tensor is an integral type
self.layout = layout self.layout = layout
self.spdims = layout.shape self.spdims = layout.shape
step = min(block, 32)
if self.mode == 'sdd':
self.c_lut, self.c_width = sdd_lut(layout, block, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
if self.mode == 'dsd':
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b): def __call__(self, a, b, out = None):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
# and potential illegal memory accesses
original_dims = max(a.ndim, b.ndim)
a, b = self._validate_inputs(a, b)
# execute
c = _matmul.apply( c = _matmul.apply(
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width, a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
) )
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
dims_to_trim = c.ndim - original_dims
for _ in range(dims_to_trim):
c = c.squeeze(0)
return c return c
def _validate_inputs(self, a, b):
if a.device != b.device:
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
f"and {b.device} for tensor B")
if not a.is_cuda:
raise ValueError("Only GPU devices are supported for now")
# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
if torch.is_autocast_enabled():
a, b = a.half(), b.half()
elif a.dtype != b.dtype:
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
if mode != 'sdd':
# One input is sparse
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
dense_inner = dense.shape[self.dense_inner_dim]
if dense_inner != self.dense_inner_size:
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
f"{sparse_name}, got {sparse.shape}")
def add_extra_dims(x):
# Add extra leading singleton dimensions if needed
dims_needed = 4 - x.ndim
if dims_needed > 0:
singletons = [1] * dims_needed
x = x.view(*singletons, *x.shape)
elif dims_needed < 0:
raise ValueError("Tensors with more than 4 dimensions are not currently supported")
return x
# Pad shapes with leading singleton dimensions
a = add_extra_dims(a)
b = add_extra_dims(b)
return a, b

View File

@@ -16,10 +16,9 @@ def num_warps(n):
@triton.jit @triton.jit
def _forward( def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
**meta TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
): ):
TN = meta['TN']
BLOCK = meta['BLOCK']
pidhm = tl.program_id(0) pidhm = tl.program_id(0)
pidz = tl.program_id(1) pidz = tl.program_id(1)
# create index ranges # create index ranges
@@ -43,25 +42,25 @@ def _forward(
x = tl.load(px, mask=check, other=-float('inf')) x = tl.load(px, mask=check, other=-float('inf'))
x = x.to(tl.float32) x = x.to(tl.float32)
# apply scale # apply scale
if meta['APPLY_SCALE']: if APPLY_SCALE:
x = x * scale x = x * scale
# apply RPE # apply RPE
if meta['APPLY_RPE']: if APPLY_RPE:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = tl.load(prpe, mask=check, other=0) rpe = tl.load(prpe, mask=check, other=0)
x = x + rpe x = x + rpe
# apply key-padding mask # apply key-padding mask
if meta['APPLY_KP_MASK']: if APPLY_KP_MASK:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
if meta['KP_MASK_MUL']: if KP_MASK_MUL:
kp_m = tl.where(kp_m == 0, -float('inf'), 0.) kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m x = x + kp_m
# apply attention mask # apply attention mask
if meta['APPLY_ATTN_MASK']: if APPLY_ATTN_MASK:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
if meta['ATTN_MASK_MUL']: if ATTN_MASK_MUL:
attn_m = tl.where(attn_m == 0, -float('inf'), 0.) attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m x = x + attn_m
# apply causal mask # apply causal mask
@@ -75,11 +74,9 @@ def _forward(
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']}) @triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
@triton.jit @triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
pidhm = tl.program_id(0) pidhm = tl.program_id(0)
pidz = tl.program_id(1) pidz = tl.program_id(1)
TN = meta['TN']
BLOCK = meta['BLOCK']
# create index ranges # create index ranges
rxm = pidhm % BLOCK rxm = pidhm % BLOCK
rbm = pidhm // BLOCK rbm = pidhm // BLOCK
@@ -172,8 +169,7 @@ class _softmax(torch.autograd.Function):
APPLY_KP_MASK = apply_kp_mask, APPLY_KP_MASK = apply_kp_mask,
APPLY_ATTN_MASK = apply_attn_mask, APPLY_ATTN_MASK = apply_attn_mask,
KP_MASK_MUL = (kp_mask_mode == 'mul'), KP_MASK_MUL = (kp_mask_mode == 'mul'),
ATTN_MASK_MUL = (attn_mask_mode == 'mul'), ATTN_MASK_MUL = (attn_mask_mode == 'mul'))
force_nc_cache = True)
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
ctx.save_for_backward(x, lut) ctx.save_for_backward(x, lut)
@@ -196,7 +192,7 @@ class _softmax(torch.autograd.Function):
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block) _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None

View File

@@ -26,8 +26,7 @@ def num_warps(N):
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])}) @triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
@triton.jit @triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
BLOCK = meta['BLOCK']
row = tl.program_id(0) row = tl.program_id(0)
cols = tl.arange(0, BLOCK) cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row) idx = tl.load(IDX + row)
@@ -52,8 +51,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])}) @triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
@triton.jit @triton.jit
def _backward(PROBS, IDX, DPROBS, N, **meta): def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
BLOCK = meta['BLOCK']
row = tl.program_id(0) row = tl.program_id(0)
cols = tl.arange(0, BLOCK) cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row) idx = tl.load(IDX + row)

View File

@@ -26,13 +26,9 @@ def _kernel(A, B, C, M, N, K,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
LOCKS, **META): LOCKS,
# extract meta-parameters BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
BLOCK_M = META['BLOCK_M'] GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
BLOCK_N = META['BLOCK_N']
BLOCK_K = META['BLOCK_K']
GROUP_M = META['GROUP_M']
SPLIT_K = META['SPLIT_K']
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
pid_z = tl.program_id(1) pid_z = tl.program_id(1)
@@ -55,7 +51,7 @@ def _kernel(A, B, C, M, N, K,
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K*SPLIT_K): for k in range(K, 0, -BLOCK_K*SPLIT_K):
if META['EVEN_K']: if EVEN_K:
a = tl.load(A) a = tl.load(A)
b = tl.load(B) b = tl.load(B)
else: else:
@@ -113,14 +109,11 @@ class _matmul(torch.autograd.Function):
locks = _matmul._locks[device] locks = _matmul._locks[device]
# launch kernel # launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, _kernel[grid](a, b, c, M, N, K,
M, N, K,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),
locks, locks, GROUP_M=8)
GROUP_M=8)
# done
return c return c
@staticmethod @staticmethod

View File

@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel # Compute Kernel
# -------------------------- # --------------------------
from triton.language.core import constexpr
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@@ -23,9 +24,9 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
): ):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program # There are multiple 'program's processing different data. We identify which program
# we are here # we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -37,8 +38,8 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE) offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses # Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a # Load x and y from DRAM, masking out any extra elements in case
# multiple of the block size # the input is not a multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask) x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask)
output = x + y output = x + y

View File

@@ -65,11 +65,11 @@ import triton.language as tl
@triton.jit @triton.jit
def softmax_kernel( def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
): ):
# The rows of the softmax are independent, so we parallelize across those # The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0) row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row # The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each # The block size is the next power of two greater than n_cols, so we can fit each

View File

@@ -182,17 +182,13 @@ def matmul_kernel(
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
# Meta-parameters # Meta-parameters
**meta, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
): GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B. """Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N) A has shape (M, K), B has shape (K, N) and C has shape (M, N)
""" """
# extract meta-parameters
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
# ----------------------------------------------------------- # -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse # This is done in a grouped ordering to promote L2 data reuse