diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index 0c8295948..b6cc7dcc1 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -89,6 +89,7 @@ struct dispatch{ static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); // math static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index fdb2fd411..28ff9f3d6 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -788,7 +788,8 @@ class reduce_inst: public builtin_inst { public: enum op_t{ ADD, SUB, MAX, MIN, - FADD, FSUB, FMAX, FMIN + FADD, FSUB, FMAX, FMIN, + XOR }; private: diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index d5c5c4902..a838bfa5a 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1890,6 +1890,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::FSUB: return fsub(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); + case ir::reduce_inst::XOR: return xor_(x, y); default: throw std::runtime_error("unreachable"); } }; @@ -1904,6 +1905,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; + case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break; default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 811e5c819..477f4dce0 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -714,6 +714,13 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); } +ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + if (!scalar_ty->is_integer_ty()) + throw semantic_error("xor_sum only supported for integers"); + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); +} + //===----------------------------------------------------------------------===// // Math diff --git a/python/src/triton.cc b/python/src/triton.cc index aca171dae..92df2ae27 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -520,6 +520,7 @@ void init_triton_frontend(py::module &&m) { m.def("min", &ir::dispatch::min, ret::reference); m.def("max", &ir::dispatch::max, ret::reference); m.def("sum", &ir::dispatch::sum, ret::reference); + m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); // math m.def("umulhi", &ir::dispatch::umulhi, ret::reference); m.def("exp", &ir::dispatch::exp, ret::reference); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 55b5bc0d9..5db77efdc 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -719,6 +719,11 @@ def min(input, axis, _builder=None): def sum(input, axis, _builder=None): return frontend.sum(input, axis, _builder) +@builtin +@_add_reduction_docstr("xor sum") +def xor_sum(input, axis, _builder=None): + return frontend.xor_sum(input, axis, _builder) + # ----------------------- # Internal for debugging