[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -198,8 +198,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<ir::value*> ptrs;
|
std::vector<ir::value*> ptrs;
|
||||||
for(ir::value *v: values)
|
for(ir::value *v: values)
|
||||||
for(ir::user *usr: v->get_users())
|
for(ir::user *usr: v->get_users())
|
||||||
@@ -215,7 +213,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||||
size /= shape_[i];
|
size /= shape_[i];
|
||||||
|
@@ -77,7 +77,6 @@ void coalesce::run(ir::module &mod) {
|
|||||||
builder.insert(new_x);
|
builder.insert(new_x);
|
||||||
x->replace_all_uses_with(new_x);
|
x->replace_all_uses_with(new_x);
|
||||||
new_x->replace_uses_of_with(new_x, x);
|
new_x->replace_uses_of_with(new_x, x);
|
||||||
// new_x->replace_uses_of_with(new_x, new_x);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for(ir::function *fn: mod.get_function_list())
|
for(ir::function *fn: mod.get_function_list())
|
||||||
@@ -101,6 +100,8 @@ void coalesce::run(ir::module &mod) {
|
|||||||
ir::instruction* curr = queue.back();
|
ir::instruction* curr = queue.back();
|
||||||
seen.insert(curr);
|
seen.insert(curr);
|
||||||
queue.pop_back();
|
queue.pop_back();
|
||||||
|
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
|
||||||
|
break;
|
||||||
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
|
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
|
||||||
in_contig = align_->contiguous(io_inst->get_pointer_operand());
|
in_contig = align_->contiguous(io_inst->get_pointer_operand());
|
||||||
break;
|
break;
|
||||||
|
@@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
|||||||
ofs.close();
|
ofs.close();
|
||||||
std::string cmd;
|
std::string cmd;
|
||||||
int err;
|
int err;
|
||||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o";
|
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||||
err = system(cmd.c_str());
|
err = system(cmd.c_str());
|
||||||
CUmodule ret;
|
CUmodule ret;
|
||||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||||
|
@@ -67,8 +67,8 @@ def test_matmul(M, N, K):
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta):
|
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
BLOCK_SIZE: tl.constexpr):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
@@ -99,7 +99,7 @@ def test_elementwise(N):
|
|||||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||||
x = torch.randn_like(z)
|
x = torch.randn_like(z)
|
||||||
y = torch.randn_like(z)
|
y = torch.randn_like(z)
|
||||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
|
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
|
||||||
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
||||||
|
@@ -40,7 +40,7 @@ def patch_kernel(template, to_replace):
|
|||||||
def test_empty_kernel(dtype_x, device='cuda'):
|
def test_empty_kernel(dtype_x, device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, **meta):
|
def kernel(X, SIZE: tl.constexpr):
|
||||||
pass
|
pass
|
||||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||||
@@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
|||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, **meta):
|
def kernel(Z, X, SIZE: tl.constexpr):
|
||||||
off = tl.arange(0, meta['SIZE'])
|
off = tl.arange(0, SIZE)
|
||||||
x = tl.load(X + off)
|
x = tl.load(X + off)
|
||||||
z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
tl.store(Z + off, z)
|
tl.store(Z + off, z)
|
||||||
@@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
|
|||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, Y, **meta):
|
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
||||||
off = tl.arange(0, meta['SIZE'])
|
off = tl.arange(0, SIZE)
|
||||||
x = tl.load(X + off)
|
x = tl.load(X + off)
|
||||||
y = tl.load(Y + off)
|
y = tl.load(Y + off)
|
||||||
z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
@@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'):
|
|||||||
|
|
||||||
# Triton kernel
|
# Triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, **meta):
|
def kernel(Z, X, SIZE: tl.constexpr):
|
||||||
SIZE = meta['SIZE']
|
|
||||||
m = tl.arange(0, SIZE)
|
m = tl.arange(0, SIZE)
|
||||||
n = tl.arange(0, SIZE)
|
n = tl.arange(0, SIZE)
|
||||||
x = tl.load(X_PTR_EXPR)
|
x = tl.load(X_PTR_EXPR)
|
||||||
@@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
|||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, **meta):
|
def kernel(X, Z):
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
x = tl.load(X + pid)
|
x = tl.load(X + pid)
|
||||||
old = GENERATE_TEST_HERE
|
old = GENERATE_TEST_HERE
|
||||||
@@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, **meta):
|
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||||
x = tl.load(X)
|
x = tl.load(X)
|
||||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
|
||||||
tl.store(Z, z)
|
tl.store(Z, z)
|
||||||
|
|
||||||
# triton result
|
# triton result
|
||||||
@@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'):
|
|||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, **meta):
|
def kernel(X, Z, BLOCK: tl.constexpr):
|
||||||
x = tl.load(X + tl.arange(0, meta['BLOCK']))
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
tl.store(Z, tl.sum(x, axis=0))
|
tl.store(Z, tl.sum(x, axis=0))
|
||||||
|
|
||||||
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
||||||
@@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
|
|||||||
dtype = cvt[dtype]
|
dtype = cvt[dtype]
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, **meta):
|
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||||
range_m = tl.arange(0, meta['BLOCK_M'])
|
range_m = tl.arange(0, BLOCK_M)
|
||||||
range_n = tl.arange(0, meta['BLOCK_N'])
|
range_n = tl.arange(0, BLOCK_N)
|
||||||
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
|
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
|
||||||
z = tl.sum(x, axis=meta['AXIS'])
|
z = tl.sum(x, axis=AXIS)
|
||||||
tl.store(Z + range_m, z)
|
tl.store(Z + range_m, z)
|
||||||
# input
|
# input
|
||||||
x = triton.testing.random(shape, dtype=dtype, device=device)
|
x = triton.testing.random(shape, dtype=dtype, device=device)
|
||||||
@@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
|||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, stride_xm, stride_xn,
|
def kernel(X, stride_xm, stride_xn,
|
||||||
Z, stride_zm, stride_zn, **meta):
|
Z, stride_zm, stride_zn,
|
||||||
BLOCK_M = meta['BLOCK_M']
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||||
BLOCK_N = meta['BLOCK_N']
|
|
||||||
off_m = tl.arange(0, BLOCK_M)
|
off_m = tl.arange(0, BLOCK_M)
|
||||||
off_n = tl.arange(0, BLOCK_N)
|
off_n = tl.arange(0, BLOCK_N)
|
||||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||||
@@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, stride_xm, stride_xk,
|
def kernel(X, stride_xm, stride_xk,
|
||||||
Y, stride_yk, stride_yn,
|
Y, stride_yk, stride_yn,
|
||||||
Z, stride_zm, stride_zn, **meta):
|
Z, stride_zm, stride_zn,
|
||||||
BLOCK_M = meta['BLOCK_M']
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
BLOCK_K = meta['BLOCK_K']
|
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
|
||||||
BLOCK_N = meta['BLOCK_N']
|
|
||||||
off_m = tl.arange(0, BLOCK_M)
|
off_m = tl.arange(0, BLOCK_M)
|
||||||
off_n = tl.arange(0, BLOCK_N)
|
off_n = tl.arange(0, BLOCK_N)
|
||||||
off_k = tl.arange(0, BLOCK_K)
|
off_k = tl.arange(0, BLOCK_K)
|
||||||
@@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||||
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
||||||
if meta['ADD_MATRIX']:
|
if ADD_MATRIX:
|
||||||
z += tl.load(Zs)
|
z += tl.load(Zs)
|
||||||
if meta['ADD_ROWS']:
|
if ADD_ROWS:
|
||||||
ZRs = Z + off_m * stride_zm
|
ZRs = Z + off_m * stride_zm
|
||||||
z += tl.load(ZRs)[:, None]
|
z += tl.load(ZRs)[:, None]
|
||||||
if meta['ADD_COLS']:
|
if ADD_COLS:
|
||||||
ZCs = Z + off_n * stride_zn
|
ZCs = Z + off_n * stride_zn
|
||||||
z += tl.load(ZCs)[None, :]
|
z += tl.load(ZCs)[None, :]
|
||||||
tl.store(Zs, z)
|
tl.store(Zs, z)
|
||||||
@@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
|
|
||||||
def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(out, **meta):
|
def kernel(out):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
a = tl.zeros((32, 32), tl.float32)
|
a = tl.zeros((32, 32), tl.float32)
|
||||||
b = tl.zeros((32, 32), tl.float32)
|
b = tl.zeros((32, 32), tl.float32)
|
||||||
@@ -538,9 +535,10 @@ def test_arange(start, device='cuda'):
|
|||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(z, **meta):
|
def _kernel(z, BLOCK: tl.constexpr,
|
||||||
off = tl.arange(0, meta['BLOCK'])
|
START: tl.constexpr, END: tl.constexpr):
|
||||||
val = tl.arange(meta['START'], meta['END'])
|
off = tl.arange(0, BLOCK)
|
||||||
|
val = tl.arange(START, END)
|
||||||
tl.store(z + off, val)
|
tl.store(z + off, val)
|
||||||
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||||
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||||
@@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
||||||
in_stride, in2_stride, out_stride,
|
in_stride, in2_stride, out_stride,
|
||||||
in_numel, in2_numel, out_numel, **meta):
|
in_numel, in2_numel, out_numel,
|
||||||
M = meta['M']
|
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
|
||||||
N = meta['N']
|
|
||||||
K = meta['K']
|
|
||||||
|
|
||||||
M_offsets = tl.arange(0, M)
|
M_offsets = tl.arange(0, M)
|
||||||
N_offsets = tl.arange(0, N)
|
N_offsets = tl.arange(0, N)
|
||||||
@@ -605,14 +601,13 @@ def test_load_cache_modifier(cache):
|
|||||||
dst = torch.empty(128, device='cuda')
|
dst = torch.empty(128, device='cuda')
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(dst, src, **meta):
|
def _kernel(dst, src, CACHE: tl.constexpr):
|
||||||
offsets = tl.arange(0, 128)
|
offsets = tl.arange(0, 128)
|
||||||
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
|
x = tl.load(src+offsets, cache_modifier=CACHE)
|
||||||
tl.store(dst+offsets, x)
|
tl.store(dst+offsets, x)
|
||||||
|
|
||||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||||
ptx = pgm.asm['ptx']
|
ptx = pgm.asm['ptx']
|
||||||
|
|
||||||
if cache == '':
|
if cache == '':
|
||||||
assert 'ld.global.ca' not in ptx
|
assert 'ld.global.ca' not in ptx
|
||||||
assert 'ld.global.cg' not in ptx
|
assert 'ld.global.cg' not in ptx
|
||||||
@@ -644,7 +639,7 @@ def test_load_cache_modifier(cache):
|
|||||||
#----------------
|
#----------------
|
||||||
def test_noop(device='cuda'):
|
def test_noop(device='cuda'):
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(**meta):
|
def kernel(x):
|
||||||
pass
|
pass
|
||||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||||
kernel[(1, )](x)
|
kernel[(1, )](x)
|
@@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
|||||||
}[MODE]
|
}[MODE]
|
||||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||||
# triton result
|
# triton result
|
||||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||||
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
||||||
@@ -151,8 +151,8 @@ def triton_attention(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||||
layout,
|
layout,
|
||||||
block,
|
block,
|
||||||
|
@@ -66,8 +66,8 @@ import torch
|
|||||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
# nuke kernel decorators -- will set meta-parameters manually
|
# nuke kernel decorators -- will set meta-parameters manually
|
||||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||||
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
|
||||||
kernel = triton.ops._matmul.kernel
|
kernel = triton.ops._matmul.kernel
|
||||||
decorators = kernel.kernel_decorators
|
decorators = kernel.kernel_decorators
|
||||||
kernel.kernel_decorators = []
|
kernel.kernel_decorators = []
|
||||||
|
@@ -103,7 +103,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
arg_values = []
|
arg_values = []
|
||||||
for i, arg_name in enumerate(arg_names):
|
for i, arg_name in enumerate(arg_names):
|
||||||
if i in self.constants:
|
if i in self.constants:
|
||||||
cst = triton.language.core._to_ir(self.constants[i], self.builder)
|
cst = self.constants[i]
|
||||||
|
if not isinstance(cst, triton.language.constexpr):
|
||||||
|
cst = triton.language.constexpr(self.constants[i])
|
||||||
arg_values.append(cst)
|
arg_values.append(cst)
|
||||||
else:
|
else:
|
||||||
if i in self.attributes:
|
if i in self.attributes:
|
||||||
@@ -114,6 +116,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
fn.add_attr(i + 1, attr)
|
fn.add_attr(i + 1, attr)
|
||||||
fn.args[i].name = arg_name
|
fn.args[i].name = arg_name
|
||||||
arg_values.append(fn.args[i])
|
arg_values.append(fn.args[i])
|
||||||
|
|
||||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||||
self.set_value(arg_name, arg_value)
|
self.set_value(arg_name, arg_value)
|
||||||
if inline:
|
if inline:
|
||||||
@@ -139,6 +142,22 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.NodeVisitor.generic_visit(self, node)
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
return node.arg
|
return node.arg
|
||||||
|
|
||||||
|
def visit_AnnAssign(self, node):
|
||||||
|
# extract attributes
|
||||||
|
annotation = self.visit(node.annotation)
|
||||||
|
target = self.visit(node.target)
|
||||||
|
value = self.visit(node.value)
|
||||||
|
# constexpr
|
||||||
|
if annotation == triton.language.constexpr:
|
||||||
|
if target in self.lscope:
|
||||||
|
raise ValueError(f'{target} is already defined.'
|
||||||
|
f' constexpr cannot be reassigned.')
|
||||||
|
self.lscope[target] = triton.language.constexpr(value)
|
||||||
|
return self.lscope[target]
|
||||||
|
# default: call visit_Assign
|
||||||
|
return self.visit_Assign(node)
|
||||||
|
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
_names = []
|
_names = []
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
@@ -151,6 +170,9 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if not isinstance(values, tuple):
|
if not isinstance(values, tuple):
|
||||||
values = [values]
|
values = [values]
|
||||||
for name, value in zip(names, values):
|
for name, value in zip(names, values):
|
||||||
|
# by default, constexpr are assigned into python variable
|
||||||
|
if isinstance(value, triton.language.constexpr):
|
||||||
|
value = value.value
|
||||||
if not isinstance(value, triton.language.block):
|
if not isinstance(value, triton.language.block):
|
||||||
value = triton.language.core._to_ir(value, self.builder)
|
value = triton.language.core._to_ir(value, self.builder)
|
||||||
self.set_value(name, value)
|
self.set_value(name, value)
|
||||||
@@ -181,6 +203,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
def visit_BinOp(self, node):
|
def visit_BinOp(self, node):
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
rhs = self.visit(node.right)
|
rhs = self.visit(node.right)
|
||||||
|
if isinstance(lhs, triton.language.core.constexpr):
|
||||||
|
lhs = lhs.value
|
||||||
|
if isinstance(rhs, triton.language.core.constexpr):
|
||||||
|
rhs = rhs.value
|
||||||
fn = {
|
fn = {
|
||||||
ast.Add: '__add__',
|
ast.Add: '__add__',
|
||||||
ast.Sub: '__sub__',
|
ast.Sub: '__sub__',
|
||||||
@@ -195,17 +221,13 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.BitOr: '__or__',
|
ast.BitOr: '__or__',
|
||||||
ast.BitXor: '__xor__',
|
ast.BitXor: '__xor__',
|
||||||
}[type(node.op)]
|
}[type(node.op)]
|
||||||
kws = dict()
|
|
||||||
|
|
||||||
if self.is_triton_object(lhs):
|
if self.is_triton_object(lhs):
|
||||||
kws['_builder'] = self.builder
|
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||||
ret = getattr(lhs, fn)(rhs, **kws)
|
elif self.is_triton_object(rhs):
|
||||||
if ret is NotImplemented:
|
|
||||||
if self.is_triton_object(rhs):
|
|
||||||
kws['_builder'] = self.builder
|
|
||||||
fn = fn[:2] + 'r' + fn[2:]
|
fn = fn[:2] + 'r' + fn[2:]
|
||||||
ret = getattr(rhs, fn)(lhs, **kws)
|
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||||
return ret
|
else:
|
||||||
|
return getattr(lhs, fn)(rhs)
|
||||||
|
|
||||||
def visit_If(self, node):
|
def visit_If(self, node):
|
||||||
cond = self.visit(node.test)
|
cond = self.visit(node.test)
|
||||||
@@ -254,6 +276,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
assert len(node.ops) == 1
|
assert len(node.ops) == 1
|
||||||
lhs = self.visit(node.left)
|
lhs = self.visit(node.left)
|
||||||
rhs = self.visit(node.comparators[0])
|
rhs = self.visit(node.comparators[0])
|
||||||
|
if isinstance(lhs, triton.language.core.constexpr):
|
||||||
|
lhs = lhs.value
|
||||||
|
if isinstance(rhs, triton.language.core.constexpr):
|
||||||
|
rhs = rhs.value
|
||||||
fn = {
|
fn = {
|
||||||
ast.Eq: '__eq__',
|
ast.Eq: '__eq__',
|
||||||
ast.NotEq: '__ne__',
|
ast.NotEq: '__ne__',
|
||||||
@@ -274,6 +300,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def visit_UnaryOp(self, node):
|
def visit_UnaryOp(self, node):
|
||||||
op = self.visit(node.operand)
|
op = self.visit(node.operand)
|
||||||
|
if isinstance(op, triton.language.core.constexpr):
|
||||||
|
op = op.value
|
||||||
fn = {
|
fn = {
|
||||||
ast.USub: '__neg__',
|
ast.USub: '__neg__',
|
||||||
ast.UAdd: '__pos__',
|
ast.UAdd: '__pos__',
|
||||||
@@ -394,7 +422,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return fn(*args, **kws)
|
return fn(*args, **kws)
|
||||||
|
|
||||||
def visit_Num(self, node):
|
def visit_Num(self, node):
|
||||||
return node.n
|
return triton.language.constexpr(node.n)
|
||||||
|
|
||||||
def visit_Attribute(self, node):
|
def visit_Attribute(self, node):
|
||||||
lhs = self.visit(node.value)
|
lhs = self.visit(node.value)
|
||||||
@@ -477,6 +505,8 @@ class Kernel:
|
|||||||
}
|
}
|
||||||
if hasattr(obj, 'data_ptr'):
|
if hasattr(obj, 'data_ptr'):
|
||||||
return type_names[obj.dtype]
|
return type_names[obj.dtype]
|
||||||
|
if isinstance(obj, triton.language.core.constexpr):
|
||||||
|
obj = obj.value
|
||||||
if isinstance(obj, int):
|
if isinstance(obj, int):
|
||||||
if abs(obj) <= 0xffffffff:
|
if abs(obj) <= 0xffffffff:
|
||||||
return 'I'
|
return 'I'
|
||||||
@@ -485,6 +515,8 @@ class Kernel:
|
|||||||
return 'f'
|
return 'f'
|
||||||
if isinstance(obj, bool):
|
if isinstance(obj, bool):
|
||||||
return 'B'
|
return 'B'
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return 'str'
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
@@ -537,7 +569,8 @@ class Kernel:
|
|||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
||||||
|
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
|
||||||
# create IR module
|
# create IR module
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
@@ -547,7 +580,7 @@ class Kernel:
|
|||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self.fn into code-generator object
|
# export symbols visible from self.fn into code-generator object
|
||||||
gscope = sys.modules[self.fn.module].__dict__
|
gscope = sys.modules[self.fn.module].__dict__
|
||||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=meta)
|
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||||
try:
|
try:
|
||||||
generator.visit(self.fn.parse())
|
generator.visit(self.fn.parse())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -566,7 +599,19 @@ class Kernel:
|
|||||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||||
|
# handle arguments passed by name
|
||||||
|
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||||
|
wargs = list(wargs)
|
||||||
|
for i, pos in enumerate(sorted(kwargs)):
|
||||||
|
wargs.insert(pos + i, kwargs[pos])
|
||||||
|
if len(wargs) != len(self.fn.arg_names):
|
||||||
|
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||||
|
# handle annotations
|
||||||
|
for name, type in self.fn.__annotations__.items():
|
||||||
|
pos = self.fn.arg_names.index(name)
|
||||||
|
assert type == triton.language.core.constexpr
|
||||||
|
wargs[pos] = type(wargs[pos])
|
||||||
# device inference
|
# device inference
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
if len(tensor_idxs) == 0:
|
if len(tensor_idxs) == 0:
|
||||||
@@ -601,18 +646,19 @@ class Kernel:
|
|||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||||
|
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||||
|
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||||
|
|
||||||
# compute hash for caching this kernel
|
# compute hash for caching this kernel
|
||||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||||
attr_key = tuple(attributes.items())
|
attr_key = tuple(attributes.items())
|
||||||
meta_key = tuple(sorted(meta.items()))
|
|
||||||
const_key = tuple(constants.items())
|
const_key = tuple(constants.items())
|
||||||
compute_capability = torch.cuda.get_device_capability(device)
|
compute_capability = torch.cuda.get_device_capability(device)
|
||||||
|
|
||||||
key = (
|
key = (
|
||||||
self.fn.cache_key, version_key(), compute_capability,
|
self.fn.cache_key, version_key(), compute_capability,
|
||||||
types_key, attr_key, num_warps, num_stages, meta_key, const_key
|
types_key, attr_key, num_warps, num_stages, const_key
|
||||||
)
|
)
|
||||||
key = repr(key)
|
key = repr(key)
|
||||||
|
|
||||||
@@ -644,7 +690,7 @@ class Kernel:
|
|||||||
binary = self._compile(
|
binary = self._compile(
|
||||||
*wargs, device=device_idx, attributes=attributes,
|
*wargs, device=device_idx, attributes=attributes,
|
||||||
num_warps=num_warps, num_stages=num_stages,
|
num_warps=num_warps, num_stages=num_stages,
|
||||||
constants=constants, **meta
|
constants=constants,
|
||||||
)
|
)
|
||||||
if bin_cache_path:
|
if bin_cache_path:
|
||||||
assert bin_lock_path is not None
|
assert bin_lock_path is not None
|
||||||
@@ -657,12 +703,15 @@ class Kernel:
|
|||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
|
||||||
params = struct.pack(fmt, *args)
|
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
|
||||||
# enqueue cached function into stream
|
# enqueue cached function into stream
|
||||||
callable = drv_cache[key]
|
callable = drv_cache[key]
|
||||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
|
||||||
|
grid = grid(csts) if hasattr(grid, '__call__') else grid
|
||||||
|
if isinstance(grid, int):
|
||||||
|
grid = tuple(grid)
|
||||||
callable(stream, params, *grid)
|
callable(stream, params, *grid)
|
||||||
return callable
|
return callable
|
||||||
|
|
||||||
@@ -697,31 +746,31 @@ class Autotuner:
|
|||||||
def _bench(self, *args, config, **meta):
|
def _bench(self, *args, config, **meta):
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
# as kwargs and by the autotuner
|
# as kwargs and by the autotuner
|
||||||
conflicts = meta.keys() & config.meta.keys()
|
conflicts = meta.keys() & config.kwargs.keys()
|
||||||
if conflicts:
|
if conflicts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||||
" Make sure that you don't re-define auto-tuned symbols."
|
" Make sure that you don't re-define auto-tuned symbols."
|
||||||
)
|
)
|
||||||
# augment meta-parameters with tunable ones
|
# augment meta-parameters with tunable ones
|
||||||
current = dict(meta, **config.meta)
|
current = dict(meta, **config.kwargs)
|
||||||
def kernel_call():
|
def kernel_call():
|
||||||
self.hook(args)
|
self.hook(args)
|
||||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
return triton.testing.do_bench(kernel_call)
|
return triton.testing.do_bench(kernel_call)
|
||||||
|
|
||||||
def __call__(self, *args, **meta):
|
def __call__(self, *args, **kwargs):
|
||||||
if len(self.configs) > 1:
|
if len(self.configs) > 1:
|
||||||
key = tuple([args[i] for i in self.key_idx])
|
key = tuple([args[i] for i in self.key_idx])
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
timings = {config: self._bench(*args, config=config, **meta) \
|
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||||
for config in self.configs}
|
for config in self.configs}
|
||||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||||
self.hook(args)
|
self.hook(args)
|
||||||
config = self.cache[key]
|
config = self.cache[key]
|
||||||
else:
|
else:
|
||||||
config = self.configs[0]
|
config = self.configs[0]
|
||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
@@ -769,6 +818,8 @@ class JITFunction:
|
|||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
self.kernel_decorators = []
|
self.kernel_decorators = []
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
|
# annotations
|
||||||
|
self.__annotations__ = fn.__annotations__
|
||||||
# forward docs
|
# forward docs
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
|
|
||||||
@@ -839,8 +890,8 @@ class Config:
|
|||||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||||
:type num_stages: int
|
:type num_stages: int
|
||||||
"""
|
"""
|
||||||
def __init__(self, meta, num_warps=4, num_stages=2):
|
def __init__(self, kwargs, num_warps=4, num_stages=2):
|
||||||
self.meta = meta
|
self.kwargs = kwargs
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
|
|
||||||
|
@@ -14,9 +14,11 @@ def _to_ir(x, builder):
|
|||||||
return builder.get_int64(x)
|
return builder.get_int64(x)
|
||||||
elif isinstance(x, float):
|
elif isinstance(x, float):
|
||||||
return builder.get_float32(x)
|
return builder.get_float32(x)
|
||||||
if isinstance(x, block):
|
elif isinstance(x, constexpr):
|
||||||
|
return _to_ir(x.value, builder)
|
||||||
|
elif isinstance(x, block):
|
||||||
return x.handle
|
return x.handle
|
||||||
if isinstance(x, dtype):
|
elif isinstance(x, dtype):
|
||||||
return x.handle(builder)
|
return x.handle(builder)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -257,6 +259,86 @@ class block:
|
|||||||
return frontend.cast(self, dtype, _builder)
|
return frontend.cast(self, dtype, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# constexpr
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
class constexpr:
|
||||||
|
"""
|
||||||
|
This class is used to store a value that is known at compile-time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
return self.value + other.value
|
||||||
|
|
||||||
|
def __radd__(self, other):
|
||||||
|
return other.value + self.value
|
||||||
|
|
||||||
|
def __sub__(self, other):
|
||||||
|
return self.value - other.value
|
||||||
|
|
||||||
|
def __rsub__(self, other):
|
||||||
|
return other.value - self.value
|
||||||
|
|
||||||
|
def __mul__(self, other):
|
||||||
|
return self.value * other.value
|
||||||
|
|
||||||
|
def __rmul__(self, other):
|
||||||
|
return other.value * self.value
|
||||||
|
|
||||||
|
def __truediv__(self, other):
|
||||||
|
return self.value / other.value
|
||||||
|
|
||||||
|
def __rtruediv__(self, other):
|
||||||
|
return other.value / self.value
|
||||||
|
|
||||||
|
def __floordiv__(self, other):
|
||||||
|
return self.value // other.value
|
||||||
|
|
||||||
|
def __rfloordiv__(self, other):
|
||||||
|
return other.value // self.value
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
return self.value > other.value
|
||||||
|
|
||||||
|
def __rgt__(self, other):
|
||||||
|
return other.value > self.value
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
return self.value >= other.value
|
||||||
|
|
||||||
|
def __rge__(self, other):
|
||||||
|
return other.value >= self.value
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
|
def __rlt__(self, other):
|
||||||
|
return other.value < self.value
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
return self.value <= other.value
|
||||||
|
|
||||||
|
def __rle__(self, other):
|
||||||
|
return other.value <= self.value
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.value == other.value
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return self.value != other.value
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.value)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwds):
|
||||||
|
return self.value(*args, **kwds)
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# SPMD Programming Model
|
# SPMD Programming Model
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@@ -312,7 +394,12 @@ def zeros(shape, dtype, _builder=None):
|
|||||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||||
:type dtype: DType
|
:type dtype: DType
|
||||||
"""
|
"""
|
||||||
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
|
for i, d in enumerate(shape):
|
||||||
|
if not isinstance(d, constexpr):
|
||||||
|
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||||
|
if not isinstance(d.value, int):
|
||||||
|
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||||
|
shape = [x.value for x in shape]
|
||||||
return frontend.zeros(shape, dtype, _builder)
|
return frontend.zeros(shape, dtype, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton._C.libtriton as libtriton
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# ********************************************************
|
# ********************************************************
|
||||||
@@ -21,54 +20,46 @@ def _sdd_kernel(
|
|||||||
stride_za, stride_ha, stride_ma, stride_ak,
|
stride_za, stride_ha, stride_ma, stride_ak,
|
||||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||||
K, grid_offset, lut, **meta
|
K, grid_offset, lut,
|
||||||
|
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||||
|
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||||
):
|
):
|
||||||
TILE_M = meta['TILE_M']
|
|
||||||
TILE_N = meta['TILE_N']
|
|
||||||
TILE_K = meta['TILE_K']
|
|
||||||
BLOCK = meta['BLOCK']
|
|
||||||
#------------#
|
#------------#
|
||||||
#- Prologue -#
|
#- Prologue -#
|
||||||
#------------#
|
#------------#
|
||||||
pid1 = tl.program_id(1) + grid_offset
|
block_id = tl.program_id(1) + grid_offset
|
||||||
blockidm = tl.arange(0, TILE_M) // BLOCK
|
lut += block_id * 3
|
||||||
blockidn = tl.arange(0, TILE_N) // BLOCK
|
# offsets
|
||||||
offlutm = blockidm * (TILE_N // BLOCK) * 4
|
off_z = tl.program_id(2) # batch
|
||||||
offlutn = blockidn * 4
|
off_h = tl.load(lut + 0) # head
|
||||||
header = lut + pid1 * (TILE_M // BLOCK) * (TILE_N // BLOCK) * 4
|
|
||||||
# batch offset
|
|
||||||
off_z = tl.program_id(2)
|
|
||||||
# head offset
|
|
||||||
off_h = tl.load(header + 0)
|
|
||||||
# initialize pointers to A
|
# initialize pointers to A
|
||||||
start_am = tl.load(header + 1 + offlutm)
|
start_am = tl.load(lut + 1)
|
||||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||||
offs_ak = tl.arange(0, TILE_K)
|
offs_ak = tl.arange(0, TILE_K)
|
||||||
a_ptrs = A + off_z * stride_za \
|
a_ptrs = A + (off_z * stride_za \
|
||||||
+ off_h * stride_ha \
|
+ off_h * stride_ha \
|
||||||
+ offs_am[:, None] * stride_ma \
|
+ offs_am[:, None] * stride_ma \
|
||||||
+ offs_ak[None, :] * stride_ak
|
+ offs_ak[None, :] * stride_ak)
|
||||||
# initialize pointers to B
|
# initialize pointers to B
|
||||||
start_bn = tl.load(header + 2 + offlutn)
|
start_bn = tl.load(lut + 2)
|
||||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||||
offs_bk = tl.arange(0, TILE_K)
|
offs_bk = tl.arange(0, TILE_K)
|
||||||
b_ptrs = B + off_z * stride_zb \
|
b_ptrs = B + (off_z * stride_zb \
|
||||||
+ off_h * stride_hb \
|
+ off_h * stride_hb \
|
||||||
+ offs_bn[None, :] * stride_nb \
|
+ offs_bn[None, :] * stride_nb \
|
||||||
+ offs_bk[:, None] * stride_bk
|
+ offs_bk[:, None] * stride_bk)
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
## Inner Loop ##
|
## Inner Loop ##
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||||
for k in range(K, 0, -TILE_K):
|
for k in range(K, 0, -TILE_K):
|
||||||
if meta['EVEN_K']:
|
if EVEN_K:
|
||||||
a = tl.load(a_ptrs)
|
a = tl.load(a_ptrs)
|
||||||
b = tl.load(b_ptrs)
|
b = tl.load(b_ptrs)
|
||||||
else:
|
else:
|
||||||
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
|
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
|
||||||
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
|
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
|
||||||
a = tl.load(a_ptrs)
|
|
||||||
b = tl.load(b_ptrs)
|
|
||||||
acc += tl.dot(a, b)
|
acc += tl.dot(a, b)
|
||||||
a_ptrs += TILE_K * stride_ak
|
a_ptrs += TILE_K * stride_ak
|
||||||
b_ptrs += TILE_K * stride_bk
|
b_ptrs += TILE_K * stride_bk
|
||||||
@@ -76,22 +67,15 @@ def _sdd_kernel(
|
|||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
## Epilogue ##
|
## Epilogue ##
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
blockidm = tl.arange(0, TILE_M) // BLOCK
|
|
||||||
blockidn = tl.arange(0, TILE_N) // BLOCK
|
|
||||||
offlutm = blockidm * (TILE_N // BLOCK) * 4
|
|
||||||
offlutn = blockidn * 4
|
|
||||||
off_block_id = 3 + offlutm[:, None] + offlutn[None, :]
|
|
||||||
block_id = tl.load(header + off_block_id)
|
|
||||||
# initialize pointers to C
|
|
||||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||||
pc = C + off_z * stride_zc \
|
pc = C + (off_z * stride_zc \
|
||||||
+ block_id * stride_hc \
|
+ block_id * stride_hc \
|
||||||
+ offs_cm[:, None] * stride_mc \
|
+ offs_cm[:, None] * stride_mc \
|
||||||
+ offs_cn[None, :] * stride_nc
|
+ offs_cn[None, :] * stride_nc)
|
||||||
tl.store(pc, c, mask=True)
|
tl.store(pc, c, mask=True)
|
||||||
|
|
||||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
|
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out = None):
|
||||||
# (A * B)^T = B^T * A^T
|
# (A * B)^T = B^T * A^T
|
||||||
if trans_c:
|
if trans_c:
|
||||||
a, b = b, a
|
a, b = b, a
|
||||||
@@ -102,46 +86,28 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks,
|
|||||||
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
|
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
|
||||||
if Ka != Kb:
|
if Ka != Kb:
|
||||||
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
|
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
|
||||||
if Ka % 16 != 0:
|
|
||||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
|
||||||
# allocate output
|
# allocate output
|
||||||
n_blocks = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
if out is None:
|
||||||
c = torch.zeros((a.shape[0], n_blocks, block, block), dtype=a.dtype, device=a.device)
|
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
|
||||||
# each iteration of the loop below
|
else:
|
||||||
# computes the value for one group of super-blocks
|
assert out.shape == (a.shape[0], lut.shape[0], block, block)
|
||||||
# (e.g., all 4x4 super-blocks)
|
c = out
|
||||||
for lut, width, pack in zip(luts, widths, packs):
|
grid = [1, c.shape[1], c.shape[0]]
|
||||||
# maximum grid size in Triton/CUDA is 64k but we may have more
|
_sdd_kernel[grid](
|
||||||
# super-blocks than that.
|
a, b, c,
|
||||||
max_grid = 65535
|
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||||
for off_grid in range(0, width, max_grid):
|
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||||
grid = [1, min(max_grid, width - off_grid), c.shape[0]]
|
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||||
# fmt: off
|
Ka, 0, lut,
|
||||||
pgm = _sdd_kernel[grid](
|
TILE_M = block, TILE_N = block, TILE_K = 32, BLOCK = block, num_stages=4,
|
||||||
a, b, c,
|
num_warps=4,
|
||||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
)
|
||||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
|
||||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
|
||||||
Ka, off_grid, lut,
|
|
||||||
TILE_M = block*pack, TILE_N = block*pack, TILE_K = 32, BLOCK = block, num_stages=3,
|
|
||||||
num_warps=4,
|
|
||||||
)
|
|
||||||
# print(pgm.asm['ptx'])
|
|
||||||
# exit()
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
def sdd_lut(layout, block, device):
|
def sdd_lut(layout, block, device):
|
||||||
start_width = 128 // block
|
lut = layout.nonzero(as_tuple=False).to(device).int()
|
||||||
layout = layout.type(torch.int32)
|
return lut, None
|
||||||
superblocks = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2], start_width)
|
|
||||||
luts, widths, packs = [], [], []
|
|
||||||
for size, nnz in superblocks:
|
|
||||||
nnz = nnz.reshape(-1, 4)
|
|
||||||
width = nnz.shape[0] // (size * size)
|
|
||||||
luts.append(torch.from_numpy(nnz).type(torch.int32).to(device))
|
|
||||||
widths.append(width)
|
|
||||||
packs.append(size)
|
|
||||||
return luts, None, widths, packs
|
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Dense = Sparse x Dense (DSD)
|
# Dense = Sparse x Dense (DSD)
|
||||||
@@ -154,12 +120,10 @@ def _dsd_kernel(
|
|||||||
stride_az, stride_ha, stride_am, stride_ak,
|
stride_az, stride_ha, stride_am, stride_ak,
|
||||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||||
DS0, DS1, lut, **meta
|
DS0, DS1, lut,
|
||||||
|
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||||
):
|
):
|
||||||
TILE_M = meta['TILE_M']
|
|
||||||
TILE_N = meta['TILE_N']
|
|
||||||
TILE_K = meta['TILE_K']
|
|
||||||
GROUP_SIZE_M = meta['GROUP_SIZE_M']
|
|
||||||
#------------#
|
#------------#
|
||||||
#- Prologue -#
|
#- Prologue -#
|
||||||
#------------#
|
#------------#
|
||||||
@@ -167,9 +131,9 @@ def _dsd_kernel(
|
|||||||
pid_n = tl.program_id(1)
|
pid_n = tl.program_id(1)
|
||||||
num_pid_m = tl.num_programs(0)
|
num_pid_m = tl.num_programs(0)
|
||||||
num_pid_n = tl.num_programs(1)
|
num_pid_n = tl.num_programs(1)
|
||||||
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||||
pidz = tl.program_id(2)
|
pidz = tl.program_id(2)
|
||||||
header = lut + pid_m * 4
|
header = lut + pid_n * 4
|
||||||
offset = tl.load(header + 0)
|
offset = tl.load(header + 0)
|
||||||
K = tl.load(header + 1)
|
K = tl.load(header + 1)
|
||||||
column = tl.load(header + 2)
|
column = tl.load(header + 2)
|
||||||
@@ -185,7 +149,8 @@ def _dsd_kernel(
|
|||||||
+ offs_am[:, None] * stride_am \
|
+ offs_am[:, None] * stride_am \
|
||||||
+ offs_ak[None, :] * stride_ak
|
+ offs_ak[None, :] * stride_ak
|
||||||
# initialize pointers to B (dense)
|
# initialize pointers to B (dense)
|
||||||
offs_bn = pid_n*TILE_N + tl.arange(0, TILE_N)
|
offs_bn = pid_m*TILE_N + tl.arange(0, TILE_N)
|
||||||
|
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
|
||||||
start_bk = tl.load(pinc)
|
start_bk = tl.load(pinc)
|
||||||
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
|
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
|
||||||
offs_bk = start_bk + tl.arange(0, TILE_K)
|
offs_bk = start_bk + tl.arange(0, TILE_K)
|
||||||
@@ -197,28 +162,33 @@ def _dsd_kernel(
|
|||||||
## Inner Loop ##
|
## Inner Loop ##
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||||
|
pinc += 2
|
||||||
|
inc_a = tl.load(pinc + 1)
|
||||||
|
inc_a = tl.multiple_of(inc_a, 8)
|
||||||
|
inc_b = tl.load(pinc)
|
||||||
|
inc_b = tl.multiple_of(inc_b, 8)
|
||||||
for k in range(K, 0, -TILE_K):
|
for k in range(K, 0, -TILE_K):
|
||||||
a = tl.load(pa, mask=True)
|
a = tl.load(pa, mask=True)
|
||||||
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
|
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
|
||||||
acc += tl.dot(a, b)
|
acc += tl.dot(a, b)
|
||||||
|
pa += inc_a
|
||||||
|
pb += inc_b*stride_bk
|
||||||
pinc += 2
|
pinc += 2
|
||||||
inc_a = tl.load(pinc + 1)
|
inc_a = tl.load(pinc + 1)
|
||||||
inc_a = tl.multiple_of(inc_a, 8)
|
inc_a = tl.multiple_of(inc_a, 8)
|
||||||
inc_b = tl.load(pinc)
|
inc_b = tl.load(pinc)
|
||||||
inc_b = tl.multiple_of(inc_b, 8)
|
inc_b = tl.multiple_of(inc_b, 8)
|
||||||
pa += inc_a
|
|
||||||
pb += inc_b*stride_bk
|
|
||||||
c = acc.to(C.dtype.element_ty)
|
c = acc.to(C.dtype.element_ty)
|
||||||
# initialize pointers to C
|
# initialize pointers to C
|
||||||
offs_cm = column*TILE_M + tl.arange(0, TILE_M)
|
offs_cm = column*TILE_M + tl.arange(0, TILE_M)
|
||||||
offs_cn = pid_n*TILE_N + tl.arange(0, TILE_N)
|
offs_cn = pid_m*TILE_N + tl.arange(0, TILE_N)
|
||||||
pc = C + off_h * stride_hc \
|
pc = C + off_h * stride_hc \
|
||||||
+ pidz * stride_zc \
|
+ pidz * stride_zc \
|
||||||
+ offs_cm[:, None] * stride_cm \
|
+ offs_cm[:, None] * stride_cm \
|
||||||
+ offs_cn[None, :] * stride_cn
|
+ offs_cn[None, :] * stride_cn
|
||||||
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
|
tl.store(pc, c, mask = offs_cn[None, :] < DS0)
|
||||||
|
|
||||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||||
# shapes / dtypes
|
# shapes / dtypes
|
||||||
AS1 = block * spdims[2 if trans_a else 1]
|
AS1 = block * spdims[2 if trans_a else 1]
|
||||||
BS0 = b.size(0)
|
BS0 = b.size(0)
|
||||||
@@ -230,11 +200,15 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
|
|||||||
CS1 = BS1
|
CS1 = BS1
|
||||||
CS2 = BS3 if trans_c else AS1
|
CS2 = BS3 if trans_c else AS1
|
||||||
CS3 = AS1 if trans_c else BS3
|
CS3 = AS1 if trans_c else BS3
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
if out is None:
|
||||||
|
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||||
|
else:
|
||||||
|
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||||
|
c = out
|
||||||
# meta-parameter heuristics
|
# meta-parameter heuristics
|
||||||
TILE_N = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
TILE_N = 128
|
||||||
# compute output
|
# compute output
|
||||||
grid = lambda meta: [width, triton.cdiv(BS3, meta['TILE_N']), BS0]
|
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
_dsd_kernel[grid](
|
_dsd_kernel[grid](
|
||||||
a, b, c,
|
a, b, c,
|
||||||
@@ -242,8 +216,8 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
|
|||||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||||
BS3, AS1, lut,
|
BS3, AS1, lut,
|
||||||
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
|
TILE_M = block, TILE_N=TILE_N, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
|
||||||
num_warps=4, GROUP_SIZE_M=8,
|
num_warps=4, GROUP_SIZE_M=4,
|
||||||
)
|
)
|
||||||
# exit()
|
# exit()
|
||||||
return c
|
return c
|
||||||
@@ -323,7 +297,7 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
lut = torch.cat((header, incs))
|
lut = torch.cat((header, incs))
|
||||||
lut = lut.type(torch.int32).to(device)
|
lut = lut.type(torch.int32).to(device)
|
||||||
# create locks
|
# create locks
|
||||||
return lut, None, width, None
|
return lut, width
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Dense = Dense x Sparse (DDS)
|
# Dense = Dense x Sparse (DDS)
|
||||||
@@ -334,12 +308,10 @@ def _dds_kernel(
|
|||||||
stride_za, stride_ha, stride_ma, stride_ka,
|
stride_za, stride_ha, stride_ma, stride_ka,
|
||||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||||
DS0, DS1, lut, **meta
|
DS0, DS1, lut,
|
||||||
|
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
||||||
):
|
):
|
||||||
TILE_M = meta['TILE_M']
|
|
||||||
TILE_N = meta['TILE_N']
|
|
||||||
TILE_K = meta['TILE_K']
|
|
||||||
GROUP_SIZE_M = meta['GROUP_SIZE_M']
|
|
||||||
#------------#
|
#------------#
|
||||||
#- Prologue -#
|
#- Prologue -#
|
||||||
#------------#
|
#------------#
|
||||||
@@ -347,16 +319,17 @@ def _dds_kernel(
|
|||||||
pid_n = tl.program_id(1)
|
pid_n = tl.program_id(1)
|
||||||
num_pid_m = tl.num_programs(0)
|
num_pid_m = tl.num_programs(0)
|
||||||
num_pid_n = tl.num_programs(1)
|
num_pid_n = tl.num_programs(1)
|
||||||
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||||
pid_z = tl.program_id(2)
|
pid_z = tl.program_id(2)
|
||||||
header = lut + pid_m * 4
|
header = lut + pid_n * 4
|
||||||
offset = tl.load(header + 0)
|
offset = tl.load(header + 0)
|
||||||
AS1 = tl.load(header + 1)
|
AS1 = tl.load(header + 1)
|
||||||
column = tl.load(header + 2)
|
column = tl.load(header + 2)
|
||||||
off_h = tl.load(header + 3)
|
off_h = tl.load(header + 3)
|
||||||
pinc = lut + offset
|
pinc = lut + offset
|
||||||
# initialize pointers to A (dense)
|
# initialize pointers to A (dense)
|
||||||
offs_am = pid_n*TILE_M + tl.arange(0, TILE_M)
|
offs_am = pid_m*TILE_M + tl.arange(0, TILE_M)
|
||||||
|
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
|
||||||
start_ak = tl.load(pinc)
|
start_ak = tl.load(pinc)
|
||||||
start_ak = tl.multiple_of(start_ak, 8)
|
start_ak = tl.multiple_of(start_ak, 8)
|
||||||
offs_ak = start_ak + tl.arange(0, TILE_K)
|
offs_ak = start_ak + tl.arange(0, TILE_K)
|
||||||
@@ -394,7 +367,7 @@ def _dds_kernel(
|
|||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
c = acc.to(C.dtype.element_ty)
|
c = acc.to(C.dtype.element_ty)
|
||||||
# initialize pointers to C (dense)
|
# initialize pointers to C (dense)
|
||||||
offs_cm = pid_n * TILE_M + tl.arange(0, TILE_M)
|
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||||
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
|
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
|
||||||
ptrs_c = C + off_h * stride_hc \
|
ptrs_c = C + off_h * stride_hc \
|
||||||
+ pid_z * stride_zc \
|
+ pid_z * stride_zc \
|
||||||
@@ -403,7 +376,7 @@ def _dds_kernel(
|
|||||||
# write back
|
# write back
|
||||||
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
|
tl.store(ptrs_c, c, mask = offs_cm[:, None] < DS0)
|
||||||
|
|
||||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out = None):
|
||||||
# shapes / dtypes
|
# shapes / dtypes
|
||||||
AS0 = a.size(0)
|
AS0 = a.size(0)
|
||||||
AS1 = a.size(1)
|
AS1 = a.size(1)
|
||||||
@@ -415,9 +388,13 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
|
|||||||
CS1 = AS1
|
CS1 = AS1
|
||||||
CS2 = BS2 if trans_c else AS2
|
CS2 = BS2 if trans_c else AS2
|
||||||
CS3 = AS2 if trans_c else BS2
|
CS3 = AS2 if trans_c else BS2
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
if out is None:
|
||||||
|
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||||
|
else:
|
||||||
|
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||||
|
c = out
|
||||||
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
||||||
grid = lambda meta: [width, triton.cdiv(AS2, meta['TILE_M']), AS0]
|
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
_dds_kernel[grid](
|
_dds_kernel[grid](
|
||||||
a, b, c,
|
a, b, c,
|
||||||
@@ -425,8 +402,8 @@ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, w
|
|||||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||||
AS2, BS2, lut,
|
AS2, BS2, lut,
|
||||||
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=3,
|
TILE_M = TILE_M, TILE_N = block, TILE_K = min(block, 32), BLOCK = block, num_stages=4,
|
||||||
num_warps=4, GROUP_SIZE_M=8,
|
num_warps=4, GROUP_SIZE_M=4,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -439,25 +416,23 @@ class _matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, da_num_locks,
|
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||||
da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||||
):
|
):
|
||||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
|
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||||
# save for backward
|
# save for backward
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
ctx.da_num_locks = da_num_locks
|
|
||||||
ctx.da_lut = da_lut
|
ctx.da_lut = da_lut
|
||||||
ctx.da_width = da_width
|
ctx.da_width = da_width
|
||||||
ctx.da_packs = da_packs
|
|
||||||
ctx.db_lut = db_lut
|
ctx.db_lut = db_lut
|
||||||
ctx.db_num_locks = db_num_locks
|
|
||||||
ctx.db_width = db_width
|
ctx.db_width = db_width
|
||||||
ctx.db_packs = db_packs
|
|
||||||
ctx.mode = mode
|
ctx.mode = mode
|
||||||
ctx.spdims = spdims
|
ctx.spdims = spdims
|
||||||
ctx.block = block
|
ctx.block = block
|
||||||
ctx.trans_a = trans_a
|
ctx.trans_a = trans_a
|
||||||
ctx.trans_b = trans_b
|
ctx.trans_b = trans_b
|
||||||
|
ctx.trans_c = trans_c
|
||||||
|
ctx.has_out = out is not None
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -466,155 +441,55 @@ class _matmul(torch.autograd.Function):
|
|||||||
a, b = ctx.saved_tensors
|
a, b = ctx.saved_tensors
|
||||||
da, db = None, None
|
da, db = None, None
|
||||||
mode = ctx.mode
|
mode = ctx.mode
|
||||||
|
|
||||||
# gradients w.r.t. a
|
# gradients w.r.t. a
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
mode_da = mode[1] + mode[0] + mode[2]
|
mode_da = mode[1] + mode[0] + mode[2]
|
||||||
da = _matmul.fn[mode_da](
|
da = _matmul.fn[mode_da](
|
||||||
dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_num_locks, ctx.da_width,
|
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||||
ctx.da_packs
|
|
||||||
)
|
)
|
||||||
# gradients w.r.t. b
|
# gradients w.r.t. b
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
mode_db = mode[2] + mode[1] + mode[0]
|
mode_db = mode[2] + mode[1] + mode[0]
|
||||||
db = _matmul.fn[mode_db](
|
db = _matmul.fn[mode_db](
|
||||||
a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_num_locks, ctx.db_width,
|
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||||
ctx.db_packs
|
|
||||||
)
|
)
|
||||||
|
dout = dc if ctx.has_out else None
|
||||||
return da, db, None, None, None,\
|
return da, db, None, None, None,\
|
||||||
None, None, None, None,\
|
None, None, None, None,\
|
||||||
None, None, None, None, None, None,\
|
None, None, None, None, None, dout
|
||||||
None, None, None, None, None, None,\
|
|
||||||
None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class matmul:
|
class matmul:
|
||||||
def make_lut(self, dtype, device):
|
|
||||||
key = (dtype, device)
|
|
||||||
if key in self.lut_cache:
|
|
||||||
return self.lut_cache[key]
|
|
||||||
# C look-up table
|
|
||||||
layout, block = self.layout, self.block
|
|
||||||
step = min(block, 32)
|
|
||||||
if self.mode == 'sdd':
|
|
||||||
c_lut, c_num_locks, c_width, c_packs = sdd_lut(layout, block, device)
|
|
||||||
elif self.mode == 'dsd':
|
|
||||||
c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, not self.trans_a, device)
|
|
||||||
elif self.mode == 'dds':
|
|
||||||
c_lut, c_num_locks, c_width, c_packs = dsd_lut(layout, block, step, self.trans_b, device)
|
|
||||||
# DA look-up table
|
|
||||||
if self.mode == 'sdd':
|
|
||||||
da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, True, device)
|
|
||||||
elif self.mode == 'dsd':
|
|
||||||
da_lut, da_num_locks, da_width, da_packs = sdd_lut(layout, block, device)
|
|
||||||
elif self.mode == 'dds':
|
|
||||||
da_lut, da_num_locks, da_width, da_packs = dsd_lut(layout, block, step, not self.trans_b, device)
|
|
||||||
# DB look-up table
|
|
||||||
if self.mode == 'sdd':
|
|
||||||
db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, False, device)
|
|
||||||
elif self.mode == 'dsd':
|
|
||||||
db_lut, db_num_locks, db_width, db_packs = dsd_lut(layout, block, step, self.trans_a, device)
|
|
||||||
elif self.mode == 'dds':
|
|
||||||
db_lut, db_num_locks, db_width, db_packs = sdd_lut(layout, block, device)
|
|
||||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,
|
|
||||||
da_lut, da_num_locks, da_width, da_packs,
|
|
||||||
db_lut, db_num_locks, db_width, db_packs)
|
|
||||||
return self.lut_cache[key]
|
|
||||||
|
|
||||||
def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
|
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
|
||||||
if mode not in ['sdd', 'dsd', 'dds']:
|
if mode not in ['sdd', 'dsd', 'dds']:
|
||||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||||
# look-up table cache
|
|
||||||
self.lut_cache = dict()
|
|
||||||
# attributes
|
|
||||||
self.block = block
|
self.block = block
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.trans_a = trans_a
|
self.trans_a = trans_a
|
||||||
self.trans_b = trans_b
|
self.trans_b = trans_b
|
||||||
|
self.trans_c = trans_c
|
||||||
layout_dim = layout.ndim
|
|
||||||
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
|
|
||||||
|
|
||||||
if not mode == 'sdd':
|
|
||||||
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
|
|
||||||
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2)
|
|
||||||
self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
|
|
||||||
sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
|
|
||||||
|
|
||||||
# Inner dim of the dense input should be equal to the inner dim of the sparse input
|
|
||||||
self.dense_inner_size = layout.shape[sparse_inner] * block
|
|
||||||
# Expected shape for sparse inputs
|
|
||||||
self.sparse_shape = (layout.sum().item(), block, block)
|
|
||||||
|
|
||||||
# Support using the same layout across attention heads etc.
|
|
||||||
if layout_dim == 2:
|
|
||||||
layout = layout.unsqueeze(0)
|
|
||||||
|
|
||||||
layout = layout.long() # Above code assumes the layout tensor is an integral type
|
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
self.spdims = layout.shape
|
self.spdims = layout.shape
|
||||||
|
step = min(block, 32)
|
||||||
|
if self.mode == 'sdd':
|
||||||
|
self.c_lut, self.c_width = sdd_lut(layout, block, device)
|
||||||
|
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
|
||||||
|
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
|
||||||
|
if self.mode == 'dsd':
|
||||||
|
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
|
||||||
|
self.da_lut, self.da_width = sdd_lut(layout, block, device)
|
||||||
|
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
|
||||||
|
if self.mode == 'dds':
|
||||||
|
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
|
||||||
|
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
|
||||||
|
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||||
|
|
||||||
def __call__(self, a, b):
|
def __call__(self, a, b, out = None):
|
||||||
c_lut, c_num_locks, c_width, c_packs,\
|
|
||||||
da_lut, da_num_locks, da_width, da_packs,\
|
|
||||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
|
||||||
|
|
||||||
# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
|
|
||||||
# and potential illegal memory accesses
|
|
||||||
original_dims = max(a.ndim, b.ndim)
|
|
||||||
a, b = self._validate_inputs(a, b)
|
|
||||||
|
|
||||||
# execute
|
|
||||||
c = _matmul.apply(
|
c = _matmul.apply(
|
||||||
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
|
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||||
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
|
self.c_lut, self.c_width,
|
||||||
|
self.da_lut, self.da_width,
|
||||||
|
self.db_lut, self.db_width,
|
||||||
|
out
|
||||||
)
|
)
|
||||||
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
|
|
||||||
dims_to_trim = c.ndim - original_dims
|
|
||||||
for _ in range(dims_to_trim):
|
|
||||||
c = c.squeeze(0)
|
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def _validate_inputs(self, a, b):
|
|
||||||
if a.device != b.device:
|
|
||||||
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
|
|
||||||
f"and {b.device} for tensor B")
|
|
||||||
if not a.is_cuda:
|
|
||||||
raise ValueError("Only GPU devices are supported for now")
|
|
||||||
|
|
||||||
# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
|
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
a, b = a.half(), b.half()
|
|
||||||
elif a.dtype != b.dtype:
|
|
||||||
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
|
|
||||||
|
|
||||||
mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
|
|
||||||
if mode != 'sdd':
|
|
||||||
# One input is sparse
|
|
||||||
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
|
|
||||||
dense_inner = dense.shape[self.dense_inner_dim]
|
|
||||||
if dense_inner != self.dense_inner_size:
|
|
||||||
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
|
|
||||||
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
|
|
||||||
|
|
||||||
if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
|
|
||||||
raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
|
|
||||||
f"{sparse_name}, got {sparse.shape}")
|
|
||||||
|
|
||||||
def add_extra_dims(x):
|
|
||||||
# Add extra leading singleton dimensions if needed
|
|
||||||
dims_needed = 4 - x.ndim
|
|
||||||
if dims_needed > 0:
|
|
||||||
singletons = [1] * dims_needed
|
|
||||||
x = x.view(*singletons, *x.shape)
|
|
||||||
elif dims_needed < 0:
|
|
||||||
raise ValueError("Tensors with more than 4 dimensions are not currently supported")
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
# Pad shapes with leading singleton dimensions
|
|
||||||
a = add_extra_dims(a)
|
|
||||||
b = add_extra_dims(b)
|
|
||||||
|
|
||||||
return a, b
|
|
||||||
|
@@ -16,10 +16,9 @@ def num_warps(n):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _forward(
|
def _forward(
|
||||||
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||||
**meta
|
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
|
||||||
|
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
|
||||||
):
|
):
|
||||||
TN = meta['TN']
|
|
||||||
BLOCK = meta['BLOCK']
|
|
||||||
pidhm = tl.program_id(0)
|
pidhm = tl.program_id(0)
|
||||||
pidz = tl.program_id(1)
|
pidz = tl.program_id(1)
|
||||||
# create index ranges
|
# create index ranges
|
||||||
@@ -43,25 +42,25 @@ def _forward(
|
|||||||
x = tl.load(px, mask=check, other=-float('inf'))
|
x = tl.load(px, mask=check, other=-float('inf'))
|
||||||
x = x.to(tl.float32)
|
x = x.to(tl.float32)
|
||||||
# apply scale
|
# apply scale
|
||||||
if meta['APPLY_SCALE']:
|
if APPLY_SCALE:
|
||||||
x = x * scale
|
x = x * scale
|
||||||
# apply RPE
|
# apply RPE
|
||||||
if meta['APPLY_RPE']:
|
if APPLY_RPE:
|
||||||
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
||||||
rpe = tl.load(prpe, mask=check, other=0)
|
rpe = tl.load(prpe, mask=check, other=0)
|
||||||
x = x + rpe
|
x = x + rpe
|
||||||
# apply key-padding mask
|
# apply key-padding mask
|
||||||
if meta['APPLY_KP_MASK']:
|
if APPLY_KP_MASK:
|
||||||
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
||||||
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
|
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
|
||||||
if meta['KP_MASK_MUL']:
|
if KP_MASK_MUL:
|
||||||
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
|
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
|
||||||
x = x + kp_m
|
x = x + kp_m
|
||||||
# apply attention mask
|
# apply attention mask
|
||||||
if meta['APPLY_ATTN_MASK']:
|
if APPLY_ATTN_MASK:
|
||||||
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
||||||
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
|
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
|
||||||
if meta['ATTN_MASK_MUL']:
|
if ATTN_MASK_MUL:
|
||||||
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
||||||
x = x + attn_m
|
x = x + attn_m
|
||||||
# apply causal mask
|
# apply causal mask
|
||||||
@@ -75,11 +74,9 @@ def _forward(
|
|||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
||||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
|
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
|
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
||||||
pidhm = tl.program_id(0)
|
pidhm = tl.program_id(0)
|
||||||
pidz = tl.program_id(1)
|
pidz = tl.program_id(1)
|
||||||
TN = meta['TN']
|
|
||||||
BLOCK = meta['BLOCK']
|
|
||||||
# create index ranges
|
# create index ranges
|
||||||
rxm = pidhm % BLOCK
|
rxm = pidhm % BLOCK
|
||||||
rbm = pidhm // BLOCK
|
rbm = pidhm // BLOCK
|
||||||
@@ -172,8 +169,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
APPLY_KP_MASK = apply_kp_mask,
|
APPLY_KP_MASK = apply_kp_mask,
|
||||||
APPLY_ATTN_MASK = apply_attn_mask,
|
APPLY_ATTN_MASK = apply_attn_mask,
|
||||||
KP_MASK_MUL = (kp_mask_mode == 'mul'),
|
KP_MASK_MUL = (kp_mask_mode == 'mul'),
|
||||||
ATTN_MASK_MUL = (attn_mask_mode == 'mul'),
|
ATTN_MASK_MUL = (attn_mask_mode == 'mul'))
|
||||||
force_nc_cache = True)
|
|
||||||
# save to context
|
# save to context
|
||||||
ctx.mark_dirty(x)
|
ctx.mark_dirty(x)
|
||||||
ctx.save_for_backward(x, lut)
|
ctx.save_for_backward(x, lut)
|
||||||
@@ -196,7 +192,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
# run kernel
|
# run kernel
|
||||||
M = x.shape[0]
|
M = x.shape[0]
|
||||||
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
||||||
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), force_nc_cache=True, BLOCK=ctx.block)
|
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
|
||||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@@ -26,8 +26,7 @@ def num_warps(N):
|
|||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
|
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
|
||||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
|
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
|
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||||
BLOCK = meta['BLOCK']
|
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = tl.arange(0, BLOCK)
|
cols = tl.arange(0, BLOCK)
|
||||||
idx = tl.load(IDX + row)
|
idx = tl.load(IDX + row)
|
||||||
@@ -52,8 +51,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
|
|||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
||||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
|
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(PROBS, IDX, DPROBS, N, **meta):
|
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||||
BLOCK = meta['BLOCK']
|
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = tl.arange(0, BLOCK)
|
cols = tl.arange(0, BLOCK)
|
||||||
idx = tl.load(IDX + row)
|
idx = tl.load(IDX + row)
|
||||||
|
@@ -26,13 +26,9 @@ def _kernel(A, B, C, M, N, K,
|
|||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
stride_cm, stride_cn,
|
stride_cm, stride_cn,
|
||||||
LOCKS, **META):
|
LOCKS,
|
||||||
# extract meta-parameters
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
BLOCK_M = META['BLOCK_M']
|
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
||||||
BLOCK_N = META['BLOCK_N']
|
|
||||||
BLOCK_K = META['BLOCK_K']
|
|
||||||
GROUP_M = META['GROUP_M']
|
|
||||||
SPLIT_K = META['SPLIT_K']
|
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
pid_z = tl.program_id(1)
|
pid_z = tl.program_id(1)
|
||||||
@@ -55,7 +51,7 @@ def _kernel(A, B, C, M, N, K,
|
|||||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(K, 0, -BLOCK_K*SPLIT_K):
|
for k in range(K, 0, -BLOCK_K*SPLIT_K):
|
||||||
if META['EVEN_K']:
|
if EVEN_K:
|
||||||
a = tl.load(A)
|
a = tl.load(A)
|
||||||
b = tl.load(B)
|
b = tl.load(B)
|
||||||
else:
|
else:
|
||||||
@@ -113,14 +109,11 @@ class _matmul(torch.autograd.Function):
|
|||||||
locks = _matmul._locks[device]
|
locks = _matmul._locks[device]
|
||||||
# launch kernel
|
# launch kernel
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||||
_kernel[grid](a, b, c,
|
_kernel[grid](a, b, c, M, N, K,
|
||||||
M, N, K,
|
|
||||||
a.stride(0), a.stride(1),
|
a.stride(0), a.stride(1),
|
||||||
b.stride(0), b.stride(1),
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),
|
c.stride(0), c.stride(1),
|
||||||
locks,
|
locks, GROUP_M=8)
|
||||||
GROUP_M=8)
|
|
||||||
# done
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
|||||||
# Compute Kernel
|
# Compute Kernel
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
|
from triton.language.core import constexpr
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -23,9 +24,9 @@ def add_kernel(
|
|||||||
y_ptr, # *Pointer* to second input vector
|
y_ptr, # *Pointer* to second input vector
|
||||||
output_ptr, # *Pointer* to output vector
|
output_ptr, # *Pointer* to output vector
|
||||||
n_elements, # Size of the vector
|
n_elements, # Size of the vector
|
||||||
**meta, # Optional meta-parameters for the kernel
|
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||||
|
# NOTE: `constexpr` so it can be used as a shape value
|
||||||
):
|
):
|
||||||
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
|
|
||||||
# There are multiple 'program's processing different data. We identify which program
|
# There are multiple 'program's processing different data. We identify which program
|
||||||
# we are here
|
# we are here
|
||||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||||
@@ -37,8 +38,8 @@ def add_kernel(
|
|||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||||
mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# Load x and y from DRAM, masking out any extar elements in case the input is not a
|
# Load x and y from DRAM, masking out any extra elements in case
|
||||||
# multiple of the block size
|
# the input is not a multiple of the block size
|
||||||
x = tl.load(x_ptr + offsets, mask=mask)
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
y = tl.load(y_ptr + offsets, mask=mask)
|
y = tl.load(y_ptr + offsets, mask=mask)
|
||||||
output = x + y
|
output = x + y
|
||||||
|
@@ -65,11 +65,11 @@ import triton.language as tl
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def softmax_kernel(
|
def softmax_kernel(
|
||||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
|
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr
|
||||||
):
|
):
|
||||||
# The rows of the softmax are independent, so we parallelize across those
|
# The rows of the softmax are independent, so we parallelize across those
|
||||||
row_idx = tl.program_id(0)
|
row_idx = tl.program_id(0)
|
||||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
|
||||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||||
|
@@ -182,17 +182,13 @@ def matmul_kernel(
|
|||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
stride_cm, stride_cn,
|
stride_cm, stride_cn,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
**meta,
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||||
):
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
ACTIVATION: tl.constexpr,
|
||||||
|
):
|
||||||
"""Kernel for computing the matmul C = A x B.
|
"""Kernel for computing the matmul C = A x B.
|
||||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
"""
|
"""
|
||||||
# extract meta-parameters
|
|
||||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
|
||||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
|
||||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
|
||||||
GROUP_SIZE_M = 8
|
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Map program ids `pid` to the block of C it should compute.
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
# This is done in a grouped ordering to promote L2 data reuse
|
# This is done in a grouped ordering to promote L2 data reuse
|
||||||
|
Reference in New Issue
Block a user