[TEST] Added test for vectorization
This commit is contained in:
@@ -937,6 +937,24 @@ def test_load_cache_modifier(cache):
|
|||||||
assert 'ld.global.ca' in ptx
|
assert 'ld.global.ca' in ptx
|
||||||
assert 'ld.global.cg' not in ptx
|
assert 'ld.global.cg' not in ptx
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
||||||
|
def test_vectorization(N):
|
||||||
|
src = torch.empty(1024, device='cuda')
|
||||||
|
dst = torch.empty(1024, device='cuda')
|
||||||
|
@triton.jit
|
||||||
|
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||||
|
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
x = tl.load(src + offsets, mask=offsets < N)
|
||||||
|
tl.store(dst + offsets, x, mask=offsets < N)
|
||||||
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||||
|
ptx = pgm.asm["ptx"]
|
||||||
|
if N % 4 == 0:
|
||||||
|
assert "ld.global.v4.b32" in ptx
|
||||||
|
elif N % 2 == 0:
|
||||||
|
assert "ld.global.v2.b32" in ptx
|
||||||
|
else:
|
||||||
|
assert "ld.global.b32" in ptx
|
||||||
|
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||||
# ---------------
|
# ---------------
|
||||||
# test store
|
# test store
|
||||||
# ---------------
|
# ---------------
|
||||||
|
Reference in New Issue
Block a user