[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)
This commit is contained in:
@@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype, shape",
|
||||
[(dtype, shape) \
|
||||
for dtype in dtypes\
|
||||
for shape in [128, 512]])
|
||||
def test_reduce1d(dtype, shape, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X + tl.arange(0, meta['BLOCK']))
|
||||
tl.store(Z, tl.sum(x, axis=0))
|
||||
|
||||
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
|
||||
kernel[(1,)](x, z_tri, BLOCK=shape)
|
||||
# torch result
|
||||
z_ref = torch.sum(x).to(dtype)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
|
Reference in New Issue
Block a user