[Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826)

This commit is contained in:
Qingyi Liu
2022-11-01 17:42:59 +08:00
committed by GitHub
parent 031c2ae77b
commit cdc0ec5077
5 changed files with 208 additions and 148 deletions

View File

@@ -29,12 +29,20 @@ def TT_RedOpAttr : I32EnumAttr<
/*case*/ /*case*/
[ [
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">, I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"MAX", 2, "max">, I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">, I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"FADD", 4, "fadd">, I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"FMAX", 5, "fmax">, I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"FMIN", 6, "fmin">, I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"XOR", 7, "xor"> I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> { ]> {
let cppNamespace = "::mlir::triton"; let cppNamespace = "::mlir::triton";
} }

View File

@@ -1301,27 +1301,27 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
case RedOp::ADD: case RedOp::ADD:
acc = add(acc, cur); acc = add(acc, cur);
break; break;
case RedOp::MAX:
if (type.isUnsignedInteger())
acc = umax(acc, cur);
else
acc = smax(acc, cur);
break;
case RedOp::MIN:
if (type.isUnsignedInteger())
acc = umin(acc, cur);
else
acc = smin(acc, cur);
break;
case RedOp::FADD: case RedOp::FADD:
acc = fadd(acc.getType(), acc, cur); acc = fadd(acc.getType(), acc, cur);
break; break;
case RedOp::FMAX: case RedOp::MIN:
acc = fmax(acc, cur); acc = smin(acc, cur);
break;
case RedOp::MAX:
acc = smax(acc, cur);
break;
case RedOp::UMIN:
acc = umin(acc, cur);
break;
case RedOp::UMAX:
acc = umax(acc, cur);
break; break;
case RedOp::FMIN: case RedOp::FMIN:
acc = fmin(acc, cur); acc = fmin(acc, cur);
break; break;
case RedOp::FMAX:
acc = fmax(acc, cur);
break;
case RedOp::XOR: case RedOp::XOR:
acc = xor_(acc, cur); acc = xor_(acc, cur);
break; break;

View File

@@ -87,8 +87,16 @@ void init_triton_ir(py::module &&m) {
.value("FADD", mlir::triton::RedOp::FADD) .value("FADD", mlir::triton::RedOp::FADD)
.value("MIN", mlir::triton::RedOp::MIN) .value("MIN", mlir::triton::RedOp::MIN)
.value("MAX", mlir::triton::RedOp::MAX) .value("MAX", mlir::triton::RedOp::MAX)
.value("UMIN", mlir::triton::RedOp::UMIN)
.value("UMAX", mlir::triton::RedOp::UMAX)
.value("ARGMIN", mlir::triton::RedOp::ARGMIN)
.value("ARGMAX", mlir::triton::RedOp::ARGMAX)
.value("ARGUMIN", mlir::triton::RedOp::ARGUMIN)
.value("ARGUMAX", mlir::triton::RedOp::ARGUMAX)
.value("FMIN", mlir::triton::RedOp::FMIN) .value("FMIN", mlir::triton::RedOp::FMIN)
.value("FMAX", mlir::triton::RedOp::FMAX) .value("FMAX", mlir::triton::RedOp::FMAX)
.value("ARGFMIN", mlir::triton::RedOp::ARGFMIN)
.value("ARGFMAX", mlir::triton::RedOp::ARGFMAX)
.value("XOR", mlir::triton::RedOp::XOR); .value("XOR", mlir::triton::RedOp::XOR);
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP") py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")

View File

@@ -870,121 +870,142 @@ def test_store_bool():
# # --------------- # # ---------------
# @pytest.mark.parametrize("op, dtype_str, shape", def get_reduced_dtype(dtype_str, op):
# [(op, dtype, shape) if op == 'argmin' or op == 'argmax':
# for op in ['min', 'max', 'argmin', 'argmax', 'sum'] return 'int32'
# for dtype in dtypes_with_bfloat16 if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
# for shape in [32, 64, 128, 512]]) return 'int32'
# def test_reduce1d(op, dtype_str, shape, device='cuda'): if dtype_str == 'bfloat16':
# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested return 'float32'
return dtype_str
# # triton kernel
# @triton.jit
# def kernel(X, Z, BLOCK: tl.constexpr):
# x = tl.load(X + tl.arange(0, BLOCK))
# tl.store(Z, GENERATE_TEST_HERE)
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
# # input
# rs = RandomState(17)
# # limit the range of integers so that the sum does not overflow
# x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
# x_tri = to_triton(x, device=device)
# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
# 'argmin': np.argmin, 'argmax': np.argmax}[op]
# # numpy result
# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
# z_tri_dtype_str = z_dtype_str
# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
# z_dtype_str = 'float32'
# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# # trunc mantissa for a fair comparison of accuracy
# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
# z_tri_dtype_str = 'bfloat16'
# else:
# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# # triton result
# z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
# device=device, dst_type=z_tri_dtype_str)
# kernel[(1,)](x_tri, z_tri, BLOCK=shape)
# z_tri = to_numpy(z_tri)
# # compare
# if op == 'sum':
# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
# else:
# if op == 'argmin' or op == 'argmax':
# # argmin and argmax can have multiple valid indices.
# # so instead we compare the values pointed by indices
# np.testing.assert_equal(x[z_ref], x[z_tri])
# else:
# np.testing.assert_equal(z_ref, z_tri)
# reduce_configs1 = [ # TODO: [Qingyi] Fix argmin / argmax
# (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 @pytest.mark.parametrize("op, dtype_str, shape",
# for op in ['min', 'max', 'argmin', 'argmax', 'sum'] [(op, dtype, shape)
# for axis in [1] for op in ['min', 'max', 'sum']
# ] for dtype in dtypes_with_bfloat16
# reduce_configs2 = [ for shape in [32, 64, 128, 512]])
# (op, 'float32', shape, axis) def test_reduce1d(op, dtype_str, shape, device='cuda'):
# for op in ['min', 'max', 'argmin', 'argmax', 'sum'] check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
# for axis in [0, 1] # triton kernel
# ] @triton.jit
def kernel(X, Z, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, GENERATE_TEST_HERE)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
# input
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
x_tri = to_triton(x, device=device)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result
z_dtype_str = get_reduced_dtype(dtype_str, op)
z_tri_dtype_str = z_dtype_str
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
z_tri_dtype_str = 'bfloat16'
else:
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
z_tri = to_numpy(z_tri)
# compare
if op == 'sum':
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
else:
if op == 'argmin' or op == 'argmax':
# argmin and argmax can have multiple valid indices.
# so instead we compare the values pointed by indices
np.testing.assert_equal(x[z_ref], x[z_tri])
else:
np.testing.assert_equal(z_ref, z_tri)
# @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) # TODO: [Qingyi] Fix argmin / argmax
# def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): reduce_configs1 = [
# # triton kernel (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
# @triton.jit for op in ['min', 'max', 'sum']
# def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): for axis in [1]
# range_m = tl.arange(0, BLOCK_M) ]
# range_n = tl.arange(0, BLOCK_N)
# x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
# z = GENERATE_TEST_HERE
# if AXIS == 1:
# tl.store(Z + range_m, z)
# else:
# tl.store(Z + range_n, z)
# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
# # input # shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
# rs = RandomState(17) # exceeds the limit of 99KB
# # limit the range of integers so that the sum does not overflow reduce2d_shapes = [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128)]
# x = numpy_random(shape, dtype_str=dtype_str, rs=rs) if 'V100' in torch.cuda.get_device_name(0):
# x_tri = to_triton(x) reduce2d_shapes += [(128, 256) and (32, 1024)]
# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
# 'argmin': np.argmin, 'argmax': np.argmax}[op]
# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str reduce_configs2 = [
# z_tri_dtype_str = z_dtype_str (op, 'float32', shape, axis)
# # numpy result for op in ['min', 'max', 'sum']
# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': for shape in reduce2d_shapes
# z_dtype_str = 'float32' for axis in [0, 1]
# z_tri_dtype_str = 'bfloat16' ]
# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# # trunc mantissa for a fair comparison of accuracy
# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
# else: def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton kernel
# # triton result @triton.jit
# z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
# device=device, dst_type=z_tri_dtype_str) range_m = tl.arange(0, BLOCK_M)
# kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) range_n = tl.arange(0, BLOCK_N)
# z_tri = to_numpy(z_tri) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
# # compare z = GENERATE_TEST_HERE
# if op == 'sum': if AXIS == 1:
# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) tl.store(Z + range_m, z)
# else: else:
# if op == 'argmin' or op == 'argmax': tl.store(Z + range_n, z)
# # argmin and argmax can have multiple valid indices.
# # so instead we compare the values pointed by indices kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
# z_ref_index = np.expand_dims(z_ref, axis=axis) # input
# z_tri_index = np.expand_dims(z_tri, axis=axis) rs = RandomState(17)
# z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) # limit the range of integers so that the sum does not overflow
# z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
# np.testing.assert_equal(z_ref_value, z_tri_value) x_tri = to_triton(x)
# else: numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
# np.testing.assert_equal(z_ref, z_tri) 'argmin': np.argmin, 'argmax': np.argmax}[op]
z_dtype_str = get_reduced_dtype(dtype_str, op)
z_tri_dtype_str = z_dtype_str
# numpy result
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_tri_dtype_str = 'bfloat16'
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
else:
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri)
# compare
if op == 'sum':
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
else:
if op == 'argmin' or op == 'argmax':
# argmin and argmax can have multiple valid indices.
# so instead we compare the values pointed by indices
z_ref_index = np.expand_dims(z_ref, axis=axis)
z_tri_index = np.expand_dims(z_tri, axis=axis)
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
np.testing.assert_equal(z_ref_value, z_tri_value)
else:
np.testing.assert_equal(z_ref, z_tri)
# # --------------- # # ---------------
# # test permute # # test permute

View File

@@ -5,11 +5,20 @@ from torch.testing import assert_close
import triton import triton
import triton.language as tl import triton.language as tl
dtype_mapping = { int_dtypes = ['int8', 'int16', 'int32', 'int64']
'float16': torch.float16, uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
'float32': torch.float32, float_dtypes = ['float16', 'float32', 'float64']
'float64': torch.float64, dtypes = int_dtypes + uint_dtypes + float_dtypes
} dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
def get_reduced_dtype(dtype):
if dtype in [torch.int8, torch.int16, torch.uint8]:
return torch.int32
if dtype in [torch.bfloat16]:
return torch.float32
return dtype
def patch_kernel(template, to_replace): def patch_kernel(template, to_replace):
@@ -40,7 +49,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo
reduce1d_configs = [ reduce1d_configs = [
(op, dtype, shape) (op, dtype, shape)
for op in ['sum', 'min', 'max'] for op in ['sum', 'min', 'max']
for dtype in ['float16', 'float32', 'float64'] for dtype in dtypes
for shape in [4, 8, 16, 32, 64, 128, 512, 1024] for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
] ]
@@ -48,11 +57,18 @@ reduce1d_configs = [
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs) @pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
def test_reduce1d(op, dtype, shape): def test_reduce1d(op, dtype, shape):
dtype = dtype_mapping[dtype] dtype = dtype_mapping[dtype]
x = torch.randn((shape,), device='cuda', dtype=dtype) reduced_dtype = get_reduced_dtype(dtype)
if dtype.is_floating_point:
x = torch.randn((shape,), device='cuda', dtype=dtype)
elif dtype is torch.uint8:
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
else:
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
z = torch.empty( z = torch.empty(
tuple(), tuple(),
device=x.device, device=x.device,
dtype=dtype, dtype=reduced_dtype,
) )
kernel = patch_kernel(reduce1d_kernel, {'OP': op}) kernel = patch_kernel(reduce1d_kernel, {'OP': op})
@@ -60,13 +76,13 @@ def test_reduce1d(op, dtype, shape):
kernel[grid](x_ptr=x, z_ptr=z, block=shape) kernel[grid](x_ptr=x, z_ptr=z, block=shape)
if op == 'sum': if op == 'sum':
golden_z = torch.sum(x, dtype=dtype) golden_z = torch.sum(x, dtype=reduced_dtype)
elif op == 'min': elif op == 'min':
golden_z = torch.min(x) golden_z = torch.min(x).to(reduced_dtype)
else: else:
golden_z = torch.max(x) golden_z = torch.max(x).to(reduced_dtype)
if op == 'sum': if dtype.is_floating_point and op == 'sum':
if shape >= 256: if shape >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1) assert_close(z, golden_z, rtol=0.05, atol=0.1)
elif shape >= 32: elif shape >= 32:
@@ -80,7 +96,7 @@ def test_reduce1d(op, dtype, shape):
reduce2d_configs = [ reduce2d_configs = [
(op, dtype, shape, axis) (op, dtype, shape, axis)
for op in ['sum', 'min', 'max'] for op in ['sum', 'min', 'max']
for dtype in ['float16', 'float32', 'float64'] for dtype in dtypes
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)] for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
for axis in [0, 1] for axis in [0, 1]
] ]
@@ -89,22 +105,29 @@ reduce2d_configs = [
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs) @pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
def test_reduce2d(op, dtype, shape, axis): def test_reduce2d(op, dtype, shape, axis):
dtype = dtype_mapping[dtype] dtype = dtype_mapping[dtype]
x = torch.randn(shape, device='cuda', dtype=dtype) reduced_dtype = get_reduced_dtype(dtype)
reduced_shape = (shape[1 - axis],) reduced_shape = (shape[1 - axis],)
z = torch.empty(reduced_shape, device=x.device, dtype=dtype)
if dtype.is_floating_point:
x = torch.randn(shape, device='cuda', dtype=dtype)
elif dtype is torch.uint8:
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
else:
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
kernel = patch_kernel(reduce2d_kernel, {'OP': op}) kernel = patch_kernel(reduce2d_kernel, {'OP': op})
grid = (1,) grid = (1,)
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1]) kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
if op == 'sum': if op == 'sum':
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype) golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
elif op == 'min': elif op == 'min':
golden_z = torch.min(x, dim=axis, keepdim=False)[0] golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
else: else:
golden_z = torch.max(x, dim=axis, keepdim=False)[0] golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
if op == 'sum': if dtype.is_floating_point and op == 'sum':
if shape[axis] >= 256: if shape[axis] >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1) assert_close(z, golden_z, rtol=0.05, atol=0.1)
elif shape[axis] >= 32: elif shape[axis] >= 32: