[BACKEND/CODE_GEN] Fixed float32 matmul problem (#380)
This commit is contained in:
@@ -455,8 +455,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, dtype=torch.float32, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -483,11 +483,13 @@ def test_dot(epilogue, device='cuda'):
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
|
||||
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
|
||||
x = triton.testing.random((M, K), dtype=dtype, device=device)
|
||||
y = triton.testing.random((K, N), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
|
||||
z = triton.testing.random((M, N), dtype=dtype, device=device)
|
||||
z_tri = z.clone()
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
||||
y, y.stride(0), y.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
@@ -505,10 +507,9 @@ def test_dot(epilogue, device='cuda'):
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref = z_ref.to(torch.float16)
|
||||
# compare
|
||||
ptx = pgm.asm['ptx']
|
||||
# print(ptx)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
|
Reference in New Issue
Block a user