Merge pull request #14 from ROCmSoftwarePlatform/fix_vectorization
fix test_vectorization and test_load_cache_modifier
This commit is contained in:
@@ -1039,13 +1039,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||||
# parse ptx to make sure ld/st are vectorized
|
|
||||||
ptx = pgm.asm['ptx']
|
if torch.version.hip is None:
|
||||||
assert 'ld.global.v4' in ptx
|
# parse ptx to make sure ld/st are vectorized
|
||||||
assert 'st.global.v4' in ptx
|
ptx = pgm.asm['ptx']
|
||||||
ptx = pgm_contiguous.asm['ptx']
|
assert 'ld.global.v4' in ptx
|
||||||
assert 'ld.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
assert 'st.global.v4' in ptx
|
ptx = pgm_contiguous.asm['ptx']
|
||||||
|
assert 'ld.global.v4' in ptx
|
||||||
|
assert 'st.global.v4' in ptx
|
||||||
|
else:
|
||||||
|
# TODO add rocm gcn assert
|
||||||
|
pass
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test dot
|
# test dot
|
||||||
@@ -1306,16 +1311,20 @@ def test_load_cache_modifier(cache):
|
|||||||
tl.store(dst + offsets, x)
|
tl.store(dst + offsets, x)
|
||||||
|
|
||||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||||
ptx = pgm.asm['ptx']
|
if torch.version.hip is None:
|
||||||
if cache == '':
|
ptx = pgm.asm['ptx']
|
||||||
assert 'ld.global.ca' not in ptx
|
if cache == '':
|
||||||
assert 'ld.global.cg' not in ptx
|
assert 'ld.global.ca' not in ptx
|
||||||
if cache == '.cg':
|
assert 'ld.global.cg' not in ptx
|
||||||
assert 'ld.global.cg' in ptx
|
if cache == '.cg':
|
||||||
assert 'ld.global.ca' not in ptx
|
assert 'ld.global.cg' in ptx
|
||||||
if cache == '.ca':
|
assert 'ld.global.ca' not in ptx
|
||||||
assert 'ld.global.ca' in ptx
|
if cache == '.ca':
|
||||||
assert 'ld.global.cg' not in ptx
|
assert 'ld.global.ca' in ptx
|
||||||
|
assert 'ld.global.cg' not in ptx
|
||||||
|
else:
|
||||||
|
# TODO add rocm gcn assert
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||||
@@ -1329,11 +1338,15 @@ def test_vectorization(N):
|
|||||||
x = tl.load(src + offsets, mask=offsets < N)
|
x = tl.load(src + offsets, mask=offsets < N)
|
||||||
tl.store(dst + offsets, x, mask=offsets < N)
|
tl.store(dst + offsets, x, mask=offsets < N)
|
||||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||||
ptx = pgm.asm["ptx"]
|
if torch.version.hip is None:
|
||||||
if N % 16 == 0:
|
ptx = pgm.asm["ptx"]
|
||||||
assert "ld.global.v4.b32" in ptx
|
if N % 16 == 0:
|
||||||
|
assert "ld.global.v4.b32" in ptx
|
||||||
|
else:
|
||||||
|
assert "ld.global.b32" in ptx
|
||||||
else:
|
else:
|
||||||
assert "ld.global.b32" in ptx
|
#TODO add rocm assert
|
||||||
|
pass
|
||||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||||
# ---------------
|
# ---------------
|
||||||
# test store
|
# test store
|
||||||
|
Reference in New Issue
Block a user