ForOp's SSA construction

This commit is contained in:
Yan Da
2022-04-03 19:11:47 +08:00
parent 61413b8a97
commit c71c50cd0c
2 changed files with 44 additions and 22 deletions

View File

@@ -659,12 +659,35 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
;
py::class_<mlir::Region>(m, "region")
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
.def("size", [](mlir::Region &self) {
return self.getBlocks().size();
})
;
py::class_<mlir::Block>(m, "block")
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
return self.getArgument(index);
})
.def("dump", &mlir::Block::dump)
.def("move_before", &mlir::Block::moveBefore)
.def("insert_before", &mlir::Block::insertBefore)
.def("get_parent", &mlir::Block::getParent, ret::reference)
.def("merge_block_before", [](mlir::Block &self, mlir::Block &dst) {
// ref: RewriterBase::mergeBlocks()
if (self.getNumArguments() != 0)
throw std::runtime_error("This block has arguments, don't merge");
dst.getOperations().splice(dst.end(), self.getOperations());
self.dropAllUses();
self.erase();
})
.def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) {
v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){
mlir::Operation *user = operand.getOwner();
return user->getBlock() == &self;
});
})
;
// py::class_<mlir::ModuleOp>(m, "module")
@@ -713,12 +736,13 @@ void init_triton_ir(py::module &&m) {
.def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& {
return self->getRegion(idx);
}, ret::reference)
.def("get_body", [](mlir::scf::ForOp &self, unsigned idx) -> mlir::Block* {
return self.getBody(idx);
}, ret::reference)
.def("dump", [](mlir::OpState &self) { self->dump(); })
;
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
.def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* {
return self.getBody();
}, ret::reference);
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
@@ -856,11 +880,14 @@ void init_triton_ir(py::module &&m) {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return new mlir::Block();
}, ret::reference)
// Structured control flow
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ForOp>(loc, lb, ub, step);
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
})
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
auto loc = self.getUnknownLoc();

View File

@@ -195,7 +195,6 @@ class CodeGenerator(ast.NodeVisitor):
assert len(_names) == 1
names = _names[0]
values = self.visit(node.value)
print(f'visit_Assign({names}, {values})')
if not isinstance(names, tuple):
names = [names]
if not isinstance(values, tuple):
@@ -440,34 +439,30 @@ class CodeGenerator(ast.NodeVisitor):
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
print(f'names: {names}')
print("After insert block body")
self.module.dump()
self.builder.set_insertion_point_to_end(insert_block)
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
print("After create for op")
self.module.dump()
# TODO: revisit blocks & regions of ForOp
block.move_before(for_op.get_body())
self.builder.set_insertion_point_to_end(for_op.get_body())
# FIXME: the body should be a region (?)
# FIXME: this won't work for nested control flow
block.merge_block_before(for_op.get_body(0))
self.builder.set_insertion_point_to_end(for_op.get_body(0))
self.builder.create_yield_op([y.handle for y in yields])
print("After create yield op")
self.module.dump()
for_op_region = for_op.get_body(0).get_parent()
assert for_op_region.size() == 1, "(For developer) Should use region here"
for i, name in enumerate(names):
# arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
self.builder.set_insertion_point_to_end(insert_block)
print("After restoring insert point")
self.module.dump()
self.lscope = liveins
self.local_defs = prev_defs
# ForOp defines new values
for i, name in enumerate(names):
self.lscope[name] = for_op.get_result(i)
self.local_defs[name] = for_op.get_result(i)
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
for stmt in node.orelse:
assert False
assert False, "Don't know what to do with else after for"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):