[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)
This commit is contained in:
@@ -165,7 +165,13 @@ void init_triton_ir(py::module &&m) {
|
||||
else {
|
||||
/* issue an warning */
|
||||
}
|
||||
});
|
||||
})
|
||||
.def("replace_all_uses_with",
|
||||
[](mlir::Value &self, mlir::Value &newValue) {
|
||||
self.replaceAllUsesWith(newValue);
|
||||
})
|
||||
|
||||
;
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
|
||||
|
||||
py::class_<mlir::Region>(m, "region")
|
||||
@@ -189,7 +195,7 @@ void init_triton_ir(py::module &&m) {
|
||||
if (self.getNumArguments() != 0)
|
||||
throw std::runtime_error(
|
||||
"This block has arguments, don't merge");
|
||||
dst.getOperations().splice(dst.end(), self.getOperations());
|
||||
dst.getOperations().splice(dst.begin(), self.getOperations());
|
||||
self.dropAllUses();
|
||||
self.erase();
|
||||
})
|
||||
@@ -262,7 +268,9 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::succeeded(mlir::verify(self.getOperation()));
|
||||
});
|
||||
// scf Ops
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
|
||||
.def("get_induction_var", &mlir::scf::ForOp::getInductionVar);
|
||||
|
||||
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
|
||||
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
||||
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
||||
@@ -501,24 +509,18 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
|
||||
// Ops
|
||||
.def("create_function",
|
||||
[](mlir::OpBuilder &self, std::string name,
|
||||
mlir::Type &funcType) -> mlir::FuncOp {
|
||||
// TODO: loc
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, name, funcTy);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("get_or_insert_function",
|
||||
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
|
||||
std::string &funcName, mlir::Type &funcType,
|
||||
std::string &visibility) -> mlir::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
|
||||
mlir::ArrayRef<mlir::NamedAttribute> attrs = {
|
||||
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
||||
self.getStringAttr(visibility))};
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
@@ -648,6 +650,12 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getIndexType());
|
||||
})
|
||||
.def("create_index_to_si",
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getI32Type());
|
||||
})
|
||||
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
@@ -1065,10 +1073,11 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_dot",
|
||||
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
|
||||
mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||
mlir::Value &c, bool allowTF32, bool transA,
|
||||
bool transB) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||
allowTF32);
|
||||
allowTF32, transA, transB);
|
||||
})
|
||||
.def("create_exp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
@@ -1095,7 +1104,6 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::SqrtOp>(loc, val);
|
||||
})
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
.def("create_reduce",
|
||||
[](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
@@ -1118,13 +1126,7 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
||||
falseValue);
|
||||
})
|
||||
// // Intrinsics
|
||||
// // These have no place in the IR, and hopefully they can be removed at
|
||||
// some point .def("create_umulhi", &ir::builder::create_umulhi,
|
||||
// ret::reference) .def("create_barrier", &ir::builder::create_barrier,
|
||||
// ret::reference);
|
||||
;
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
@@ -1144,8 +1146,11 @@ void init_triton_ir(py::module &&m) {
|
||||
printingFlags);
|
||||
})
|
||||
.def("run",
|
||||
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
[](mlir::PassManager &self, mlir::ModuleOp &mod) {
|
||||
// TODO: maybe dump module to file and print error for better
|
||||
// diagnostics
|
||||
if (mlir::failed(self.run(mod.getOperation())))
|
||||
throw std::runtime_error("PassManager::run failed");
|
||||
})
|
||||
.def(
|
||||
"add_sccp_pass",
|
||||
@@ -1168,6 +1173,10 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("add_cse_pass",
|
||||
[](mlir::PassManager &self) { self.addPass(mlir::createCSEPass()); })
|
||||
.def("add_licm_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createLoopInvariantCodeMotionPass());
|
||||
})
|
||||
.def("add_triton_combine_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createCombineOpsPass());
|
||||
|
Reference in New Issue
Block a user