[BACKEND] float division is now approximate by default (#446)

This commit is contained in:
Philippe Tillet
2022-01-29 18:29:29 -08:00
committed by GitHub
parent bd52e530a0
commit bef76b142a
9 changed files with 44 additions and 3 deletions

View File

@@ -34,6 +34,7 @@ struct dispatch{
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);

View File

@@ -117,6 +117,7 @@ private:
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction {
public:
typedef binary_op_t op_t;
@@ -145,6 +146,10 @@ public:
bool is_shl() const;
bool is_shr() const;
// Approx
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
@@ -163,6 +168,8 @@ public:
binary_op_t op_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
bool fdiv_ieee_rnd_;
};

View File

@@ -66,7 +66,10 @@ public:
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; }
const std::map<std::string, type*>& get_types() { return types_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
void set_types(const std::map<std::string, type*>& types) { types_ = types; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }

View File

@@ -319,6 +319,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
vals_[x][idx] = add(lhs, rhs);
else if(op == ll::Mul)
vals_[x][idx] = mul(lhs, rhs);
else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() &&
x->get_type()->get_scalar_ty()->is_fp32_ty()){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false),
" div.full.f32 $0, $1, $2;", "=r,r,r", false);
vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs});
}
else
vals_[x][idx] = bin_op(op, lhs, rhs);
}

View File

@@ -204,6 +204,18 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
throw_unreachable("floordiv");
}
ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty())
throw semantic_error("both operands of fdiv must have floating point scalar type");
binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES);
ir::value* ret = builder->create_fdiv(input, other);
if(ir::binary_operator* binop = dynamic_cast<ir::binary_operator*>(ret))
binop->set_fdiv_ieee_rounding(ieee_rounding->get_value());
return ret;
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();

View File

@@ -134,7 +134,7 @@ bool binary_operator::is_int_add_sub() const {
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, INST_BINOP, 2, name, next), op_(op){
: instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
set_operand(0, lhs);
set_operand(1, rhs);
}

View File

@@ -556,6 +556,7 @@ void init_triton_frontend(py::module &&m) {
m.def("mul", &ir::dispatch::mul, ret::reference);
m.def("truediv", &ir::dispatch::truediv, ret::reference);
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
m.def("mod", &ir::dispatch::mod, ret::reference);
m.def("and_", &ir::dispatch::and_, ret::reference);
m.def("or_", &ir::dispatch::or_, ret::reference);
@@ -691,6 +692,7 @@ void init_triton_ir(py::module &&m) {
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def("repr", &ir::type::repr)
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
.def_property_readonly("context", &ir::type::get_context, ret::reference);
@@ -713,6 +715,8 @@ void init_triton_ir(py::module &&m) {
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("get_values", &ir::module::get_values, ret::reference)
.def("set_values", &ir::module::set_values)
.def("get_types", &ir::module::get_types, ret::reference)
.def("set_types", &ir::module::set_types)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;

View File

@@ -986,12 +986,14 @@ class JITFunction:
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
types = generator.module.get_types().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
generator.module.set_values(values)
generator.module.set_types(types)
return ret
except Exception as e:
node = generator.last_node

View File

@@ -699,6 +699,11 @@ def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder)
@builtin
def fdiv(x, y, ieee_rounding=False, _builder=None):
return frontend.fdiv(x, y, ieee_rounding, _builder)
def _add_math_1arg_docstr(name):
def _decorator(func):
@@ -869,11 +874,11 @@ def sigmoid(x):
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x):
def softmax(x, ieee_rounding=False):
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
return num / den
return fdiv(num, den, ieee_rounding)
@triton.jit