[BACKEND][CODEGEN] Fix reduce uint (#547)

This commit is contained in:
Keren Zhou
2022-06-13 16:43:57 -07:00
committed by GitHub
parent 58c8889235
commit 93209c07e0
6 changed files with 61 additions and 25 deletions

View File

@@ -688,60 +688,78 @@ def test_f16_to_f8_rounding():
# ---------------
@pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape)
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'sum']
for dtype in dtypes
for shape in [32, 64, 128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'):
def test_reduce1d(op, dtype_str, shape, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, tl.sum(x, axis=0))
tl.store(Z, GENERATE_TEST_HERE)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
# input
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
x[:] = 1
# numpy result
z_ref = np.sum(x).astype(getattr(np, dtype_str))
# triton result
x_tri = to_triton(x, device=device)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
# numpy result
z_ref = numpy_op(x).astype(getattr(np, dtype_str))
# triton result
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if op == 'sum':
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
else:
np.testing.assert_equal(z_ref, to_numpy(z_tri))
reduce_configs1 = [
(dtype, (1, 1024), axis) for dtype in ['float32', 'uint32']
(op, dtype, (1, 1024), axis) for dtype in dtypes
for op in ['min', 'max', 'sum']
for axis in [1]
]
reduce_configs2 = [
('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
(op, 'float32', shape, 1)
for op in ['min', 'max', 'sum']
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
]
@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=AXIS)
z = GENERATE_TEST_HERE
tl.store(Z + range_m, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
# input
x = numpy_random(shape, dtype_str=dtype_str)
# triton result
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
x_tri = to_triton(x)
z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
# numpy reference result
z_ref = np.sum(x, axis=axis).astype(x.dtype)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
# numpy result
z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device)
binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if op == 'sum':
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
else:
np.testing.assert_equal(z_ref, to_numpy(z_tri))
# ---------------
# test permute