[FRONTEND] Added xor_sum
This commit is contained in:
@@ -89,6 +89,7 @@ struct dispatch{
|
|||||||
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
|
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||||
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
|
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||||
|
static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder);
|
||||||
|
|
||||||
// math
|
// math
|
||||||
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||||
|
@@ -788,7 +788,8 @@ class reduce_inst: public builtin_inst {
|
|||||||
public:
|
public:
|
||||||
enum op_t{
|
enum op_t{
|
||||||
ADD, SUB, MAX, MIN,
|
ADD, SUB, MAX, MIN,
|
||||||
FADD, FSUB, FMAX, FMIN
|
FADD, FSUB, FMAX, FMIN,
|
||||||
|
XOR
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -1890,6 +1890,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
case ir::reduce_inst::FSUB: return fsub(x, y);
|
case ir::reduce_inst::FSUB: return fsub(x, y);
|
||||||
case ir::reduce_inst::FMAX: return max_num(x, y);
|
case ir::reduce_inst::FMAX: return max_num(x, y);
|
||||||
case ir::reduce_inst::FMIN: return min_num(x, y);
|
case ir::reduce_inst::FMIN: return min_num(x, y);
|
||||||
|
case ir::reduce_inst::XOR: return xor_(x, y);
|
||||||
default: throw std::runtime_error("unreachable");
|
default: throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1904,6 +1905,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
||||||
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||||
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
|
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
|
||||||
|
case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break;
|
||||||
default: throw std::runtime_error("unreachable");
|
default: throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
ir::value *arg = x->get_operand(0);
|
ir::value *arg = x->get_operand(0);
|
||||||
|
@@ -714,6 +714,13 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build
|
|||||||
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
|
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||||
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
|
if (!scalar_ty->is_integer_ty())
|
||||||
|
throw semantic_error("xor_sum only supported for integers");
|
||||||
|
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Math
|
// Math
|
||||||
|
@@ -520,6 +520,7 @@ void init_triton_frontend(py::module &&m) {
|
|||||||
m.def("min", &ir::dispatch::min, ret::reference);
|
m.def("min", &ir::dispatch::min, ret::reference);
|
||||||
m.def("max", &ir::dispatch::max, ret::reference);
|
m.def("max", &ir::dispatch::max, ret::reference);
|
||||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||||
|
m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference);
|
||||||
// math
|
// math
|
||||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||||
|
@@ -719,6 +719,11 @@ def min(input, axis, _builder=None):
|
|||||||
def sum(input, axis, _builder=None):
|
def sum(input, axis, _builder=None):
|
||||||
return frontend.sum(input, axis, _builder)
|
return frontend.sum(input, axis, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
@_add_reduction_docstr("xor sum")
|
||||||
|
def xor_sum(input, axis, _builder=None):
|
||||||
|
return frontend.xor_sum(input, axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Internal for debugging
|
# Internal for debugging
|
||||||
|
Reference in New Issue
Block a user