More progress on WhileOp

This commit is contained in:
Yan Da
2022-04-05 17:55:43 +08:00
parent d7fbddc7d4
commit 39fad2b18a
4 changed files with 33 additions and 9 deletions

View File

@@ -14,6 +14,7 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include <llvm-6.0/llvm/ADT/SmallVector.h>
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -653,12 +654,14 @@ void init_triton_ir(py::module &&m) {
.def("size", [](mlir::Region &self) {
return self.getBlocks().size();
})
.def("empty", &mlir::Region::empty)
;
py::class_<mlir::Block>(m, "block")
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
return self.getArgument(index);
})
.def("get_num_arguments", &mlir::Block::getNumArguments)
.def("dump", &mlir::Block::dump)
.def("move_before", &mlir::Block::moveBefore)
.def("insert_before", &mlir::Block::insertBefore)
@@ -674,7 +677,14 @@ void init_triton_ir(py::module &&m) {
.def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) {
v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){
mlir::Operation *user = operand.getOwner();
return user->getBlock() == &self;
mlir::Block *currentBlock = user->getBlock();
while (currentBlock) {
if (currentBlock == &self)
return true;
// Move up one level
currentBlock = currentBlock->getParent()->getParentOp()->getBlock();
}
return false;
});
})
;
@@ -886,8 +896,11 @@ void init_triton_ir(py::module &&m) {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
}, ret::reference)
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* {
return self.createBlock(&parent);
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent,
std::vector<mlir::Type> &argTypes) -> mlir::Block* {
auto argLoc = self.getUnknownLoc();
llvm::SmallVector<mlir::Location, 8> argLocs(argTypes.size(), argLoc);
return self.createBlock(&parent, {}, argTypes, argLocs);
}, ret::reference)
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
return new mlir::Block();

View File

@@ -425,23 +425,32 @@ class CodeGenerator(ast.NodeVisitor):
if loop_defs[name].type == liveins[name].type:
# these are loop-carried values
names.append(name)
ret_types.append(loop_defs[name].type.to_ir(self.builder))
ret_types.append(loop_defs[name].type)
init_args.append(liveins[name])
yields.append(loop_defs[name])
self.builder.set_insertion_point_to_end(insert_block)
while_op = self.builder.create_while_op(ret_types, init_args)
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
[arg.handle for arg in init_args])
# merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before())
before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types])
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
self.builder.create_condtion_op(cond.handle, [])
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
# merge the loop body
after_block = self.builder.create_block_with_parent(while_op.get_after())
after_block = self.builder.create_block_with_parent(while_op.get_after(),
[ty.to_ir(self.builder) for ty in ret_types])
loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields])
# update global_uses in while_op
for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
self.builder.set_insertion_point_to_end(insert_block)
self.lscope = liveins
self.local_defs = prev_defs