[ROCM] enable matmul(dot) and others (#391)
This commit is contained in:
@@ -45,6 +45,35 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# ---------------
|
||||
# test load and store op
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype,size", [
|
||||
(dtype, size)
|
||||
for dtype in dtypes
|
||||
for size in [128, 256, 512, 1024, 2048, 4096]
|
||||
])
|
||||
def test_load_and_store_op(dtype, size, device='cuda'):
|
||||
SIZE = size
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X + off)
|
||||
tl.store(Z + off, x)
|
||||
# inputs
|
||||
x = triton.testing.random(SIZE, dtype=cvt[dtype], device=device)
|
||||
|
||||
# output tensors
|
||||
z_ref = x.clone() # reference result
|
||||
z_tri = torch.empty_like(x) # triton result
|
||||
|
||||
# run load and store kernel
|
||||
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
||||
SIZE = 128
|
||||
@@ -340,18 +369,23 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
||||
if torch.version.hip is not None:
|
||||
assert 'bfloat' not in dtype_x
|
||||
assert 'bfloat' not in dtype_z
|
||||
|
||||
SIZE = 1024
|
||||
x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X)
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X+ off)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
|
||||
tl.store(Z, z)
|
||||
tl.store(Z+ off, z)
|
||||
|
||||
# triton result
|
||||
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, BITCAST=bitcast)
|
||||
z_tri = torch.empty((SIZE, ), dtype=cvt[dtype_z], device=device)
|
||||
kernel[(1, )](x, z_tri, SIZE=SIZE, BITCAST=bitcast)
|
||||
# torch result
|
||||
if bitcast:
|
||||
import numpy as np
|
||||
@@ -359,7 +393,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
z_ref = torch.from_numpy(z_ref).to(device)
|
||||
else:
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
@@ -448,17 +482,23 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
||||
z_ref = x.permute(*perm).contiguous()
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
if torch.version.hip is None:
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# ---------------
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
@pytest.mark.parametrize("dtype, epilogue", [(dtype, epilogue)\
|
||||
for dtype in ['float16','float32'] \
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols']])
|
||||
def test_dot(dtype, epilogue, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
|
||||
torch.manual_seed(0)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -486,10 +526,10 @@ def test_dot(epilogue, device='cuda'):
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
|
||||
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
|
||||
x = triton.testing.random((M, K), dtype=dtype, device=device)
|
||||
y = triton.testing.random((K, N), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
|
||||
z = triton.testing.random((M, N), dtype=dtype, device=device)
|
||||
z_tri = z.clone()
|
||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
||||
y, y.stride(0), y.stride(1),
|
||||
@@ -508,12 +548,14 @@ def test_dot(epilogue, device='cuda'):
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref = z_ref.to(torch.float16)
|
||||
# compare
|
||||
ptx = pgm.asm['ptx']
|
||||
# print(ptx)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# make sure ld/st are vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# print(ptx)
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm['ptx']
|
||||
# make sure ld/st are vectorized
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
@@ -611,17 +653,18 @@ def test_load_cache_modifier(cache):
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
if torch.version.hip is None:
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
@@ -647,4 +690,4 @@ def test_noop(device='cuda'):
|
||||
def kernel(**meta):
|
||||
pass
|
||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||
kernel[(1, )](x)
|
||||
kernel[(1, )](x)
|
||||
|
Reference in New Issue
Block a user