diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index dfb7f2535..b2c4a0516 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1039,13 +1039,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # compare triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri_contiguous, z_ref) - # parse ptx to make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx - ptx = pgm_contiguous.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx + + if torch.version.hip is None: + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + else: + # TODO add rocm gcn assert + pass # --------------- # test dot @@ -1306,16 +1311,20 @@ def test_load_cache_modifier(cache): tl.store(dst + offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) - ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - if cache == '.cg': - assert 'ld.global.cg' in ptx - assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cg' not in ptx + if torch.version.hip is None: + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + else: + # TODO add rocm gcn assert + pass @pytest.mark.parametrize("N", [16, 10, 11, 1024]) @@ -1329,11 +1338,15 @@ def test_vectorization(N): x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) - ptx = pgm.asm["ptx"] - if N % 16 == 0: - assert "ld.global.v4.b32" in ptx + if torch.version.hip is None: + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx else: - assert "ld.global.b32" in ptx + #TODO add rocm assert + pass # triton.testing.assert_almost_equal(dst, src[:N]) # --------------- # test store