Add Triton CombineOps
This commit is contained in:
@@ -1254,7 +1254,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>())
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
return self.createOrFold<mlir::triton::BroadcastOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg
|
||||
);
|
||||
throw std::runtime_error("arg is not of RankedTensorType, use create_splat");
|
||||
@@ -1323,12 +1323,15 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
|
||||
self.run(mod.getOperation());
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
})
|
||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createInlinerPass());
|
||||
})
|
||||
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createCanonicalizerPass());
|
||||
})
|
||||
;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user