[LANG] Various (relatively minor) improvements (#320)
This commit is contained in:
@@ -515,6 +515,22 @@ def test_dot(epilogue, device='cuda'):
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
@triton.jit
|
||||
def _kernel(z, **meta):
|
||||
off = tl.arange(0, meta['BLOCK'])
|
||||
val = tl.arange(meta['START'], meta['END'])
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
|
Reference in New Issue
Block a user