diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 9f4e18da8..ee7897e03 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -913,7 +913,7 @@ public: class reduce_inst: public builtin_inst { public: enum op_t{ - ADD, SUB, MAX, MIN, + ADD, SUB, MAX, MIN, UMAX, UMIN, FADD, FSUB, FMAX, FMIN, XOR }; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index f88ecf833..53cfb70fc 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -119,6 +119,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) #define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) #define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) +#define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__) +#define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__) #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) @@ -2498,6 +2500,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::SUB: return sub(x, y); case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); + case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y); + case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y); case ir::reduce_inst::FADD: return fadd(x, y); case ir::reduce_inst::FSUB: return fsub(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); @@ -2510,9 +2514,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { Value *neutral; switch(op) { case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break; case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; diff --git a/python/src/triton.cc b/python/src/triton.cc index 7f9e7e752..7ebd6b9b9 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -571,6 +571,8 @@ void init_triton_ir(py::module &&m) { .value("FADD", ir::reduce_inst::FADD) .value("MIN", ir::reduce_inst::MIN) .value("MAX", ir::reduce_inst::MAX) + .value("UMIN", ir::reduce_inst::UMIN) + .value("UMAX", ir::reduce_inst::UMAX) .value("FMIN", ir::reduce_inst::FMIN) .value("FMAX", ir::reduce_inst::FMAX) .value("XOR", ir::reduce_inst::XOR); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6ea3ebc9d..348672822 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -688,60 +688,78 @@ def test_f16_to_f8_rounding(): # --------------- -@pytest.mark.parametrize("dtype_str, shape", - [(dtype, shape) +@pytest.mark.parametrize("op, dtype_str, shape", + [(op, dtype, shape) + for op in ['min', 'max', 'sum'] for dtype in dtypes for shape in [32, 64, 128, 512]]) -def test_reduce1d(dtype_str, shape, device='cuda'): +def test_reduce1d(op, dtype_str, shape, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - tl.store(Z, tl.sum(x, axis=0)) + tl.store(Z, GENERATE_TEST_HERE) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'}) + # input rs = RandomState(17) + # limit the range of integers so that the sum does not overflow x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - x[:] = 1 - # numpy result - z_ref = np.sum(x).astype(getattr(np, dtype_str)) - # triton result x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + # numpy result + z_ref = numpy_op(x).astype(getattr(np, dtype_str)) + # triton result z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) kernel[(1,)](x_tri, z_tri, BLOCK=shape) # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + if op == 'sum': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + else: + np.testing.assert_equal(z_ref, to_numpy(z_tri)) reduce_configs1 = [ - (dtype, (1, 1024), axis) for dtype in ['float32', 'uint32'] + (op, dtype, (1, 1024), axis) for dtype in dtypes + for op in ['min', 'max', 'sum'] for axis in [1] ] reduce_configs2 = [ - ('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] + (op, 'float32', shape, 1) + for op in ['min', 'max', 'sum'] + for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] ] -@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2) -def test_reduce2d(dtype_str, shape, axis, device='cuda'): +@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) +def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) - z = tl.sum(x, axis=AXIS) + z = GENERATE_TEST_HERE tl.store(Z + range_m, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) # input - x = numpy_random(shape, dtype_str=dtype_str) - # triton result + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x) - z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) - # numpy reference result - z_ref = np.sum(x, axis=axis).astype(x.dtype) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + # numpy result + z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str)) + # triton result + z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) + binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + if op == 'sum': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + else: + np.testing.assert_equal(z_ref, to_numpy(z_tri)) # --------------- # test permute diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f0cc02e66..fa6f190e3 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -136,6 +136,9 @@ class dtype: def is_int_signed(self): return self.name in dtype.SINT_TYPES + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + def is_int(self): return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 753944285..e57faa5ec 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -959,6 +959,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: input = cast(input, tl.int32, builder) + # choose the right unsigned operation + if scalar_ty.is_int_unsigned(): + if INT_OP is ir.REDUCE_OP.MIN: + INT_OP = ir.REDUCE_OP.UMIN + elif INT_OP is ir.REDUCE_OP.MAX: + INT_OP = ir.REDUCE_OP.UMAX + # get result type shape = input.type.shape ret_shape = []