diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index f819b1710..d65931b48 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1057,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if INT_OP in int_op_to_unit: INT_OP = int_op_to_unit[INT_OP] + # If we are doing an argmin or argmax we want to use an int32 output type + out_scalar_ty = scalar_ty + if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX: + out_scalar_ty = tl.int32 + elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN: + out_scalar_ty = tl.int32 + # get result type shape = input.type.shape ret_shape = [] @@ -1064,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, if i != axis: ret_shape.append(s) if ret_shape: - res_ty = tl.block_type(scalar_ty, ret_shape) + res_ty = tl.block_type(out_scalar_ty, ret_shape) else: # 0d-tensor -> scalar - res_ty = scalar_ty + res_ty = out_scalar_ty if scalar_ty.is_floating(): return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)