ForOp's SSA construction
This commit is contained in:
@@ -659,12 +659,35 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
|
||||
;
|
||||
|
||||
py::class_<mlir::Region>(m, "region")
|
||||
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
|
||||
.def("size", [](mlir::Region &self) {
|
||||
return self.getBlocks().size();
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::Block>(m, "block")
|
||||
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
|
||||
return self.getArgument(index);
|
||||
})
|
||||
.def("dump", &mlir::Block::dump)
|
||||
.def("move_before", &mlir::Block::moveBefore)
|
||||
.def("insert_before", &mlir::Block::insertBefore)
|
||||
.def("get_parent", &mlir::Block::getParent, ret::reference)
|
||||
.def("merge_block_before", [](mlir::Block &self, mlir::Block &dst) {
|
||||
// ref: RewriterBase::mergeBlocks()
|
||||
if (self.getNumArguments() != 0)
|
||||
throw std::runtime_error("This block has arguments, don't merge");
|
||||
dst.getOperations().splice(dst.end(), self.getOperations());
|
||||
self.dropAllUses();
|
||||
self.erase();
|
||||
})
|
||||
.def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) {
|
||||
v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){
|
||||
mlir::Operation *user = operand.getOwner();
|
||||
return user->getBlock() == &self;
|
||||
});
|
||||
})
|
||||
;
|
||||
|
||||
// py::class_<mlir::ModuleOp>(m, "module")
|
||||
@@ -713,12 +736,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& {
|
||||
return self->getRegion(idx);
|
||||
}, ret::reference)
|
||||
.def("get_body", [](mlir::scf::ForOp &self, unsigned idx) -> mlir::Block* {
|
||||
return self.getBody(idx);
|
||||
}, ret::reference)
|
||||
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
||||
;
|
||||
// scf Ops
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
|
||||
.def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* {
|
||||
return self.getBody();
|
||||
}, ret::reference);
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
|
||||
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
||||
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
||||
@@ -856,11 +880,14 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
}, ret::reference)
|
||||
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
return new mlir::Block();
|
||||
}, ret::reference)
|
||||
// Structured control flow
|
||||
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
|
||||
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::ForOp>(loc, lb, ub, step);
|
||||
return self.create<mlir::scf::ForOp>(loc, lb, ub, step, initArgs);
|
||||
})
|
||||
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
@@ -195,7 +195,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(_names) == 1
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
print(f'visit_Assign({names}, {values})')
|
||||
if not isinstance(names, tuple):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
@@ -440,34 +439,30 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
|
||||
print(f'names: {names}')
|
||||
print("After insert block body")
|
||||
self.module.dump()
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
print("After create for op")
|
||||
self.module.dump()
|
||||
# TODO: revisit blocks & regions of ForOp
|
||||
block.move_before(for_op.get_body())
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body())
|
||||
# FIXME: the body should be a region (?)
|
||||
# FIXME: this won't work for nested control flow
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
print("After create yield op")
|
||||
self.module.dump()
|
||||
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "(For developer) Should use region here"
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
print("After restoring insert point")
|
||||
self.module.dump()
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# ForOp defines new values
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = for_op.get_result(i)
|
||||
self.local_defs[name] = for_op.get_result(i)
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False
|
||||
assert False, "Don't know what to do with else after for"
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Slice(self, node):
|
||||
|
Reference in New Issue
Block a user