[TEST] Added test for vectorization

This commit is contained in:
Philippe Tillet
2022-04-24 13:32:35 -07:00
parent bda209002e
commit 3ca792043f

View File

@@ -937,6 +937,24 @@ def test_load_cache_modifier(cache):
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 4 == 0:
assert "ld.global.v4.b32" in ptx
elif N % 2 == 0:
assert "ld.global.v2.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# ---------------
# test store
# ---------------