Merge pull request #14 from ROCmSoftwarePlatform/fix_vectorization

fix test_vectorization and test_load_cache_modifier
This commit is contained in:
rsanthanam-amd
2022-10-28 16:12:57 -05:00
committed by GitHub

View File

@@ -1039,13 +1039,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
else:
# TODO add rocm gcn assert
pass
# ---------------
# test dot
@@ -1306,16 +1311,20 @@ def test_load_cache_modifier(cache):
tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
if torch.version.hip is None:
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
else:
# TODO add rocm gcn assert
pass
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
@@ -1329,11 +1338,15 @@ def test_vectorization(N):
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
if torch.version.hip is None:
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.b32" in ptx
else:
assert "ld.global.b32" in ptx
#TODO add rocm assert
pass
# triton.testing.assert_almost_equal(dst, src[:N])
# ---------------
# test store