skip ptx assert
This commit is contained in:
@@ -1329,11 +1329,15 @@ def test_vectorization(N):
|
|||||||
x = tl.load(src + offsets, mask=offsets < N)
|
x = tl.load(src + offsets, mask=offsets < N)
|
||||||
tl.store(dst + offsets, x, mask=offsets < N)
|
tl.store(dst + offsets, x, mask=offsets < N)
|
||||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||||
ptx = pgm.asm["ptx"]
|
if torch.version.hip is None:
|
||||||
if N % 16 == 0:
|
ptx = pgm.asm["ptx"]
|
||||||
assert "ld.global.v4.b32" in ptx
|
if N % 16 == 0:
|
||||||
|
assert "ld.global.v4.b32" in ptx
|
||||||
|
else:
|
||||||
|
assert "ld.global.b32" in ptx
|
||||||
else:
|
else:
|
||||||
assert "ld.global.b32" in ptx
|
#TODO add rocm assert
|
||||||
|
pass
|
||||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||||
# ---------------
|
# ---------------
|
||||||
# test store
|
# test store
|
||||||
|
Reference in New Issue
Block a user