More python bindings
This commit is contained in:
@@ -664,6 +664,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.getArgument(index);
|
return self.getArgument(index);
|
||||||
})
|
})
|
||||||
.def("dump", &mlir::Block::dump)
|
.def("dump", &mlir::Block::dump)
|
||||||
|
.def("move_before", &mlir::Block::moveBefore)
|
||||||
;
|
;
|
||||||
|
|
||||||
// py::class_<mlir::ModuleOp>(m, "module")
|
// py::class_<mlir::ModuleOp>(m, "module")
|
||||||
@@ -705,15 +706,28 @@ void init_triton_ir(py::module &&m) {
|
|||||||
;
|
;
|
||||||
|
|
||||||
// Ops
|
// Ops
|
||||||
py::class_<mlir::scf::ForOp>(m, "ForOp")
|
py::class_<mlir::OpState>(m, "OpState")
|
||||||
.def("get_body", &mlir::scf::ForOp::getBody, ret::reference);
|
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
|
||||||
py::class_<mlir::scf::IfOp>(m, "IfOp")
|
return self->getResult(idx);
|
||||||
|
})
|
||||||
|
.def("get_region", [](mlir::OpState &self, unsigned idx) -> mlir::Region& {
|
||||||
|
return self->getRegion(idx);
|
||||||
|
}, ret::reference)
|
||||||
|
;
|
||||||
|
// scf Ops
|
||||||
|
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
|
||||||
|
.def("get_body", [](mlir::scf::ForOp &self) -> mlir::Block* {
|
||||||
|
return self.getBody();
|
||||||
|
}, ret::reference);
|
||||||
|
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
|
||||||
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
||||||
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
||||||
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
|
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
|
||||||
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
|
.def("get_else_yield", &mlir::scf::IfOp::elseYield)
|
||||||
;
|
;
|
||||||
py::class_<mlir::scf::YieldOp>(m, "YieldOp");
|
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
|
||||||
|
|
||||||
|
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||||
|
|
||||||
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
||||||
.def(py::init<mlir::MLIRContext *>())
|
.def(py::init<mlir::MLIRContext *>())
|
||||||
@@ -740,12 +754,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||||
return self.getInsertionBlock();
|
return self.getInsertionBlock();
|
||||||
}, ret::reference)
|
}, ret::reference)
|
||||||
// .def("get_insert_point", [](ir::builder *self) {
|
.def("get_insertion_point", &mlir::OpBuilder::saveInsertionPoint)
|
||||||
// ir::basic_block *bb = self->get_insert_block();
|
.def("restore_insertion_point", &mlir::OpBuilder::restoreInsertionPoint)
|
||||||
// ir::basic_block::iterator it = self->get_insert_point();
|
|
||||||
// ir::instruction *instr = it == bb->end() ? nullptr : *it;
|
|
||||||
// return std::make_pair(bb, instr);
|
|
||||||
// }, ret::reference)
|
|
||||||
// .def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
|
// .def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
|
||||||
// ir::basic_block *bb = pt.first;
|
// ir::basic_block *bb = pt.first;
|
||||||
// ir::instruction *instr = pt.second;
|
// ir::instruction *instr = pt.second;
|
||||||
@@ -856,9 +866,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::scf::IfOp>(loc, condition);
|
return self.create<mlir::scf::IfOp>(loc, condition);
|
||||||
})
|
})
|
||||||
.def("create_yield_op", [](mlir::OpBuilder &self) -> mlir::scf::YieldOp {
|
.def("create_yield_op", [](mlir::OpBuilder &self, std::vector<mlir::Value> &yields) -> mlir::scf::YieldOp {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::scf::YieldOp>(loc);
|
return self.create<mlir::scf::YieldOp>(loc, yields);
|
||||||
})
|
})
|
||||||
// // .def("create_while")
|
// // .def("create_while")
|
||||||
|
|
||||||
|
@@ -415,8 +415,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||||
|
|
||||||
loop_body = self.builder.create_block()
|
insert_block = self.builder.get_insertion_block()
|
||||||
self.builder.set_insertion_point_to_start(loop_body)
|
|
||||||
|
block = self.builder.create_block()
|
||||||
|
self.builder.set_insertion_point_to_start(block)
|
||||||
|
|
||||||
liveins = self.lscope.copy()
|
liveins = self.lscope.copy()
|
||||||
prev_defs = self.local_defs.copy()
|
prev_defs = self.local_defs.copy()
|
||||||
@@ -425,20 +427,44 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# visit loop body
|
# visit loop body
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
|
|
||||||
# TODO: update insertion point
|
|
||||||
init_args = []
|
init_args = []
|
||||||
yields = []
|
yields = []
|
||||||
|
names = []
|
||||||
for name in self.local_defs:
|
for name in self.local_defs:
|
||||||
if name in liveins:
|
if name in liveins:
|
||||||
assert self.is_triton_tensor(self.local_defs[name])
|
assert self.is_triton_tensor(self.local_defs[name])
|
||||||
assert self.is_triton_tensor(liveins[name])
|
assert self.is_triton_tensor(liveins[name])
|
||||||
if self.local_defs[name].type == liveins[name].type:
|
if self.local_defs[name].type == liveins[name].type:
|
||||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle)
|
# TODO: better way to do this?
|
||||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle)
|
names.append(name)
|
||||||
for_op = self.builder.create_for_op(lb, ub, step, init_args)
|
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||||
|
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||||
|
|
||||||
|
print(f'names: {names}')
|
||||||
|
print("After insert block body")
|
||||||
|
self.module.dump()
|
||||||
|
self.builder.set_insertion_point_to_end(insert_block)
|
||||||
|
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||||
|
print("After create for op")
|
||||||
|
self.module.dump()
|
||||||
|
# TODO: revisit blocks & regions of ForOp
|
||||||
|
block.move_before(for_op.get_body())
|
||||||
|
self.builder.set_insertion_point_to_end(for_op.get_body())
|
||||||
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
print("After create yield op")
|
||||||
|
self.module.dump()
|
||||||
|
|
||||||
|
|
||||||
|
self.builder.set_insertion_point_to_end(insert_block)
|
||||||
|
print("After restoring insert point")
|
||||||
|
self.module.dump()
|
||||||
|
|
||||||
self.lscope = liveins
|
self.lscope = liveins
|
||||||
self.local_defs = prev_defs
|
self.local_defs = prev_defs
|
||||||
|
# ForOp defines new values
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
self.lscope[name] = for_op.get_result(i)
|
||||||
|
self.local_defs[name] = for_op.get_result(i)
|
||||||
|
|
||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
assert False
|
assert False
|
||||||
|
Reference in New Issue
Block a user