From 3ca792043f244a5f28eb7a347e7757edf63ae0d5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 24 Apr 2022 13:32:35 -0700 Subject: [PATCH] [TEST] Added test for vectorization --- python/test/unit/language/test_core.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4570a5c61..a7f27eaba 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -937,6 +937,24 @@ def test_load_cache_modifier(cache): assert 'ld.global.ca' in ptx assert 'ld.global.cg' not in ptx +@pytest.mark.parametrize("N", [8, 10, 11, 1024]) +def test_vectorization(N): + src = torch.empty(1024, device='cuda') + dst = torch.empty(1024, device='cuda') + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) + ptx = pgm.asm["ptx"] + if N % 4 == 0: + assert "ld.global.v4.b32" in ptx + elif N % 2 == 0: + assert "ld.global.v2.b32" in ptx + else: + assert "ld.global.b32" in ptx + # triton.testing.assert_almost_equal(dst, src[:N]) # --------------- # test store # ---------------