diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 1665bffb7..626065224 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -107,9 +107,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block ir::phi_node* phi = make_phi(ty, 1, block); set_value(name, block, phi); result = add_phi_operands(name, phi); + if(auto *phi = dynamic_cast(result)) + result = try_remove_trivial_phis(phi); } - if(auto *phi = dynamic_cast(result)) + if(auto *phi = dynamic_cast(result)){ result = try_remove_trivial_phis(phi); + } set_value(name, block, result); return result; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 97fd20a40..c102f4a35 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -181,6 +181,9 @@ void init_triton_ir(py::module &&m) { py::class_(m, "constant_float") .def_property_readonly("value", &ir::constant_fp::get_value); + py::class_(m, "instruction"); + py::class_(m, "phi_node"); + py::class_(m, "type") .def("is_ptr", &ir::type::is_pointer_ty) .def("is_int", static_cast(&ir::type::is_integer_ty)) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index c2f6fdfd8..3a43acd16 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -51,7 +51,7 @@ class CodeGenerator(ast.NodeVisitor): break if add_scope: self.module.pop_scope() - return self.last_ret + return stmts and isinstance(stmt, ast.Return) def __init__(self, context, prototype, gscope, attributes, constants, kwargs): self.builder = _triton.ir.builder(context) @@ -85,7 +85,10 @@ class CodeGenerator(ast.NodeVisitor): # By design, only non-kernel functions can return def visit_Return(self, node): - return self.visit(node.value) + ret = self.visit(node.value) + if ret is None: + return self.builder.ret_void() + return ret def visit_FunctionDef(self, node, inline=False, arg_values=None): arg_names, kwarg_names = self.visit(node.args) @@ -112,7 +115,8 @@ class CodeGenerator(ast.NodeVisitor): for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) if inline: - return self.visit_compound_statement(node.body, add_scope=True) + self.visit_compound_statement(node.body, add_scope=True) + return self.last_ret else: entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) self.module.seal_block(entry) @@ -140,6 +144,8 @@ class CodeGenerator(ast.NodeVisitor): assert len(names) == 1 name = names[0] value = self.visit(node.value) + if not isinstance(value, triton.language.block): + value = triton.language._to_ir(value, self.builder) self.set_value(names[0], value) def visit_AugAssign(self, node): @@ -208,14 +214,16 @@ class CodeGenerator(ast.NodeVisitor): else: self.builder.cond_br(cond.handle, then_bb, endif_bb) self.builder.set_insert_block(then_bb) - self.visit_compound_statement(node.body, add_scope=True) + is_terminator = self.visit_compound_statement(node.body, add_scope=True) # TODO: last statement is a terminator? - self.builder.br(endif_bb) + if not is_terminator: + self.builder.br(endif_bb) if else_bb: self.builder.set_insert_block(else_bb) - self.visit_compound_statement(node.orelse, add_scope=True) + is_terminator = self.visit_compound_statement(node.orelse, add_scope=True) #TODO: last statement is a terminator? - self.builder.br(endif_bb) + if not is_terminator: + self.builder.br(endif_bb) self.module.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: diff --git a/python/triton/language.py b/python/triton/language.py index 0ee34151f..ccfda885a 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -4,22 +4,22 @@ from triton._C.libtriton.triton import frontend from functools import wraps +# convert block/dtype to ir values +def _to_ir(x, builder): + if isinstance(x, bool): + return builder.get_int1(x) + elif isinstance(x, int): + return builder.get_int32(x) + elif isinstance(x, float): + return builder.get_float32(x) + if isinstance(x, block): + return x.handle + if isinstance(x, dtype): + return x.handle(builder) + return x + + def _patch(fn): - - # convert block/dtype to ir values - def _to_ir(x, builder): - if isinstance(x, bool): - return builder.get_int1(x) - elif isinstance(x, int): - return builder.get_int32(x) - elif isinstance(x, float): - return builder.get_float32(x) - if isinstance(x, block): - return x.handle - if isinstance(x, dtype): - return x.handle(builder) - return x - def _from_ir(x): if isinstance(x, ir.value): if x.type.is_void(): @@ -306,6 +306,7 @@ def zeros(shape, dtype, builder=None): :param dtype: Data-type of the new array, e.g., :code:`triton.float16` :type dtype: DType """ + shape = [int(x.handle) if isinstance(x, block) else x for x in shape] return frontend.zeros(shape, dtype, builder)