[BACKEND] float division is now approximate by default (#446)
This commit is contained in:
@@ -556,6 +556,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("mul", &ir::dispatch::mul, ret::reference);
|
||||
m.def("truediv", &ir::dispatch::truediv, ret::reference);
|
||||
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
|
||||
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
|
||||
m.def("mod", &ir::dispatch::mod, ret::reference);
|
||||
m.def("and_", &ir::dispatch::and_, ret::reference);
|
||||
m.def("or_", &ir::dispatch::or_, ret::reference);
|
||||
@@ -691,6 +692,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
|
||||
.def("repr", &ir::type::repr)
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference);
|
||||
@@ -713,6 +715,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||
.def("get_values", &ir::module::get_values, ret::reference)
|
||||
.def("set_values", &ir::module::set_values)
|
||||
.def("get_types", &ir::module::get_types, ret::reference)
|
||||
.def("set_types", &ir::module::set_types)
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
|
@@ -986,12 +986,14 @@ class JITFunction:
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
types = generator.module.get_types().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
generator.module.set_values(values)
|
||||
generator.module.set_types(types)
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
|
@@ -699,6 +699,11 @@ def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
return frontend.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -869,11 +874,11 @@ def sigmoid(x):
|
||||
|
||||
@triton.jit
|
||||
@_add_math_1arg_docstr("softmax")
|
||||
def softmax(x):
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - triton.language.max(x, 0)
|
||||
num = triton.language.exp(z)
|
||||
den = triton.language.sum(num, 0)
|
||||
return num / den
|
||||
return fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
Reference in New Issue
Block a user