[BACKEND] Alignment pass improvements (#503)

This commit is contained in:
Philippe Tillet
2022-04-25 21:16:00 -07:00
committed by GitHub
parent 7d544799a0
commit ae2a1ab225
4 changed files with 29 additions and 15 deletions

View File

@@ -937,13 +937,15 @@ def test_load_cache_modifier(cache):
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])

View File

@@ -942,9 +942,9 @@ class Kernel:
assert _type == triton.language.constexpr, "only constexpr annotations are supported for now"
wargs[pos] = _type(wargs[pos])
# check that tensors are on GPU.
for arg in wargs:
if hasattr(arg, 'data_ptr'):
assert arg.is_cuda, "All tensors must be on GPU!"
# for arg in wargs:
# if hasattr(arg, 'data_ptr'):
# assert arg.is_cuda, "All tensors must be on GPU!"
# set device (i.e., make sure torch has the context initialized)
device = torch.cuda.current_device()
torch.cuda.set_device(device)