More progress on WhileOp codegen
This commit is contained in:
@@ -638,17 +638,6 @@ void init_triton_ir(py::module &&m) {
|
||||
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
// // .def("get", &ir::undef_value::get, ret::reference);
|
||||
|
||||
py::class_<mlir::ModuleOp>(m, "module")
|
||||
// .def("set_attr")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("get_context", &mlir::ModuleOp::getContext)
|
||||
;
|
||||
|
||||
py::class_<mlir::Type>(m, "type")
|
||||
.def("is_integer", &mlir::Type::isInteger)
|
||||
.def("is_fp16", &mlir::Type::isF16)
|
||||
@@ -753,13 +742,27 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
|
||||
;
|
||||
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
|
||||
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
|
||||
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
|
||||
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
|
||||
// .def("set_attr")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
|
||||
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
// // getters
|
||||
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
|
||||
.def_property_readonly("context", &mlir::OpBuilder::getContext, ret::reference)
|
||||
.def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::ModuleOp>(loc);
|
||||
@@ -883,6 +886,9 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
}, ret::reference)
|
||||
.def("create_block_with_parent", [](mlir::OpBuilder &self, mlir::Region &parent) -> mlir::Block* {
|
||||
return self.createBlock(&parent);
|
||||
})
|
||||
.def("new_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
return new mlir::Block();
|
||||
}, ret::reference)
|
||||
@@ -900,7 +906,16 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::YieldOp>(loc, yields);
|
||||
})
|
||||
// // .def("create_while")
|
||||
.def("create_while_op", [](mlir::OpBuilder &self, std::vector<mlir::Type> &retTypes,
|
||||
std::vector<mlir::Value> &initArgs) -> mlir::scf::WhileOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
|
||||
})
|
||||
.def("create_condtion_op", [](mlir::OpBuilder &self, mlir::Value &cond,
|
||||
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
|
||||
})
|
||||
|
||||
// miscellious
|
||||
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
|
||||
|
@@ -46,6 +46,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# SSA-construction
|
||||
# name => triton.language.tensor
|
||||
self.local_defs: Dict[str, triton.language.tensor] = {}
|
||||
self.global_uses: Dict[str, triton.language.tensor] = {}
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
@@ -57,6 +58,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
if name not in self.local_defs:
|
||||
self.global_uses[name] = ret
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
@@ -263,27 +266,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
# cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
# current_bb = self.builder.get_insertion_block()
|
||||
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
# if else_bb:
|
||||
# self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
# else:
|
||||
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
# self.builder.set_insert_block(then_bb)
|
||||
# is_terminator = self.visit_compound_statement(node.body)
|
||||
# # TODO: last statement is a terminator?
|
||||
# if not is_terminator:
|
||||
# self.builder.br(endif_bb)
|
||||
# if else_bb:
|
||||
# self.builder.set_insert_block(else_bb)
|
||||
# is_terminator = self.visit_compound_statement(node.orelse)
|
||||
# # TODO: last statement is a terminator?
|
||||
# if not is_terminator:
|
||||
# self.builder.br(endif_bb)
|
||||
# self.builder.set_insert_block(endif_bb)
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
liveins = self.lscope.copy()
|
||||
parent_defs = self.local_defs.copy()
|
||||
@@ -413,22 +395,64 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insertion_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
|
||||
def continue_fn():
|
||||
cond = self.visit(node.test)
|
||||
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
|
||||
continue_fn()
|
||||
self.builder.set_insert_block(loop_bb)
|
||||
# condtion (the before region)
|
||||
cond_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(cond_block)
|
||||
cond = self.visit(node.test)
|
||||
|
||||
# loop body (the after region)
|
||||
loop_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insertion_block()
|
||||
self.builder.set_insert_block(next_bb)
|
||||
loop_defs = self.local_defs
|
||||
|
||||
# collect loop-carried values
|
||||
names = []
|
||||
ret_types = []
|
||||
init_args = []
|
||||
yields = []
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr (?)
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type.to_ir(self.builder))
|
||||
init_args.append(liveins[name])
|
||||
yields.append(loop_defs[name])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
while_op = self.builder.create_while_op(ret_types, init_args)
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before())
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
self.builder.create_condtion_op(cond.handle, [])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after())
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Not implemented"
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
|
Reference in New Issue
Block a user