[BACKEND] Alignment pass improvements (#503)
This commit is contained in:
@@ -937,13 +937,15 @@ def test_load_cache_modifier(cache):
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
||||
def test_vectorization(N):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
|
@@ -942,9 +942,9 @@ class Kernel:
|
||||
assert _type == triton.language.constexpr, "only constexpr annotations are supported for now"
|
||||
wargs[pos] = _type(wargs[pos])
|
||||
# check that tensors are on GPU.
|
||||
for arg in wargs:
|
||||
if hasattr(arg, 'data_ptr'):
|
||||
assert arg.is_cuda, "All tensors must be on GPU!"
|
||||
# for arg in wargs:
|
||||
# if hasattr(arg, 'data_ptr'):
|
||||
# assert arg.is_cuda, "All tensors must be on GPU!"
|
||||
# set device (i.e., make sure torch has the context initialized)
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
|
Reference in New Issue
Block a user