[Analysis] Added Axis Info Analysis (#8)
This commit is contained in:
@@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
||||
return self.addEntryBlock();
|
||||
}, ret::reference)
|
||||
.def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){
|
||||
// set arg attributes "name" to value "val"
|
||||
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
}, ret::reference)
|
||||
.def("reset_type", &mlir::FuncOp::setType)
|
||||
;
|
||||
|
||||
@@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType();
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
auto ret = self.createOrFold<mlir::triton::SplatOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||
);
|
||||
return ret;
|
||||
})
|
||||
// // atomic
|
||||
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
@@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
})
|
||||
.def("add_sccp_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSCCPPass());
|
||||
})
|
||||
.def("add_symbol_dce_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSymbolDCEPass());
|
||||
})
|
||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createInlinerPass());
|
||||
})
|
||||
|
Reference in New Issue
Block a user