diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index d8293d231..7dc8ab0ed 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -34,6 +34,7 @@ struct dispatch{ static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index ca1416f48..5af077e8f 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -117,6 +117,7 @@ private: //===----------------------------------------------------------------------===// // binary_operator classes //===----------------------------------------------------------------------===// + class binary_operator: public instruction { public: typedef binary_op_t op_t; @@ -145,6 +146,10 @@ public: bool is_shl() const; bool is_shr() const; + // Approx + void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; } + bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; } + // Wraps void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; } @@ -163,6 +168,8 @@ public: binary_op_t op_; bool has_no_unsigned_wrap_; bool has_no_signed_wrap_; + + bool fdiv_ieee_rnd_; }; diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index b350e3cc9..30881fd49 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -66,7 +66,10 @@ public: void set_continue_fn(std::function fn); // Getters const std::map& get_values() { return values_; } + const std::map& get_types() { return types_; } void set_values(const std::map& values) { values_ = values; } + void set_types(const std::map& types) { types_ = types; } + value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name); void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a55991475..986e8212e 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -319,6 +319,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) { vals_[x][idx] = add(lhs, rhs); else if(op == ll::Mul) vals_[x][idx] = mul(lhs, rhs); + else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && + x->get_type()->get_scalar_ty()->is_fp32_ty()){ + InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), + " div.full.f32 $0, $1, $2;", "=r,r,r", false); + vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); + + } else vals_[x][idx] = bin_op(op, lhs, rhs); } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index e3168b375..5a2d33a98 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -204,6 +204,18 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b throw_unreachable("floordiv"); } +ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) + throw semantic_error("both operands of fdiv must have floating point scalar type"); + binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); + ir::value* ret = builder->create_fdiv(input, other); + if(ir::binary_operator* binop = dynamic_cast(ret)) + binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); + return ret; +} + ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *scalar_ty = input->get_type()->get_scalar_ty(); diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 6e416a43e..ac5a17289 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -134,7 +134,7 @@ bool binary_operator::is_int_add_sub() const { binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) - : instruction(ty, INST_BINOP, 2, name, next), op_(op){ + : instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){ set_operand(0, lhs); set_operand(1, rhs); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 8ceb14200..77edf791f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -556,6 +556,7 @@ void init_triton_frontend(py::module &&m) { m.def("mul", &ir::dispatch::mul, ret::reference); m.def("truediv", &ir::dispatch::truediv, ret::reference); m.def("floordiv", &ir::dispatch::floordiv, ret::reference); + m.def("fdiv", &ir::dispatch::fdiv, ret::reference); m.def("mod", &ir::dispatch::mod, ret::reference); m.def("and_", &ir::dispatch::and_, ret::reference); m.def("or_", &ir::dispatch::or_, ret::reference); @@ -691,6 +692,7 @@ void init_triton_ir(py::module &&m) { .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) + .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) .def_property_readonly("context", &ir::type::get_context, ret::reference); @@ -713,6 +715,8 @@ void init_triton_ir(py::module &&m) { .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) .def("get_values", &ir::module::get_values, ret::reference) .def("set_values", &ir::module::set_values) + .def("get_types", &ir::module::get_types, ret::reference) + .def("set_types", &ir::module::set_types) .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 347635e32..48057b770 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -986,12 +986,14 @@ class JITFunction: gscope = generator.gscope.copy() lscope = generator.lscope.copy() values = generator.module.get_values().copy() + types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope generator.module.set_values(values) + generator.module.set_types(types) return ret except Exception as e: node = generator.last_node diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df3b1f4cf..425e12e01 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -699,6 +699,11 @@ def umulhi(x, y, _builder=None): return frontend.umulhi(x, y, _builder) +@builtin +def fdiv(x, y, ieee_rounding=False, _builder=None): + return frontend.fdiv(x, y, ieee_rounding, _builder) + + def _add_math_1arg_docstr(name): def _decorator(func): @@ -869,11 +874,11 @@ def sigmoid(x): @triton.jit @_add_math_1arg_docstr("softmax") -def softmax(x): +def softmax(x, ieee_rounding=False): z = x - triton.language.max(x, 0) num = triton.language.exp(z) den = triton.language.sum(num, 0) - return num / den + return fdiv(num, den, ieee_rounding) @triton.jit