[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)

This commit is contained in:
Philippe Tillet
2021-08-14 21:07:01 -07:00
committed by GitHub
parent 6e7593b446
commit bb1eebb4b4
3 changed files with 45 additions and 4 deletions

View File

@@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype, shape",
[(dtype, shape) \
for dtype in dtypes\
for shape in [128, 512]])
def test_reduce1d(dtype, shape, device='cuda'):
dtype = cvt[dtype]
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X + tl.arange(0, meta['BLOCK']))
tl.store(Z, tl.sum(x, axis=0))
x = triton.testing.random((shape,), dtype=dtype, device=device)
# triton result
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
kernel[(1,)](x, z_tri, BLOCK=shape)
# torch result
z_ref = torch.sum(x).to(dtype)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
# test load