[BACKEND/CODE_GEN] Fixed float32 matmul problem (#380)

This commit is contained in:
Philippe Tillet
2021-11-30 22:00:56 -08:00
committed by GitHub
parent c86ad9c9ab
commit 8ec9f037bb
2 changed files with 17 additions and 13 deletions

View File

@@ -455,8 +455,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, dtype=torch.float32, device='cuda'):
torch.manual_seed(0)
# triton kernel
@triton.jit
@@ -483,11 +483,13 @@ def test_dot(epilogue, device='cuda'):
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
x = triton.testing.random((M, K), dtype=dtype, device=device)
y = triton.testing.random((K, N), dtype=dtype, device=device)
# triton result
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
z = triton.testing.random((M, N), dtype=dtype, device=device)
z_tri = z.clone()
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
@@ -505,10 +507,9 @@ def test_dot(epilogue, device='cuda'):
z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare
ptx = pgm.asm['ptx']
# print(ptx)
triton.testing.assert_almost_equal(z_tri, z_ref)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx