[FRONTEND] Fix argmin/max output type (#1012)

Currently Triton returns tensors with the input types rather than i32
when doing reduce argmax/argmin.
This commit is contained in:
Sharad Vikram
2023-01-04 15:12:16 +08:00
committed by GitHub
parent 8460ea3df1
commit bc73bbb12c

View File

@@ -1057,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# If we are doing an argmin or argmax we want to use an int32 output type
out_scalar_ty = scalar_ty
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
out_scalar_ty = tl.int32
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
out_scalar_ty = tl.int32
# get result type
shape = input.type.shape
ret_shape = []
@@ -1064,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if i != axis:
ret_shape.append(s)
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
res_ty = tl.block_type(out_scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
res_ty = out_scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)