[BACKEND] float division is now approximate by default (#446)

This commit is contained in:
Philippe Tillet
2022-01-29 18:29:29 -08:00
committed by GitHub
parent bd52e530a0
commit bef76b142a
9 changed files with 44 additions and 3 deletions

View File

@@ -556,6 +556,7 @@ void init_triton_frontend(py::module &&m) {
m.def("mul", &ir::dispatch::mul, ret::reference);
m.def("truediv", &ir::dispatch::truediv, ret::reference);
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
m.def("mod", &ir::dispatch::mod, ret::reference);
m.def("and_", &ir::dispatch::and_, ret::reference);
m.def("or_", &ir::dispatch::or_, ret::reference);
@@ -691,6 +692,7 @@ void init_triton_ir(py::module &&m) {
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def("repr", &ir::type::repr)
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
.def_property_readonly("context", &ir::type::get_context, ret::reference);
@@ -713,6 +715,8 @@ void init_triton_ir(py::module &&m) {
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("get_values", &ir::module::get_values, ret::reference)
.def("set_values", &ir::module::set_values)
.def("get_types", &ir::module::get_types, ret::reference)
.def("set_types", &ir::module::set_types)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;

View File

@@ -986,12 +986,14 @@ class JITFunction:
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
types = generator.module.get_types().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
generator.module.set_values(values)
generator.module.set_types(types)
return ret
except Exception as e:
node = generator.last_node

View File

@@ -699,6 +699,11 @@ def umulhi(x, y, _builder=None):
return frontend.umulhi(x, y, _builder)
@builtin
def fdiv(x, y, ieee_rounding=False, _builder=None):
return frontend.fdiv(x, y, ieee_rounding, _builder)
def _add_math_1arg_docstr(name):
def _decorator(func):
@@ -869,11 +874,11 @@ def sigmoid(x):
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x):
def softmax(x, ieee_rounding=False):
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
return num / den
return fdiv(num, den, ieee_rounding)
@triton.jit