Some progress on visit_If

This commit is contained in:
Yan Da
2022-04-03 22:34:46 +08:00
parent c71c50cd0c
commit 9df899b291
2 changed files with 64 additions and 10 deletions

View File

@@ -740,6 +740,9 @@ void init_triton_ir(py::module &&m) {
return self.getBody(idx);
}, ret::reference)
.def("dump", [](mlir::OpState &self) { self->dump(); })
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
self->insertOperands(self->getNumOperands(), val);
})
;
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
@@ -889,9 +892,9 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
})
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
.def("create_if_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes, mlir::Value &condition, bool withElse) -> mlir::scf::IfOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::IfOp>(loc, condition);
return self.create<mlir::scf::IfOp>(loc, retTypes, condition, withElse);
})
.def("create_yield_op", [](mlir::OpBuilder &self, std::vector<mlir::Value> &yields) -> mlir::scf::YieldOp {
auto loc = self.getUnknownLoc();

View File

@@ -284,15 +284,66 @@ class CodeGenerator(ast.NodeVisitor):
# if not is_terminator:
# self.builder.br(endif_bb)
# self.builder.set_insert_block(endif_bb)
parent_values = self.lscope.copy()
self.visit_compound_statement(node.body)
then_values = self.lvalues.copy()
assert node.orelse
self.lscope = parent_values
self.visit_compound_statement(node.orelse)
else_values = self.lscope.copy()
cond = cond.to(triton.language.int1, _builder=self.builder)
liveins = self.lscope.copy()
parent_defs = self.local_defs.copy()
self.local_defs = {}
ip_block = self.builder.get_insertion_block()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
then_defs = self.local_defs.copy()
if then_defs or node.orelse:
if node.orelse:
self.local_defs = {}
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(else_block)
self.visit_compound_statement(node.orelse)
else_defs = self.local_defs.copy()
else:
# collect else_defs
else_defs = {}
for name in then_defs:
if name in liveins:
# TODO: what if this is constexpr?
assert self.is_triton_tensor(then_defs[name])
assert self.is_triton_tensor(liveins[name])
else_defs[name] = liveins[name]
# collect yields
names = []
ret_types = []
for then_name in then_defs:
for else_name in else_defs:
if then_name == else_name:
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse:
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_yield_op = if_op.get_then_yield()
else_yield_op = if_op.get_else_yield()
for name in names:
then_yield_op.append_operand(then_defs[name].handle)
else_yield_op.append_operand(else_defs[name].handle)
else:
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
self.builder.set_insertion_point_to_end(ip_block)
# restore values in the parent scope
self.lscope = liveins
self.local_defs = parent_defs
# update values yielded by IfOp
for i, name in enumerate(names):
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
self.lscope[name] = new_tensor
self.local_defs[name] = new_tensor
self.lvalues = join_if_lvalues(then_values, else_values)
else:
if isinstance(cond, triton.language.constexpr):