diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9c1fd7bc8..b2c4a0516 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1039,13 +1039,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # compare triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri_contiguous, z_ref) - # parse ptx to make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx - ptx = pgm_contiguous.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx + + if torch.version.hip is None: + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + else: + # TODO add rocm gcn assert + pass # --------------- # test dot @@ -1306,16 +1311,20 @@ def test_load_cache_modifier(cache): tl.store(dst + offsets, x) pgm = _kernel[(1,)](dst, src, CACHE=cache) - ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - if cache == '.cg': - assert 'ld.global.cg' in ptx - assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cg' not in ptx + if torch.version.hip is None: + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + else: + # TODO add rocm gcn assert + pass @pytest.mark.parametrize("N", [16, 10, 11, 1024])