Add argmin argmax (#552)
This commit is contained in:
@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_sqrt(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
|
||||
throw_not_int_or_float(name);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value of `input`.
|
||||
)pbdoc";
|
||||
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.max
|
||||
----------------------------------------------*/
|
||||
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sum
|
||||
----------------------------------------------*/
|
||||
|
@@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("MAX", ir::reduce_inst::MAX)
|
||||
.value("UMIN", ir::reduce_inst::UMIN)
|
||||
.value("UMAX", ir::reduce_inst::UMAX)
|
||||
.value("ARGMIN", ir::reduce_inst::ARGMIN)
|
||||
.value("ARGMAX", ir::reduce_inst::ARGMAX)
|
||||
.value("ARGUMIN", ir::reduce_inst::ARGUMIN)
|
||||
.value("ARGUMAX", ir::reduce_inst::ARGUMAX)
|
||||
.value("FMIN", ir::reduce_inst::FMIN)
|
||||
.value("FMAX", ir::reduce_inst::FMAX)
|
||||
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
|
||||
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
|
Reference in New Issue
Block a user