add similar fixes two addition tests

This commit is contained in:
Michael Melesse
2022-10-28 20:34:58 +00:00
parent ffb30cdc52
commit 8d9572bc63

View File

@@ -1039,6 +1039,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
@@ -1046,6 +1048,9 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
else:
# TODO add rocm gcn assert
pass
# ---------------
# test dot
@@ -1306,6 +1311,7 @@ def test_load_cache_modifier(cache):
tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
if torch.version.hip is None:
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
@@ -1316,6 +1322,9 @@ def test_load_cache_modifier(cache):
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
else:
# TODO add rocm gcn assert
pass
@pytest.mark.parametrize("N", [16, 10, 11, 1024])