[LANG] Various (relatively minor) improvements (#320)

This commit is contained in:
Philippe Tillet
2021-10-04 18:39:40 -07:00
committed by GitHub
parent 12b6158c5c
commit 5123db0b7d
10 changed files with 59 additions and 16 deletions

View File

@@ -515,6 +515,22 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test arange
# ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit
def _kernel(z, **meta):
off = tl.arange(0, meta['BLOCK'])
val = tl.arange(meta['START'], meta['END'])
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
# test load

View File

@@ -112,7 +112,7 @@ BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
for seed in [0, 42, 124, 54]]
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))