More on scf Ops
This commit is contained in:
@@ -711,6 +711,17 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("dump", [](mlir::FuncOp &self) { self.dump(); })
|
.def("dump", [](mlir::FuncOp &self) { self.dump(); })
|
||||||
;
|
;
|
||||||
|
|
||||||
|
// Ops
|
||||||
|
py::class_<mlir::scf::ForOp>(m, "ForOp")
|
||||||
|
.def("get_body", &mlir::scf::ForOp::getBody, ret::reference);
|
||||||
|
py::class_<mlir::scf::IfOp>(m, "IfOp")
|
||||||
|
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
||||||
|
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
||||||
|
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
|
||||||
|
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
|
||||||
|
;
|
||||||
|
py::class_<mlir::scf::YieldOp>(m, "YieldOp");
|
||||||
|
|
||||||
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
||||||
.def(py::init<mlir::MLIRContext *>())
|
.def(py::init<mlir::MLIRContext *>())
|
||||||
// // getters
|
// // getters
|
||||||
@@ -839,21 +850,20 @@ void init_triton_ir(py::module &&m) {
|
|||||||
}
|
}
|
||||||
throw std::runtime_error("invalid function type");
|
throw std::runtime_error("invalid function type");
|
||||||
})
|
})
|
||||||
// // Structured control flow
|
// Structured control flow
|
||||||
// .def("create_for", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
|
.def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
|
||||||
// mlir::Value &step, std::vector<mlir::Value> &initArgs) -> MlirOperation {
|
mlir::Value &step, std::vector<mlir::Value> &initArgs) -> mlir::scf::ForOp {
|
||||||
// auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
// return wrap(self.create<mlir::scf::ForOp>(
|
return self.create<mlir::scf::ForOp>(loc, lb, ub, step);
|
||||||
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
|
})
|
||||||
// })
|
.def("create_if_of", [](mlir::OpBuilder &self, mlir::Value &condition) -> mlir::scf::IfOp {
|
||||||
// .def("create_if", [](mlir::OpBuilder &self, mlir::Value &condition) -> MlirOperation {
|
auto loc = self.getUnknownLoc();
|
||||||
// auto loc = self.getUnknownLoc();
|
return self.create<mlir::scf::IfOp>(loc, condition);
|
||||||
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
|
})
|
||||||
// })
|
.def("create_yield_op", [](mlir::OpBuilder &self) -> mlir::scf::YieldOp {
|
||||||
// .def("create_yield", [](mlir::OpBuilder &self) -> MlirOperation {
|
auto loc = self.getUnknownLoc();
|
||||||
// auto loc = self.getUnknownLoc();
|
return self.create<mlir::scf::YieldOp>(loc);
|
||||||
// return wrap(self.create<mlir::scf::YieldOp>(loc).getOperation());
|
})
|
||||||
// })
|
|
||||||
// // .def("create_while")
|
// // .def("create_while")
|
||||||
|
|
||||||
// miscellious
|
// miscellious
|
||||||
|
@@ -362,30 +362,40 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
def visit_If(self, node):
|
def visit_If(self, node):
|
||||||
cond = self.visit(node.test)
|
cond = self.visit(node.test)
|
||||||
if isinstance(cond, triton.language.tensor):
|
if isinstance(cond, triton.language.tensor):
|
||||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
# cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
current_bb = self.builder.get_insertion_block()
|
# current_bb = self.builder.get_insertion_block()
|
||||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
# then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
# else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
# endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||||
self._seal_block(then_bb)
|
# self._seal_block(then_bb)
|
||||||
if else_bb:
|
# if else_bb:
|
||||||
self._seal_block(else_bb)
|
# self._seal_block(else_bb)
|
||||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
# self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||||
else:
|
# else:
|
||||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
# self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||||
self.builder.set_insert_block(then_bb)
|
# self.builder.set_insert_block(then_bb)
|
||||||
is_terminator = self.visit_compound_statement(node.body)
|
# is_terminator = self.visit_compound_statement(node.body)
|
||||||
# TODO: last statement is a terminator?
|
# # TODO: last statement is a terminator?
|
||||||
if not is_terminator:
|
# if not is_terminator:
|
||||||
self.builder.br(endif_bb)
|
# self.builder.br(endif_bb)
|
||||||
if else_bb:
|
# if else_bb:
|
||||||
self.builder.set_insert_block(else_bb)
|
# self.builder.set_insert_block(else_bb)
|
||||||
is_terminator = self.visit_compound_statement(node.orelse)
|
# is_terminator = self.visit_compound_statement(node.orelse)
|
||||||
# TODO: last statement is a terminator?
|
# # TODO: last statement is a terminator?
|
||||||
if not is_terminator:
|
# if not is_terminator:
|
||||||
self.builder.br(endif_bb)
|
# self.builder.br(endif_bb)
|
||||||
self._seal_block(endif_bb)
|
# self._seal_block(endif_bb)
|
||||||
self.builder.set_insert_block(endif_bb)
|
# self.builder.set_insert_block(endif_bb)
|
||||||
|
parent_lvalues = self.lvalues.copy()
|
||||||
|
self.visit_compound_statement(node.body)
|
||||||
|
then_lvalues = self.lvalues.copy()
|
||||||
|
assert node.orelse
|
||||||
|
self.lvalues = parent_lvalues
|
||||||
|
self.visit_compound_statement(node.orelse)
|
||||||
|
else_lvalues = self.lvalues.copy()
|
||||||
|
|
||||||
|
self.lvalues = join_if_lvalues(then_lvalues, else_lvalues)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if isinstance(cond, triton.language.constexpr):
|
if isinstance(cond, triton.language.constexpr):
|
||||||
cond = cond.value
|
cond = cond.value
|
||||||
@@ -500,47 +510,69 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
return
|
return
|
||||||
# create nodes
|
# # create nodes
|
||||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
# st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
# ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||||
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
# arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
||||||
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
# arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
||||||
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
# arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
||||||
init_node = ast.Assign(targets=[st_target], value=arg_0)
|
# init_node = ast.Assign(targets=[st_target], value=arg_0)
|
||||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
# pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
# neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
# pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
# build_cond = lambda: triton.language.where(self.visit(pos_step_node),
|
||||||
self.visit(pos_cond_node),
|
# self.visit(pos_cond_node),
|
||||||
self.visit(neg_cond_node),
|
# self.visit(neg_cond_node),
|
||||||
_builder=self.builder)
|
# _builder=self.builder)
|
||||||
# cond_node = neg_cond_node
|
# # cond_node = neg_cond_node
|
||||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
# step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||||
# code generation
|
# # code generation
|
||||||
current_bb = self.builder.get_insertion_block()
|
# current_bb = self.builder.get_insertion_block()
|
||||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
# loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
# next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||||
|
|
||||||
def continue_fn():
|
# def continue_fn():
|
||||||
self.visit(step_node)
|
# self.visit(step_node)
|
||||||
cond = build_cond()
|
# cond = build_cond()
|
||||||
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
# return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||||
|
|
||||||
self.visit(init_node)
|
# self.visit(init_node)
|
||||||
cond = build_cond()
|
# cond = build_cond()
|
||||||
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
# self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||||
self.builder.set_insert_block(loop_bb)
|
# self.builder.set_insert_block(loop_bb)
|
||||||
self.visit_compound_statement(node.body)
|
# self.visit_compound_statement(node.body)
|
||||||
# TODO: handle case where body breaks control flow
|
# # TODO: handle case where body breaks control flow
|
||||||
continue_fn()
|
# continue_fn()
|
||||||
stop_bb = self.builder.get_insertion_block()
|
# stop_bb = self.builder.get_insertion_block()
|
||||||
self._seal_block(stop_bb)
|
# self._seal_block(stop_bb)
|
||||||
self._seal_block(loop_bb)
|
# self._seal_block(loop_bb)
|
||||||
self._seal_block(next_bb)
|
# self._seal_block(next_bb)
|
||||||
self.builder.set_insert_block(next_bb)
|
# self.builder.set_insert_block(next_bb)
|
||||||
|
|
||||||
for stmt in node.orelse:
|
# for stmt in node.orelse:
|
||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
# ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
|
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||||
|
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||||
|
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||||
|
|
||||||
|
loop_body = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_start(loop_body)
|
||||||
|
|
||||||
|
prev_defs = self.local_defs.copy()
|
||||||
|
self.local_defs = set()
|
||||||
|
|
||||||
|
# visit loop body
|
||||||
|
parent_lvalues = self.lvalues.copy()
|
||||||
|
self.visit_compound_statement()
|
||||||
|
loop_lvalues = self.lvalues.copy()
|
||||||
|
|
||||||
|
# TODO: update insertion point
|
||||||
|
# TODO: create scf.forOp
|
||||||
|
# self.lvalues = join_loop_lvalues(parent_lvalues, loop_lvalues)
|
||||||
|
# self.make_for_op(parent_lvalues, loop_lvalues, lb, ub, step)
|
||||||
|
for_op = self.builder.create_for_op(lb, ub, step, [loop_init_args])
|
||||||
|
|
||||||
|
self.local_defs = prev_defs
|
||||||
|
|
||||||
def visit_Slice(self, node):
|
def visit_Slice(self, node):
|
||||||
lower = self.visit(node.lower)
|
lower = self.visit(node.lower)
|
||||||
|
Reference in New Issue
Block a user