Add argmin argmax (#552)

This commit is contained in:
Keren Zhou
2022-06-15 13:55:20 -07:00
committed by GitHub
parent 6b9756532f
commit b5e728cb14
11 changed files with 345 additions and 101 deletions

View File

@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
throw_not_int_or_float(name);
}
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.arg_min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value's index of `input`.
)pbdoc";
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.arg_max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value's index of `input`.
)pbdoc";
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/

View File

@@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) {
.value("MAX", ir::reduce_inst::MAX)
.value("UMIN", ir::reduce_inst::UMIN)
.value("UMAX", ir::reduce_inst::UMAX)
.value("ARGMIN", ir::reduce_inst::ARGMIN)
.value("ARGMAX", ir::reduce_inst::ARGMAX)
.value("ARGUMIN", ir::reduce_inst::ARGUMIN)
.value("ARGUMAX", ir::reduce_inst::ARGUMAX)
.value("FMIN", ir::reduce_inst::FMIN)
.value("FMAX", ir::reduce_inst::FMAX)
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
.value("XOR", ir::reduce_inst::XOR);
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")