[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -40,7 +40,7 @@ def patch_kernel(template, to_replace):
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
@triton.jit
|
||||
def kernel(X, **meta):
|
||||
def kernel(X, SIZE: tl.constexpr):
|
||||
pass
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
@@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
def kernel(Z, X, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
x = tl.load(X + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + off, z)
|
||||
@@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, **meta):
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
||||
off = tl.arange(0, SIZE)
|
||||
x = tl.load(X + off)
|
||||
y = tl.load(Y + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
@@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'):
|
||||
|
||||
# Triton kernel
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
SIZE = meta['SIZE']
|
||||
def kernel(Z, X, SIZE: tl.constexpr):
|
||||
m = tl.arange(0, SIZE)
|
||||
n = tl.arange(0, SIZE)
|
||||
x = tl.load(X_PTR_EXPR)
|
||||
@@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
def kernel(X, Z):
|
||||
pid = tl.program_id(0)
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
@@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x = tl.load(X)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
||||
z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
|
||||
tl.store(Z, z)
|
||||
|
||||
# triton result
|
||||
@@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'):
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X + tl.arange(0, meta['BLOCK']))
|
||||
def kernel(X, Z, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.store(Z, tl.sum(x, axis=0))
|
||||
|
||||
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
||||
@@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
range_m = tl.arange(0, meta['BLOCK_M'])
|
||||
range_n = tl.arange(0, meta['BLOCK_N'])
|
||||
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
|
||||
z = tl.sum(x, axis=meta['AXIS'])
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
range_m = tl.arange(0, BLOCK_M)
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.store(Z + range_m, z)
|
||||
# input
|
||||
x = triton.testing.random(shape, dtype=dtype, device=device)
|
||||
@@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn, **meta):
|
||||
BLOCK_M = meta['BLOCK_M']
|
||||
BLOCK_N = meta['BLOCK_N']
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
@@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
Y, stride_yk, stride_yn,
|
||||
Z, stride_zm, stride_zn, **meta):
|
||||
BLOCK_M = meta['BLOCK_M']
|
||||
BLOCK_K = meta['BLOCK_K']
|
||||
BLOCK_N = meta['BLOCK_N']
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_k = tl.arange(0, BLOCK_K)
|
||||
@@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'):
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
||||
if meta['ADD_MATRIX']:
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if meta['ADD_ROWS']:
|
||||
if ADD_ROWS:
|
||||
ZRs = Z + off_m * stride_zm
|
||||
z += tl.load(ZRs)[:, None]
|
||||
if meta['ADD_COLS']:
|
||||
if ADD_COLS:
|
||||
ZCs = Z + off_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
tl.store(Zs, z)
|
||||
@@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'):
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out, **meta):
|
||||
def kernel(out):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
@@ -538,9 +535,10 @@ def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
@triton.jit
|
||||
def _kernel(z, **meta):
|
||||
off = tl.arange(0, meta['BLOCK'])
|
||||
val = tl.arange(meta['START'], meta['END'])
|
||||
def _kernel(z, BLOCK: tl.constexpr,
|
||||
START: tl.constexpr, END: tl.constexpr):
|
||||
off = tl.arange(0, BLOCK)
|
||||
val = tl.arange(START, END)
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||
@@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
@triton.jit
|
||||
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
||||
in_stride, in2_stride, out_stride,
|
||||
in_numel, in2_numel, out_numel, **meta):
|
||||
M = meta['M']
|
||||
N = meta['N']
|
||||
K = meta['K']
|
||||
in_numel, in2_numel, out_numel,
|
||||
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
|
||||
|
||||
M_offsets = tl.arange(0, M)
|
||||
N_offsets = tl.arange(0, N)
|
||||
@@ -605,14 +601,13 @@ def test_load_cache_modifier(cache):
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, **meta):
|
||||
def _kernel(dst, src, CACHE: tl.constexpr):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
|
||||
x = tl.load(src+offsets, cache_modifier=CACHE)
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
@@ -644,7 +639,7 @@ def test_load_cache_modifier(cache):
|
||||
#----------------
|
||||
def test_noop(device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(**meta):
|
||||
def kernel(x):
|
||||
pass
|
||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||
kernel[(1, )](x)
|
@@ -21,7 +21,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest)
|
||||
@@ -151,8 +151,8 @@ def triton_attention(
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||
layout,
|
||||
block,
|
||||
|
@@ -66,8 +66,8 @@ import torch
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
|
Reference in New Issue
Block a user