Some progress on visit_If
This commit is contained in:
@@ -740,6 +740,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.getBody(idx);
|
return self.getBody(idx);
|
||||||
}, ret::reference)
|
}, ret::reference)
|
||||||
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
||||||
|
.def("append_operand", [](mlir::OpState &self, mlir::Value &val) {
|
||||||
|
self->insertOperands(self->getNumOperands(), val);
|
||||||
|
})
|
||||||
;
|
;
|
||||||
// scf Ops
|
// scf Ops
|
||||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||||
@@ -889,9 +892,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
|
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
|
||||||
})
|
})
|
||||||
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
|
.def("create_if_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes, mlir::Value &condition, bool withElse) -> mlir::scf::IfOp {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::scf::IfOp>(loc, condition);
|
return self.create<mlir::scf::IfOp>(loc, retTypes, condition, withElse);
|
||||||
})
|
})
|
||||||
.def("create_yield_op", [](mlir::OpBuilder &self, std::vector<mlir::Value> &yields) -> mlir::scf::YieldOp {
|
.def("create_yield_op", [](mlir::OpBuilder &self, std::vector<mlir::Value> &yields) -> mlir::scf::YieldOp {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
|
@@ -284,15 +284,66 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# if not is_terminator:
|
# if not is_terminator:
|
||||||
# self.builder.br(endif_bb)
|
# self.builder.br(endif_bb)
|
||||||
# self.builder.set_insert_block(endif_bb)
|
# self.builder.set_insert_block(endif_bb)
|
||||||
parent_values = self.lscope.copy()
|
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
self.visit_compound_statement(node.body)
|
liveins = self.lscope.copy()
|
||||||
then_values = self.lvalues.copy()
|
parent_defs = self.local_defs.copy()
|
||||||
assert node.orelse
|
self.local_defs = {}
|
||||||
self.lscope = parent_values
|
|
||||||
self.visit_compound_statement(node.orelse)
|
ip_block = self.builder.get_insertion_block()
|
||||||
else_values = self.lscope.copy()
|
|
||||||
|
then_block = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_start(then_block)
|
||||||
|
self.visit_compound_statement(node.body)
|
||||||
|
then_defs = self.local_defs.copy()
|
||||||
|
|
||||||
|
if then_defs or node.orelse:
|
||||||
|
if node.orelse:
|
||||||
|
self.local_defs = {}
|
||||||
|
else_block = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_end(else_block)
|
||||||
|
self.visit_compound_statement(node.orelse)
|
||||||
|
else_defs = self.local_defs.copy()
|
||||||
|
else:
|
||||||
|
# collect else_defs
|
||||||
|
else_defs = {}
|
||||||
|
for name in then_defs:
|
||||||
|
if name in liveins:
|
||||||
|
# TODO: what if this is constexpr?
|
||||||
|
assert self.is_triton_tensor(then_defs[name])
|
||||||
|
assert self.is_triton_tensor(liveins[name])
|
||||||
|
else_defs[name] = liveins[name]
|
||||||
|
# collect yields
|
||||||
|
names = []
|
||||||
|
ret_types = []
|
||||||
|
for then_name in then_defs:
|
||||||
|
for else_name in else_defs:
|
||||||
|
if then_name == else_name:
|
||||||
|
if then_defs[then_name].type == else_defs[else_name].type:
|
||||||
|
names.append(then_name)
|
||||||
|
ret_types.append(then_defs[then_name].type)
|
||||||
|
|
||||||
|
self.builder.set_insertion_point_to_end(ip_block)
|
||||||
|
|
||||||
|
if then_defs or node.orelse:
|
||||||
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||||
|
then_yield_op = if_op.get_then_yield()
|
||||||
|
else_yield_op = if_op.get_else_yield()
|
||||||
|
for name in names:
|
||||||
|
then_yield_op.append_operand(then_defs[name].handle)
|
||||||
|
else_yield_op.append_operand(else_defs[name].handle)
|
||||||
|
else:
|
||||||
|
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||||
|
|
||||||
|
self.builder.set_insertion_point_to_end(ip_block)
|
||||||
|
# restore values in the parent scope
|
||||||
|
self.lscope = liveins
|
||||||
|
self.local_defs = parent_defs
|
||||||
|
# update values yielded by IfOp
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||||
|
self.lscope[name] = new_tensor
|
||||||
|
self.local_defs[name] = new_tensor
|
||||||
|
|
||||||
self.lvalues = join_if_lvalues(then_values, else_values)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if isinstance(cond, triton.language.constexpr):
|
if isinstance(cond, triton.language.constexpr):
|
||||||
|
Reference in New Issue
Block a user