More progress on WhileOp
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
#include <llvm-6.0/llvm/ADT/SmallVector.h>
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -653,12 +654,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("size", [](mlir::Region &self) {
|
||||
return self.getBlocks().size();
|
||||
})
|
||||
.def("empty", &mlir::Region::empty)
|
||||
;
|
||||
|
||||
py::class_<mlir::Block>(m, "block")
|
||||
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
|
||||
return self.getArgument(index);
|
||||
})
|
||||
.def("get_num_arguments", &mlir::Block::getNumArguments)
|
||||
.def("dump", &mlir::Block::dump)
|
||||
.def("move_before", &mlir::Block::moveBefore)
|
||||
.def("insert_before", &mlir::Block::insertBefore)
|
||||
@@ -674,7 +677,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("replace_use_in_block_with", [](mlir::Block &self, mlir::Value &v, mlir::Value &newVal) {
|
||||
v.replaceUsesWithIf(newVal, [&](mlir::OpOperand &operand){
|
||||
mlir::Operation *user = operand.getOwner();
|
||||
return user->getBlock() == &self;
|
||||
mlir::Block *currentBlock = user->getBlock();
|
||||
while (currentBlock) {
|
||||
if (currentBlock == &self)
|
||||
return true;
|
||||
// Move up one level
|
||||
currentBlock = currentBlock->getParent()->getParentOp()->getBlock();
|
||||
}
|
||||
return false;
|
||||
});
|
||||
})
|
||||
;
|
||||
@@ -886,8 +896,11 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
}, ret::reference)
|
||||
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* {
|
||||
return self.createBlock(&parent);
|
||||
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent,
|
||||
std::vector<mlir::Type> &argTypes) -> mlir::Block* {
|
||||
auto argLoc = self.getUnknownLoc();
|
||||
llvm::SmallVector<mlir::Location, 8> argLocs(argTypes.size(), argLoc);
|
||||
return self.createBlock(&parent, {}, argTypes, argLocs);
|
||||
}, ret::reference)
|
||||
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
return new mlir::Block();
|
||||
|
@@ -425,23 +425,32 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type.to_ir(self.builder))
|
||||
ret_types.append(loop_defs[name].type)
|
||||
init_args.append(liveins[name])
|
||||
yields.append(loop_defs[name])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
while_op = self.builder.create_while_op(ret_types, init_args)
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before())
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
self.builder.create_condtion_op(cond.handle, [])
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after())
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global_uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
|
@@ -47,6 +47,5 @@ z = torch.empty_like(x)
|
||||
# add_kernel[(1,)](x, y, z, size, 256)
|
||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, 128, 8, grid=(1,))
|
||||
mod.get_context()
|
||||
mod.dump()
|
||||
# print(mod)
|
||||
|
@@ -16,3 +16,6 @@ def generic_while(lb, value):
|
||||
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
mod_atomic.dump()
|
||||
|
||||
mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||
mod_generic_while.dump()
|
||||
|
Reference in New Issue
Block a user