|
|
|
@@ -363,6 +363,133 @@ def test_reduce1d(dtype, shape, device='cuda'):
|
|
|
|
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype, shape, axis",
|
|
|
|
|
[(dtype, shape, 1) \
|
|
|
|
|
for dtype in ['float32']\
|
|
|
|
|
for shape in [(1, 1024)]])
|
|
|
|
|
def test_reduce2d(dtype, shape, axis, device='cuda'):
|
|
|
|
|
dtype = cvt[dtype]
|
|
|
|
|
# triton kernel
|
|
|
|
|
@triton.jit
|
|
|
|
|
def kernel(X, Z, **meta):
|
|
|
|
|
range_m = tl.arange(0, meta['BLOCK_M'])
|
|
|
|
|
range_n = tl.arange(0, meta['BLOCK_N'])
|
|
|
|
|
x = tl.load(X + range_m[:, None]*meta['BLOCK_N'] + range_n[None, :])
|
|
|
|
|
z = tl.sum(x, axis=meta['AXIS'])
|
|
|
|
|
tl.store(Z + range_m, z)
|
|
|
|
|
# input
|
|
|
|
|
x = triton.testing.random(shape, dtype=dtype, device=device)
|
|
|
|
|
# triton result
|
|
|
|
|
z_tri = torch.empty((shape[0],), dtype=dtype, device=device)
|
|
|
|
|
kernel[(1,)](x, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
|
|
|
|
# torch result
|
|
|
|
|
z_ref = torch.sum(x, axis=axis).to(dtype)
|
|
|
|
|
# compare
|
|
|
|
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
|
# test permute
|
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
|
# test permute
|
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype, shape, perm",
|
|
|
|
|
[(dtype, shape, perm) \
|
|
|
|
|
for dtype in ['float32']\
|
|
|
|
|
for shape in [(128, 128)]\
|
|
|
|
|
for perm in [(1, 0)]])
|
|
|
|
|
def test_permute(dtype, shape, perm, device='cuda'):
|
|
|
|
|
dtype = cvt[dtype]
|
|
|
|
|
# triton kernel
|
|
|
|
|
@triton.jit
|
|
|
|
|
def kernel(X, stride_xm, stride_xn,
|
|
|
|
|
Z, stride_zm, stride_zn, **meta):
|
|
|
|
|
BLOCK_M = meta['BLOCK_M']
|
|
|
|
|
BLOCK_N = meta['BLOCK_N']
|
|
|
|
|
off_m = tl.arange(0, BLOCK_M)
|
|
|
|
|
off_n = tl.arange(0, BLOCK_N)
|
|
|
|
|
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
|
|
|
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
|
|
|
|
tl.store(Zs, tl.load(Xs))
|
|
|
|
|
# input
|
|
|
|
|
x = triton.testing.random(shape, dtype=dtype, device=device)
|
|
|
|
|
# triton result
|
|
|
|
|
z_tri = torch.empty_like(x)
|
|
|
|
|
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
|
|
|
|
z_tri, z_tri.stride(1), z_tri.stride(0),
|
|
|
|
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
|
|
|
|
# torch result
|
|
|
|
|
z_ref = x.permute(*perm).contiguous()
|
|
|
|
|
# compare
|
|
|
|
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
|
|
|
|
# parse ptx to make sure ld/st are vectorized
|
|
|
|
|
ptx = pgm.asm('ptx')
|
|
|
|
|
assert 'ld.global.v4' in ptx
|
|
|
|
|
assert 'st.global.v4' in ptx
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
|
# test dot
|
|
|
|
|
# ---------------
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
|
|
|
|
|
def test_dot(epilogue, device='cuda'):
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
# triton kernel
|
|
|
|
|
@triton.jit
|
|
|
|
|
def kernel(X, stride_xm, stride_xk,
|
|
|
|
|
Y, stride_yk, stride_yn,
|
|
|
|
|
Z, stride_zm, stride_zn, **meta):
|
|
|
|
|
BLOCK_M = meta['BLOCK_M']
|
|
|
|
|
BLOCK_K = meta['BLOCK_K']
|
|
|
|
|
BLOCK_N = meta['BLOCK_N']
|
|
|
|
|
off_m = tl.arange(0, BLOCK_M)
|
|
|
|
|
off_n = tl.arange(0, BLOCK_N)
|
|
|
|
|
off_k = tl.arange(0, BLOCK_K)
|
|
|
|
|
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
|
|
|
|
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
|
|
|
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
|
|
|
|
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
|
|
|
|
if meta['ADD_MATRIX']:
|
|
|
|
|
z += tl.load(Zs)
|
|
|
|
|
if meta['ADD_ROWS']:
|
|
|
|
|
ZRs = Z + off_m * stride_zm
|
|
|
|
|
z += tl.load(ZRs)[:, None]
|
|
|
|
|
if meta['ADD_COLS']:
|
|
|
|
|
ZCs = Z + off_n * stride_zn
|
|
|
|
|
z += tl.load(ZCs)[None, :]
|
|
|
|
|
tl.store(Zs, z)
|
|
|
|
|
# input
|
|
|
|
|
M, N, K = 64, 64, 32
|
|
|
|
|
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
|
|
|
|
|
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
|
|
|
|
|
# triton result
|
|
|
|
|
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
|
|
|
|
|
z_tri = z.clone()
|
|
|
|
|
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
|
|
|
|
y, y.stride(0), y.stride(1),
|
|
|
|
|
z_tri, z_tri.stride(0), z_tri.stride(1),
|
|
|
|
|
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
|
|
|
|
ADD_MATRIX = epilogue=='add-matrix',
|
|
|
|
|
ADD_ROWS = epilogue=='add-rows',
|
|
|
|
|
ADD_COLS = epilogue=='add-cols')
|
|
|
|
|
# torch result
|
|
|
|
|
z_ref = torch.matmul(x.float(), y.float())
|
|
|
|
|
if epilogue == 'add-matrix':
|
|
|
|
|
z_ref += z
|
|
|
|
|
if epilogue == 'add-rows':
|
|
|
|
|
z_ref += z[:,0][:, None]
|
|
|
|
|
if epilogue == 'add-cols':
|
|
|
|
|
z_ref += z[0,:][None, :]
|
|
|
|
|
z_ref = z_ref.to(torch.float16)
|
|
|
|
|
# compare
|
|
|
|
|
ptx = pgm.asm('ptx')
|
|
|
|
|
# print(ptx)
|
|
|
|
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
|
|
|
|
# make sure ld/st are vectorized
|
|
|
|
|
assert 'ld.global.v4' in ptx
|
|
|
|
|
assert 'st.global.v4' in ptx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
|