[BACKEND] Alignment pass improvements (#503)

This commit is contained in:
Philippe Tillet
2022-04-25 21:16:00 -07:00
committed by GitHub
parent 7d544799a0
commit ae2a1ab225
4 changed files with 29 additions and 15 deletions

View File

@@ -937,13 +937,15 @@ def test_load_cache_modifier(cache):
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])