[TEST] Added test for vectorization
This commit is contained in:
@@ -937,6 +937,24 @@ def test_load_cache_modifier(cache):
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
||||
def test_vectorization(N):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
@triton.jit
|
||||
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 4 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
elif N % 2 == 0:
|
||||
assert "ld.global.v2.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user