[BACKEND] Fix layout convert for non-contiguous input (#564)

This commit is contained in:
Keren Zhou
2022-06-25 23:12:03 -07:00
committed by GitHub
parent d345ddf837
commit 87413bc925
2 changed files with 20 additions and 7 deletions

View File

@@ -793,8 +793,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm)
for dtype in ['float32']
for shape in [(128, 128)]
for dtype in ['float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
@@ -812,18 +812,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
x = numpy_random(shape, dtype_str=dtype_str)
# triton result
z_tri = to_triton(np.empty_like(x), device=device)
z_tri_contiguous = to_triton(np.empty_like(x), device=device)
x_tri = to_triton(x, device=device)
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1])
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result
z_ref = x.transpose(*perm)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test dot