[FRONTEND] Fix argmin/max output type (#1012)
Currently Triton returns tensors with the input types rather than i32 when doing reduce argmax/argmin.
This commit is contained in:
@@ -1057,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|||||||
if INT_OP in int_op_to_unit:
|
if INT_OP in int_op_to_unit:
|
||||||
INT_OP = int_op_to_unit[INT_OP]
|
INT_OP = int_op_to_unit[INT_OP]
|
||||||
|
|
||||||
|
# If we are doing an argmin or argmax we want to use an int32 output type
|
||||||
|
out_scalar_ty = scalar_ty
|
||||||
|
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
|
||||||
|
out_scalar_ty = tl.int32
|
||||||
|
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
|
||||||
|
out_scalar_ty = tl.int32
|
||||||
|
|
||||||
# get result type
|
# get result type
|
||||||
shape = input.type.shape
|
shape = input.type.shape
|
||||||
ret_shape = []
|
ret_shape = []
|
||||||
@@ -1064,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|||||||
if i != axis:
|
if i != axis:
|
||||||
ret_shape.append(s)
|
ret_shape.append(s)
|
||||||
if ret_shape:
|
if ret_shape:
|
||||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
res_ty = tl.block_type(out_scalar_ty, ret_shape)
|
||||||
else:
|
else:
|
||||||
# 0d-tensor -> scalar
|
# 0d-tensor -> scalar
|
||||||
res_ty = scalar_ty
|
res_ty = out_scalar_ty
|
||||||
|
|
||||||
if scalar_ty.is_floating():
|
if scalar_ty.is_floating():
|
||||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||||
|
Reference in New Issue
Block a user