diff --git a/python/src/triton.cc b/python/src/triton.cc index 66cece243..953c62a40 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -740,6 +740,9 @@ void init_triton_ir(py::module &&m) { return self.getBody(idx); }, ret::reference) .def("dump", [](mlir::OpState &self) { self->dump(); }) + .def("append_operand", [](mlir::OpState &self, mlir::Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) ; // scf Ops py::class_(m, "ForOp"); @@ -889,9 +892,9 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, lb, ub, step, initArgs); }) - .def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp { + .def("create_if_op", [](mlir::OpBuilder &self, std::vector &retTypes, mlir::Value &condition, bool withElse) -> mlir::scf::IfOp { auto loc = self.getUnknownLoc(); - return self.create(loc, condition); + return self.create(loc, retTypes, condition, withElse); }) .def("create_yield_op", [](mlir::OpBuilder &self, std::vector &yields) -> mlir::scf::YieldOp { auto loc = self.getUnknownLoc(); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index dcf038136..cf1594757 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -284,15 +284,66 @@ class CodeGenerator(ast.NodeVisitor): # if not is_terminator: # self.builder.br(endif_bb) # self.builder.set_insert_block(endif_bb) - parent_values = self.lscope.copy() - self.visit_compound_statement(node.body) - then_values = self.lvalues.copy() - assert node.orelse - self.lscope = parent_values - self.visit_compound_statement(node.orelse) - else_values = self.lscope.copy() + cond = cond.to(triton.language.int1, _builder=self.builder) + liveins = self.lscope.copy() + parent_defs = self.local_defs.copy() + self.local_defs = {} + + ip_block = self.builder.get_insertion_block() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_defs = self.local_defs.copy() + + if then_defs or node.orelse: + if node.orelse: + self.local_defs = {} + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(else_block) + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else: + # collect else_defs + else_defs = {} + for name in then_defs: + if name in liveins: + # TODO: what if this is constexpr? + assert self.is_triton_tensor(then_defs[name]) + assert self.is_triton_tensor(liveins[name]) + else_defs[name] = liveins[name] + # collect yields + names = [] + ret_types = [] + for then_name in then_defs: + for else_name in else_defs: + if then_name == else_name: + if then_defs[then_name].type == else_defs[else_name].type: + names.append(then_name) + ret_types.append(then_defs[then_name].type) + + self.builder.set_insertion_point_to_end(ip_block) + + if then_defs or node.orelse: + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_yield_op = if_op.get_then_yield() + else_yield_op = if_op.get_else_yield() + for name in names: + then_yield_op.append_operand(then_defs[name].handle) + else_yield_op.append_operand(else_defs[name].handle) + else: + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) + + self.builder.set_insertion_point_to_end(ip_block) + # restore values in the parent scope + self.lscope = liveins + self.local_defs = parent_defs + # update values yielded by IfOp + for i, name in enumerate(names): + new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i]) + self.lscope[name] = new_tensor + self.local_defs[name] = new_tensor - self.lvalues = join_if_lvalues(then_values, else_values) else: if isinstance(cond, triton.language.constexpr):