Some progress on visit_If
This commit is contained in:
@@ -740,6 +740,9 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.getBody(idx);
|
||||
}, ret::reference)
|
||||
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
||||
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
|
||||
self->insertOperands(self->getNumOperands(), val);
|
||||
})
|
||||
;
|
||||
// scf Ops
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||
@@ -889,9 +892,9 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
|
||||
})
|
||||
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
|
||||
.def("create_if_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes, mlir::Value &condition, bool withElse) -> mlir::scf::IfOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::IfOp>(loc, condition);
|
||||
return self.create<mlir::scf::IfOp>(loc, retTypes, condition, withElse);
|
||||
})
|
||||
.def("create_yield_op", [](mlir::OpBuilder &self, std::vector<mlir::Value> &yields) -> mlir::scf::YieldOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
@@ -284,15 +284,66 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# if not is_terminator:
|
||||
# self.builder.br(endif_bb)
|
||||
# self.builder.set_insert_block(endif_bb)
|
||||
parent_values = self.lscope.copy()
|
||||
self.visit_compound_statement(node.body)
|
||||
then_values = self.lvalues.copy()
|
||||
assert node.orelse
|
||||
self.lscope = parent_values
|
||||
self.visit_compound_statement(node.orelse)
|
||||
else_values = self.lscope.copy()
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
liveins = self.lscope.copy()
|
||||
parent_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
|
||||
ip_block = self.builder.get_insertion_block()
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
then_defs = self.local_defs.copy()
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.local_defs = {}
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
self.visit_compound_statement(node.orelse)
|
||||
else_defs = self.local_defs.copy()
|
||||
else:
|
||||
# collect else_defs
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
# TODO: what if this is constexpr?
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
else_defs[name] = liveins[name]
|
||||
# collect yields
|
||||
names = []
|
||||
ret_types = []
|
||||
for then_name in then_defs:
|
||||
for else_name in else_defs:
|
||||
if then_name == else_name:
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_yield_op = if_op.get_then_yield()
|
||||
else_yield_op = if_op.get_else_yield()
|
||||
for name in names:
|
||||
then_yield_op.append_operand(then_defs[name].handle)
|
||||
else_yield_op.append_operand(else_defs[name].handle)
|
||||
else:
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
# restore values in the parent scope
|
||||
self.lscope = liveins
|
||||
self.local_defs = parent_defs
|
||||
# update values yielded by IfOp
|
||||
for i, name in enumerate(names):
|
||||
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_tensor
|
||||
self.local_defs[name] = new_tensor
|
||||
|
||||
self.lvalues = join_if_lvalues(then_values, else_values)
|
||||
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
|
Reference in New Issue
Block a user