From cdc0ec50778fdf539e17f328e63bc95f48863ced Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Tue, 1 Nov 2022 17:42:59 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826) --- .../Dialect/Triton/IR/TritonAttrDefs.td | 18 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 28 +- python/src/triton.cc | 8 + python/tests/test_core.py | 241 ++++++++++-------- python/tests/test_reduce.py | 61 +++-- 5 files changed, 208 insertions(+), 148 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index cd4423041..cdcb69576 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -29,12 +29,20 @@ def TT_RedOpAttr : I32EnumAttr< /*case*/ [ I32EnumAttrCase, - I32EnumAttrCase<"MAX", 2, "max">, + I32EnumAttrCase<"FADD", 2, "fadd">, I32EnumAttrCase<"MIN", 3, "min">, - I32EnumAttrCase<"FADD", 4, "fadd">, - I32EnumAttrCase<"FMAX", 5, "fmax">, - I32EnumAttrCase<"FMIN", 6, "fmin">, - I32EnumAttrCase<"XOR", 7, "xor"> + I32EnumAttrCase<"MAX", 4, "max">, + I32EnumAttrCase<"UMIN", 5, "umin">, + I32EnumAttrCase<"UMAX", 6, "umax">, + I32EnumAttrCase<"ARGMIN", 7, "argmin">, + I32EnumAttrCase<"ARGMAX", 8, "argmax">, + I32EnumAttrCase<"ARGUMIN", 9, "argumin">, + I32EnumAttrCase<"ARGUMAX", 10, "argumax">, + I32EnumAttrCase<"FMIN", 11, "fmin">, + I32EnumAttrCase<"FMAX", 12, "fmax">, + I32EnumAttrCase<"ARGFMIN", 13, "argfmin">, + I32EnumAttrCase<"ARGFMAX", 14, "argfmax">, + I32EnumAttrCase<"XOR", 15, "xor"> ]> { let cppNamespace = "::mlir::triton"; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index f1641f05c..85f20a5eb 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1301,27 +1301,27 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter, case RedOp::ADD: acc = add(acc, cur); break; - case RedOp::MAX: - if (type.isUnsignedInteger()) - acc = umax(acc, cur); - else - acc = smax(acc, cur); - break; - case RedOp::MIN: - if (type.isUnsignedInteger()) - acc = umin(acc, cur); - else - acc = smin(acc, cur); - break; case RedOp::FADD: acc = fadd(acc.getType(), acc, cur); break; - case RedOp::FMAX: - acc = fmax(acc, cur); + case RedOp::MIN: + acc = smin(acc, cur); + break; + case RedOp::MAX: + acc = smax(acc, cur); + break; + case RedOp::UMIN: + acc = umin(acc, cur); + break; + case RedOp::UMAX: + acc = umax(acc, cur); break; case RedOp::FMIN: acc = fmin(acc, cur); break; + case RedOp::FMAX: + acc = fmax(acc, cur); + break; case RedOp::XOR: acc = xor_(acc, cur); break; diff --git a/python/src/triton.cc b/python/src/triton.cc index 24fc6406a..56ee90b18 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -87,8 +87,16 @@ void init_triton_ir(py::module &&m) { .value("FADD", mlir::triton::RedOp::FADD) .value("MIN", mlir::triton::RedOp::MIN) .value("MAX", mlir::triton::RedOp::MAX) + .value("UMIN", mlir::triton::RedOp::UMIN) + .value("UMAX", mlir::triton::RedOp::UMAX) + .value("ARGMIN", mlir::triton::RedOp::ARGMIN) + .value("ARGMAX", mlir::triton::RedOp::ARGMAX) + .value("ARGUMIN", mlir::triton::RedOp::ARGUMIN) + .value("ARGUMAX", mlir::triton::RedOp::ARGUMAX) .value("FMIN", mlir::triton::RedOp::FMIN) .value("FMAX", mlir::triton::RedOp::FMAX) + .value("ARGFMIN", mlir::triton::RedOp::ARGFMIN) + .value("ARGFMAX", mlir::triton::RedOp::ARGFMAX) .value("XOR", mlir::triton::RedOp::XOR); py::enum_(m, "ATOMIC_OP") diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 1a6a15c82..afcc67d9d 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -870,121 +870,142 @@ def test_store_bool(): # # --------------- -# @pytest.mark.parametrize("op, dtype_str, shape", -# [(op, dtype, shape) -# for op in ['min', 'max', 'argmin', 'argmax', 'sum'] -# for dtype in dtypes_with_bfloat16 -# for shape in [32, 64, 128, 512]]) -# def test_reduce1d(op, dtype_str, shape, device='cuda'): -# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested - -# # triton kernel -# @triton.jit -# def kernel(X, Z, BLOCK: tl.constexpr): -# x = tl.load(X + tl.arange(0, BLOCK)) -# tl.store(Z, GENERATE_TEST_HERE) - -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'}) -# # input -# rs = RandomState(17) -# # limit the range of integers so that the sum does not overflow -# x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) -# x_tri = to_triton(x, device=device) -# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, -# 'argmin': np.argmin, 'argmax': np.argmax}[op] -# # numpy result -# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str -# z_tri_dtype_str = z_dtype_str -# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': -# z_dtype_str = 'float32' -# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) -# # trunc mantissa for a fair comparison of accuracy -# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') -# z_tri_dtype_str = 'bfloat16' -# else: -# z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) -# # triton result -# z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), -# device=device, dst_type=z_tri_dtype_str) -# kernel[(1,)](x_tri, z_tri, BLOCK=shape) -# z_tri = to_numpy(z_tri) -# # compare -# if op == 'sum': -# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) -# else: -# if op == 'argmin' or op == 'argmax': -# # argmin and argmax can have multiple valid indices. -# # so instead we compare the values pointed by indices -# np.testing.assert_equal(x[z_ref], x[z_tri]) -# else: -# np.testing.assert_equal(z_ref, z_tri) +def get_reduced_dtype(dtype_str, op): + if op == 'argmin' or op == 'argmax': + return 'int32' + if dtype_str in ['int8', 'uint8', 'int16', 'uint16']: + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str -# reduce_configs1 = [ -# (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 -# for op in ['min', 'max', 'argmin', 'argmax', 'sum'] -# for axis in [1] -# ] -# reduce_configs2 = [ -# (op, 'float32', shape, axis) -# for op in ['min', 'max', 'argmin', 'argmax', 'sum'] -# for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] -# for axis in [0, 1] -# ] +# TODO: [Qingyi] Fix argmin / argmax +@pytest.mark.parametrize("op, dtype_str, shape", + [(op, dtype, shape) + for op in ['min', 'max', 'sum'] + for dtype in dtypes_with_bfloat16 + for shape in [32, 64, 128, 512]]) +def test_reduce1d(op, dtype_str, shape, device='cuda'): + check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Z, GENERATE_TEST_HERE) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, + 'argmin': np.argmin, 'argmax': np.argmax}[op] + # numpy result + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), + device=device, dst_type=z_tri_dtype_str) + kernel[(1,)](x_tri, z_tri, BLOCK=shape) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op == 'argmin' or op == 'argmax': + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) -# @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) -# def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): -# # triton kernel -# @triton.jit -# def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): -# range_m = tl.arange(0, BLOCK_M) -# range_n = tl.arange(0, BLOCK_N) -# x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) -# z = GENERATE_TEST_HERE -# if AXIS == 1: -# tl.store(Z + range_m, z) -# else: -# tl.store(Z + range_n, z) +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [ + (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum'] + for axis in [1] +] -# kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) -# # input -# rs = RandomState(17) -# # limit the range of integers so that the sum does not overflow -# x = numpy_random(shape, dtype_str=dtype_str, rs=rs) -# x_tri = to_triton(x) -# numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, -# 'argmin': np.argmin, 'argmax': np.argmax}[op] -# z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str -# z_tri_dtype_str = z_dtype_str -# # numpy result -# if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': -# z_dtype_str = 'float32' -# z_tri_dtype_str = 'bfloat16' -# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) -# # trunc mantissa for a fair comparison of accuracy -# z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') -# else: -# z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) -# # triton result -# z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), -# device=device, dst_type=z_tri_dtype_str) -# kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) -# z_tri = to_numpy(z_tri) -# # compare -# if op == 'sum': -# np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) -# else: -# if op == 'argmin' or op == 'argmax': -# # argmin and argmax can have multiple valid indices. -# # so instead we compare the values pointed by indices -# z_ref_index = np.expand_dims(z_ref, axis=axis) -# z_tri_index = np.expand_dims(z_tri, axis=axis) -# z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) -# z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) -# np.testing.assert_equal(z_ref_value, z_tri_value) -# else: -# np.testing.assert_equal(z_ref, z_tri) + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128)] +if 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + + +reduce_configs2 = [ + (op, 'float32', shape, axis) + for op in ['min', 'max', 'sum'] + for shape in reduce2d_shapes + for axis in [0, 1] +] + + +@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) +def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + if AXIS == 1: + tl.store(Z + range_m, z) + else: + tl.store(Z + range_n, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, + 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + # numpy result + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), + device=device, dst_type=z_tri_dtype_str) + kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op == 'argmin' or op == 'argmax': + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) # # --------------- # # test permute diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 0c92b8a91..01c16ac0f 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -5,11 +5,20 @@ from torch.testing import assert_close import triton import triton.language as tl -dtype_mapping = { - 'float16': torch.float16, - 'float32': torch.float32, - 'float64': torch.float64, -} +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64 +float_dtypes = ['float16', 'float32', 'float64'] +dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes +dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes} + + +def get_reduced_dtype(dtype): + if dtype in [torch.int8, torch.int16, torch.uint8]: + return torch.int32 + if dtype in [torch.bfloat16]: + return torch.float32 + return dtype def patch_kernel(template, to_replace): @@ -40,7 +49,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo reduce1d_configs = [ (op, dtype, shape) for op in ['sum', 'min', 'max'] - for dtype in ['float16', 'float32', 'float64'] + for dtype in dtypes for shape in [4, 8, 16, 32, 64, 128, 512, 1024] ] @@ -48,11 +57,18 @@ reduce1d_configs = [ @pytest.mark.parametrize('op, dtype, shape', reduce1d_configs) def test_reduce1d(op, dtype, shape): dtype = dtype_mapping[dtype] - x = torch.randn((shape,), device='cuda', dtype=dtype) + reduced_dtype = get_reduced_dtype(dtype) + + if dtype.is_floating_point: + x = torch.randn((shape,), device='cuda', dtype=dtype) + elif dtype is torch.uint8: + x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype) + else: + x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype) z = torch.empty( tuple(), device=x.device, - dtype=dtype, + dtype=reduced_dtype, ) kernel = patch_kernel(reduce1d_kernel, {'OP': op}) @@ -60,13 +76,13 @@ def test_reduce1d(op, dtype, shape): kernel[grid](x_ptr=x, z_ptr=z, block=shape) if op == 'sum': - golden_z = torch.sum(x, dtype=dtype) + golden_z = torch.sum(x, dtype=reduced_dtype) elif op == 'min': - golden_z = torch.min(x) + golden_z = torch.min(x).to(reduced_dtype) else: - golden_z = torch.max(x) + golden_z = torch.max(x).to(reduced_dtype) - if op == 'sum': + if dtype.is_floating_point and op == 'sum': if shape >= 256: assert_close(z, golden_z, rtol=0.05, atol=0.1) elif shape >= 32: @@ -80,7 +96,7 @@ def test_reduce1d(op, dtype, shape): reduce2d_configs = [ (op, dtype, shape, axis) for op in ['sum', 'min', 'max'] - for dtype in ['float16', 'float32', 'float64'] + for dtype in dtypes for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)] for axis in [0, 1] ] @@ -89,22 +105,29 @@ reduce2d_configs = [ @pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs) def test_reduce2d(op, dtype, shape, axis): dtype = dtype_mapping[dtype] - x = torch.randn(shape, device='cuda', dtype=dtype) + reduced_dtype = get_reduced_dtype(dtype) reduced_shape = (shape[1 - axis],) - z = torch.empty(reduced_shape, device=x.device, dtype=dtype) + + if dtype.is_floating_point: + x = torch.randn(shape, device='cuda', dtype=dtype) + elif dtype is torch.uint8: + x = torch.randint(0, 20, shape, device='cuda', dtype=dtype) + else: + x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype) + z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype) kernel = patch_kernel(reduce2d_kernel, {'OP': op}) grid = (1,) kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1]) if op == 'sum': - golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype) + golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype) elif op == 'min': - golden_z = torch.min(x, dim=axis, keepdim=False)[0] + golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype) else: - golden_z = torch.max(x, dim=axis, keepdim=False)[0] + golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype) - if op == 'sum': + if dtype.is_floating_point and op == 'sum': if shape[axis] >= 256: assert_close(z, golden_z, rtol=0.05, atol=0.1) elif shape[axis] >= 32: