[BACKEND] float division is now approximate by default (#446)
This commit is contained in:
@@ -34,6 +34,7 @@ struct dispatch{
|
||||
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
|
||||
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
@@ -117,6 +117,7 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class binary_operator: public instruction {
|
||||
public:
|
||||
typedef binary_op_t op_t;
|
||||
@@ -145,6 +146,10 @@ public:
|
||||
bool is_shl() const;
|
||||
bool is_shr() const;
|
||||
|
||||
// Approx
|
||||
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
|
||||
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
|
||||
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
@@ -163,6 +168,8 @@ public:
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
|
||||
bool fdiv_ieee_rnd_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -66,7 +66,10 @@ public:
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
const std::map<val_key_t, value*>& get_values() { return values_; }
|
||||
const std::map<std::string, type*>& get_types() { return types_; }
|
||||
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
||||
void set_types(const std::map<std::string, type*>& types) { types_ = types; }
|
||||
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
||||
|
@@ -319,6 +319,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
|
||||
x->get_type()->get_scalar_ty()->is_fp32_ty()){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
|
||||
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
|
||||
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
|
||||
|
||||
}
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
|
@@ -204,6 +204,18 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
|
||||
throw_unreachable("floordiv");
|
||||
}
|
||||
|
||||
ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty())
|
||||
throw semantic_error("both operands of fdiv must have floating point scalar type");
|
||||
binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES);
|
||||
ir::value* ret = builder->create_fdiv(input, other);
|
||||
if(ir::binary_operator* binop = dynamic_cast<ir::binary_operator*>(ret))
|
||||
binop->set_fdiv_ieee_rounding(ieee_rounding->get_value());
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
|
@@ -134,7 +134,7 @@ bool binary_operator::is_int_add_sub() const {
|
||||
|
||||
|
||||
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, INST_BINOP, 2, name, next), op_(op){
|
||||
: instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
@@ -556,6 +556,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("mul", &ir::dispatch::mul, ret::reference);
|
||||
m.def("truediv", &ir::dispatch::truediv, ret::reference);
|
||||
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
|
||||
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
|
||||
m.def("mod", &ir::dispatch::mod, ret::reference);
|
||||
m.def("and_", &ir::dispatch::and_, ret::reference);
|
||||
m.def("or_", &ir::dispatch::or_, ret::reference);
|
||||
@@ -691,6 +692,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
|
||||
.def("repr", &ir::type::repr)
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference);
|
||||
@@ -713,6 +715,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||
.def("get_values", &ir::module::get_values, ret::reference)
|
||||
.def("set_values", &ir::module::set_values)
|
||||
.def("get_types", &ir::module::get_types, ret::reference)
|
||||
.def("set_types", &ir::module::set_types)
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
|
@@ -986,12 +986,14 @@ class JITFunction:
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
types = generator.module.get_types().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
generator.module.set_values(values)
|
||||
generator.module.set_types(types)
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
|
@@ -699,6 +699,11 @@ def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
return frontend.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -869,11 +874,11 @@ def sigmoid(x):
|
||||
|
||||
@triton.jit
|
||||
@_add_math_1arg_docstr("softmax")
|
||||
def softmax(x):
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - triton.language.max(x, 0)
|
||||
num = triton.language.exp(z)
|
||||
den = triton.language.sum(num, 0)
|
||||
return num / den
|
||||
return fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
Reference in New Issue
Block a user