[LANG] Added support for constexpr (#361)
This commit is contained in:
		@@ -40,7 +40,7 @@ def patch_kernel(template, to_replace):
 | 
			
		||||
def test_empty_kernel(dtype_x, device='cuda'):
 | 
			
		||||
    SIZE = 128
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, **meta):
 | 
			
		||||
    def kernel(X, SIZE: tl.constexpr):
 | 
			
		||||
        pass
 | 
			
		||||
    x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
 | 
			
		||||
    kernel[(1, )](x, SIZE=SIZE, num_warps=4)
 | 
			
		||||
@@ -50,8 +50,8 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
 | 
			
		||||
    SIZE = 128
 | 
			
		||||
    # define the kernel / launch-grid
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(Z, X, **meta):
 | 
			
		||||
        off = tl.arange(0, meta['SIZE'])
 | 
			
		||||
    def kernel(Z, X, SIZE: tl.constexpr):
 | 
			
		||||
        off = tl.arange(0, SIZE)
 | 
			
		||||
        x = tl.load(X + off)
 | 
			
		||||
        z = GENERATE_TEST_HERE
 | 
			
		||||
        tl.store(Z + off, z)
 | 
			
		||||
@@ -73,8 +73,8 @@ def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='c
 | 
			
		||||
    SIZE = 128
 | 
			
		||||
    # define the kernel / launch-grid
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(Z, X, Y, **meta):
 | 
			
		||||
        off = tl.arange(0, meta['SIZE'])
 | 
			
		||||
    def kernel(Z, X, Y, SIZE: tl.constexpr):
 | 
			
		||||
        off = tl.arange(0, SIZE)
 | 
			
		||||
        x = tl.load(X + off)
 | 
			
		||||
        y = tl.load(Y + off)
 | 
			
		||||
        z = GENERATE_TEST_HERE
 | 
			
		||||
@@ -203,8 +203,7 @@ def test_index1d(expr, device='cuda'):
 | 
			
		||||
 | 
			
		||||
    # Triton kernel
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(Z, X, **meta):
 | 
			
		||||
        SIZE = meta['SIZE']
 | 
			
		||||
    def kernel(Z, X, SIZE: tl.constexpr):
 | 
			
		||||
        m = tl.arange(0, SIZE)
 | 
			
		||||
        n = tl.arange(0, SIZE)
 | 
			
		||||
        x = tl.load(X_PTR_EXPR)
 | 
			
		||||
@@ -290,7 +289,7 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
 | 
			
		||||
 | 
			
		||||
    # triton kernel
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, Z, **meta):
 | 
			
		||||
    def kernel(X, Z):
 | 
			
		||||
        pid = tl.program_id(0)
 | 
			
		||||
        x = tl.load(X + pid)
 | 
			
		||||
        old = GENERATE_TEST_HERE
 | 
			
		||||
@@ -344,9 +343,9 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
 | 
			
		||||
 | 
			
		||||
    # triton kernel
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, Z, **meta):
 | 
			
		||||
    def kernel(X, Z, BITCAST: tl.constexpr):
 | 
			
		||||
        x = tl.load(X)
 | 
			
		||||
        z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
 | 
			
		||||
        z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
 | 
			
		||||
        tl.store(Z, z)
 | 
			
		||||
 | 
			
		||||
    # triton result
 | 
			
		||||
@@ -373,8 +372,8 @@ def test_reduce1d(dtype, shape, device='cuda'):
 | 
			
		||||
 | 
			
		||||
    # triton kernel
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, Z, **meta):
 | 
			
		||||
        x = tl.load(X + tl.arange(0, meta['BLOCK']))
 | 
			
		||||
    def kernel(X, Z, BLOCK: tl.constexpr):
 | 
			
		||||
        x = tl.load(X + tl.arange(0, BLOCK))
 | 
			
		||||
        tl.store(Z, tl.sum(x, axis=0))
 | 
			
		||||
 | 
			
		||||
    x = triton.testing.random((shape,), dtype=dtype, device=device)
 | 
			
		||||
@@ -395,11 +394,11 @@ def test_reduce2d(dtype, shape, axis, device='cuda'):
 | 
			
		||||
    dtype = cvt[dtype]
 | 
			
		||||
    # triton kernel
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, Z, **meta):
 | 
			
		||||
        range_m = tl.arange(0, meta['BLOCK_M'])
 | 
			
		||||
        range_n = tl.arange(0, meta['BLOCK_N'])
 | 
			
		||||
        x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
 | 
			
		||||
        z = tl.sum(x, axis=meta['AXIS'])
 | 
			
		||||
    def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
 | 
			
		||||
        range_m = tl.arange(0, BLOCK_M)
 | 
			
		||||
        range_n = tl.arange(0, BLOCK_N)
 | 
			
		||||
        x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
 | 
			
		||||
        z = tl.sum(x, axis=AXIS)
 | 
			
		||||
        tl.store(Z + range_m, z)
 | 
			
		||||
    # input
 | 
			
		||||
    x = triton.testing.random(shape, dtype=dtype, device=device)
 | 
			
		||||
@@ -429,9 +428,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
 | 
			
		||||
    # triton kernel
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, stride_xm, stride_xn, 
 | 
			
		||||
               Z, stride_zm, stride_zn, **meta):
 | 
			
		||||
        BLOCK_M = meta['BLOCK_M']
 | 
			
		||||
        BLOCK_N = meta['BLOCK_N']
 | 
			
		||||
               Z, stride_zm, stride_zn, 
 | 
			
		||||
               BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
 | 
			
		||||
        off_m = tl.arange(0, BLOCK_M)
 | 
			
		||||
        off_n = tl.arange(0, BLOCK_N)
 | 
			
		||||
        Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
 | 
			
		||||
@@ -464,10 +462,9 @@ def test_dot(epilogue, device='cuda'):
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(X, stride_xm, stride_xk, 
 | 
			
		||||
               Y, stride_yk, stride_yn,
 | 
			
		||||
               Z, stride_zm, stride_zn, **meta):
 | 
			
		||||
        BLOCK_M = meta['BLOCK_M']
 | 
			
		||||
        BLOCK_K = meta['BLOCK_K']
 | 
			
		||||
        BLOCK_N = meta['BLOCK_N']
 | 
			
		||||
               Z, stride_zm, stride_zn, 
 | 
			
		||||
               BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
 | 
			
		||||
               ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
 | 
			
		||||
        off_m = tl.arange(0, BLOCK_M)
 | 
			
		||||
        off_n = tl.arange(0, BLOCK_N)
 | 
			
		||||
        off_k = tl.arange(0, BLOCK_K)
 | 
			
		||||
@@ -475,12 +472,12 @@ def test_dot(epilogue, device='cuda'):
 | 
			
		||||
        Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
 | 
			
		||||
        Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
 | 
			
		||||
        z = tl.dot(tl.load(Xs), tl.load(Ys))
 | 
			
		||||
        if meta['ADD_MATRIX']:
 | 
			
		||||
        if ADD_MATRIX:
 | 
			
		||||
            z += tl.load(Zs)
 | 
			
		||||
        if meta['ADD_ROWS']:
 | 
			
		||||
        if ADD_ROWS:
 | 
			
		||||
            ZRs = Z + off_m * stride_zm
 | 
			
		||||
            z += tl.load(ZRs)[:, None]
 | 
			
		||||
        if meta['ADD_COLS']:
 | 
			
		||||
        if ADD_COLS:
 | 
			
		||||
            ZCs = Z + off_n * stride_zn 
 | 
			
		||||
            z += tl.load(ZCs)[None, :]
 | 
			
		||||
        tl.store(Zs, z)
 | 
			
		||||
@@ -517,7 +514,7 @@ def test_dot(epilogue, device='cuda'):
 | 
			
		||||
 | 
			
		||||
def test_dot_without_load():
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(out, **meta):
 | 
			
		||||
    def kernel(out):
 | 
			
		||||
        pid = tl.program_id(axis=0)
 | 
			
		||||
        a = tl.zeros((32, 32), tl.float32)
 | 
			
		||||
        b = tl.zeros((32, 32), tl.float32)
 | 
			
		||||
@@ -538,9 +535,10 @@ def test_arange(start, device='cuda'):
 | 
			
		||||
    BLOCK = 128
 | 
			
		||||
    z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def _kernel(z, **meta):
 | 
			
		||||
        off = tl.arange(0, meta['BLOCK'])
 | 
			
		||||
        val = tl.arange(meta['START'], meta['END'])
 | 
			
		||||
    def _kernel(z, BLOCK: tl.constexpr, 
 | 
			
		||||
                START: tl.constexpr, END: tl.constexpr):
 | 
			
		||||
        off = tl.arange(0, BLOCK)
 | 
			
		||||
        val = tl.arange(START, END)
 | 
			
		||||
        tl.store(z + off, val)
 | 
			
		||||
    _kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
 | 
			
		||||
    z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
 | 
			
		||||
@@ -564,10 +562,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def _kernel(in1_ptr, in2_ptr, output_ptr,
 | 
			
		||||
                in_stride, in2_stride, out_stride,
 | 
			
		||||
                in_numel, in2_numel, out_numel, **meta):
 | 
			
		||||
        M = meta['M']
 | 
			
		||||
        N = meta['N']
 | 
			
		||||
        K = meta['K']
 | 
			
		||||
                in_numel, in2_numel, out_numel,
 | 
			
		||||
                M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
 | 
			
		||||
 | 
			
		||||
        M_offsets = tl.arange(0, M)
 | 
			
		||||
        N_offsets = tl.arange(0, N)
 | 
			
		||||
@@ -605,14 +601,13 @@ def test_load_cache_modifier(cache):
 | 
			
		||||
    dst = torch.empty(128, device='cuda')
 | 
			
		||||
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def _kernel(dst, src, **meta):
 | 
			
		||||
    def _kernel(dst, src, CACHE: tl.constexpr):
 | 
			
		||||
        offsets = tl.arange(0, 128)
 | 
			
		||||
        x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
 | 
			
		||||
        x = tl.load(src+offsets, cache_modifier=CACHE)
 | 
			
		||||
        tl.store(dst+offsets, x)
 | 
			
		||||
 | 
			
		||||
    pgm = _kernel[(1,)](dst, src, CACHE=cache)
 | 
			
		||||
    ptx = pgm.asm['ptx']
 | 
			
		||||
 | 
			
		||||
    if cache == '':
 | 
			
		||||
        assert 'ld.global.ca' not in ptx
 | 
			
		||||
        assert 'ld.global.cg' not in ptx
 | 
			
		||||
@@ -644,7 +639,7 @@ def test_load_cache_modifier(cache):
 | 
			
		||||
#----------------
 | 
			
		||||
def test_noop(device='cuda'):
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(**meta):
 | 
			
		||||
    def kernel(x):
 | 
			
		||||
        pass
 | 
			
		||||
    x = triton.testing.random((1,), dtype=torch.int32, device=device)
 | 
			
		||||
    kernel[(1, )](x)
 | 
			
		||||
		Reference in New Issue
	
	Block a user