[Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826)
This commit is contained in:
@@ -870,121 +870,142 @@ def test_store_bool():
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("op, dtype_str, shape",
|
||||
# [(op, dtype, shape)
|
||||
# for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
# for dtype in dtypes_with_bfloat16
|
||||
# for shape in [32, 64, 128, 512]])
|
||||
# def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, Z, BLOCK: tl.constexpr):
|
||||
# x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# tl.store(Z, GENERATE_TEST_HERE)
|
||||
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
||||
# # input
|
||||
# rs = RandomState(17)
|
||||
# # limit the range of integers so that the sum does not overflow
|
||||
# x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
# x_tri = to_triton(x, device=device)
|
||||
# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
# 'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# # numpy result
|
||||
# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
# z_tri_dtype_str = z_dtype_str
|
||||
# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
# z_dtype_str = 'float32'
|
||||
# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# # trunc mantissa for a fair comparison of accuracy
|
||||
# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
# z_tri_dtype_str = 'bfloat16'
|
||||
# else:
|
||||
# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# # triton result
|
||||
# z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
||||
# device=device, dst_type=z_tri_dtype_str)
|
||||
# kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||
# z_tri = to_numpy(z_tri)
|
||||
# # compare
|
||||
# if op == 'sum':
|
||||
# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
# else:
|
||||
# if op == 'argmin' or op == 'argmax':
|
||||
# # argmin and argmax can have multiple valid indices.
|
||||
# # so instead we compare the values pointed by indices
|
||||
# np.testing.assert_equal(x[z_ref], x[z_tri])
|
||||
# else:
|
||||
# np.testing.assert_equal(z_ref, z_tri)
|
||||
def get_reduced_dtype(dtype_str, op):
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
return 'int32'
|
||||
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
||||
return 'int32'
|
||||
if dtype_str == 'bfloat16':
|
||||
return 'float32'
|
||||
return dtype_str
|
||||
|
||||
|
||||
# reduce_configs1 = [
|
||||
# (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||
# for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
# for axis in [1]
|
||||
# ]
|
||||
# reduce_configs2 = [
|
||||
# (op, 'float32', shape, axis)
|
||||
# for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
# for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||
# for axis in [0, 1]
|
||||
# ]
|
||||
# TODO: [Qingyi] Fix argmin / argmax
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'sum']
|
||||
for dtype in dtypes_with_bfloat16
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.store(Z, GENERATE_TEST_HERE)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x, device=device)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# trunc mantissa for a fair comparison of accuracy
|
||||
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
z_tri_dtype_str = 'bfloat16'
|
||||
else:
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
# argmin and argmax can have multiple valid indices.
|
||||
# so instead we compare the values pointed by indices
|
||||
np.testing.assert_equal(x[z_ref], x[z_tri])
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
# def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
# range_m = tl.arange(0, BLOCK_M)
|
||||
# range_n = tl.arange(0, BLOCK_N)
|
||||
# x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
# z = GENERATE_TEST_HERE
|
||||
# if AXIS == 1:
|
||||
# tl.store(Z + range_m, z)
|
||||
# else:
|
||||
# tl.store(Z + range_n, z)
|
||||
# TODO: [Qingyi] Fix argmin / argmax
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||
for op in ['min', 'max', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
|
||||
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||
# # input
|
||||
# rs = RandomState(17)
|
||||
# # limit the range of integers so that the sum does not overflow
|
||||
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
# x_tri = to_triton(x)
|
||||
# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
# 'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
# z_tri_dtype_str = z_dtype_str
|
||||
# # numpy result
|
||||
# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
# z_dtype_str = 'float32'
|
||||
# z_tri_dtype_str = 'bfloat16'
|
||||
# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# # trunc mantissa for a fair comparison of accuracy
|
||||
# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
# else:
|
||||
# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# # triton result
|
||||
# z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
# device=device, dst_type=z_tri_dtype_str)
|
||||
# kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
# z_tri = to_numpy(z_tri)
|
||||
# # compare
|
||||
# if op == 'sum':
|
||||
# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
# else:
|
||||
# if op == 'argmin' or op == 'argmax':
|
||||
# # argmin and argmax can have multiple valid indices.
|
||||
# # so instead we compare the values pointed by indices
|
||||
# z_ref_index = np.expand_dims(z_ref, axis=axis)
|
||||
# z_tri_index = np.expand_dims(z_tri, axis=axis)
|
||||
# z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
||||
# z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
||||
# np.testing.assert_equal(z_ref_value, z_tri_value)
|
||||
# else:
|
||||
# np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
|
||||
# exceeds the limit of 99KB
|
||||
reduce2d_shapes = [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128)]
|
||||
if 'V100' in torch.cuda.get_device_name(0):
|
||||
reduce2d_shapes += [(128, 256) and (32, 1024)]
|
||||
|
||||
|
||||
reduce_configs2 = [
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'sum']
|
||||
for shape in reduce2d_shapes
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
range_m = tl.arange(0, BLOCK_M)
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
# numpy result
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
z_tri_dtype_str = 'bfloat16'
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# trunc mantissa for a fair comparison of accuracy
|
||||
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
||||
else:
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
# argmin and argmax can have multiple valid indices.
|
||||
# so instead we compare the values pointed by indices
|
||||
z_ref_index = np.expand_dims(z_ref, axis=axis)
|
||||
z_tri_index = np.expand_dims(z_tri, axis=axis)
|
||||
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
||||
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
||||
np.testing.assert_equal(z_ref_value, z_tri_value)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
# # ---------------
|
||||
# # test permute
|
||||
|
Reference in New Issue
Block a user