[LANG] Various (relatively minor) improvements (#320)
This commit is contained in:
@@ -515,6 +515,22 @@ def test_dot(epilogue, device='cuda'):
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
@triton.jit
|
||||
def _kernel(z, **meta):
|
||||
off = tl.arange(0, meta['BLOCK'])
|
||||
val = tl.arange(meta['START'], meta['END'])
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
|
@@ -112,7 +112,7 @@ BLOCK = 1024
|
||||
# test generation of random uint32
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
Reference in New Issue
Block a user