[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)

This commit is contained in:
Philippe Tillet
2022-10-21 16:52:15 -07:00
committed by GitHub
parent c4726333bf
commit bb0f9235d1
26 changed files with 683 additions and 229 deletions

View File

@@ -165,7 +165,13 @@ void init_triton_ir(py::module &&m) {
else {
/* issue an warning */
}
});
})
.def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
})
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::Region>(m, "region")
@@ -189,7 +195,7 @@ void init_triton_ir(py::module &&m) {
if (self.getNumArguments() != 0)
throw std::runtime_error(
"This block has arguments, don't merge");
dst.getOperations().splice(dst.end(), self.getOperations());
dst.getOperations().splice(dst.begin(), self.getOperations());
self.dropAllUses();
self.erase();
})
@@ -262,7 +268,9 @@ void init_triton_ir(py::module &&m) {
return mlir::succeeded(mlir::verify(self.getOperation()));
});
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
.def("get_induction_var", &mlir::scf::ForOp::getInductionVar);
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
@@ -501,24 +509,18 @@ void init_triton_ir(py::module &&m) {
})
// Ops
.def("create_function",
[](mlir::OpBuilder &self, std::string name,
mlir::Type &funcType) -> mlir::FuncOp {
// TODO: loc
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, name, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("get_or_insert_function",
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
std::string &funcName, mlir::Type &funcType,
std::string &visibility) -> mlir::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
mlir::ArrayRef<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))};
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
}
throw std::runtime_error("invalid function type");
})
@@ -648,6 +650,12 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getIndexType());
})
.def("create_index_to_si",
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getI32Type());
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
@@ -1065,10 +1073,11 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32) -> mlir::Value {
mlir::Value &c, bool allowTF32, bool transA,
bool transB) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32);
allowTF32, transA, transB);
})
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
@@ -1095,7 +1104,6 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::SqrtOp>(loc, val);
})
// .def("create_trans", &ir::builder::create_trans, ret::reference)
.def("create_reduce",
[](mlir::OpBuilder &self, mlir::Value &operand,
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
@@ -1118,13 +1126,7 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::SelectOp>(loc, condition, trueValue,
falseValue);
})
// // Intrinsics
// // These have no place in the IR, and hopefully they can be removed at
// some point .def("create_umulhi", &ir::builder::create_umulhi,
// ret::reference) .def("create_barrier", &ir::builder::create_barrier,
// ret::reference);
;
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
@@ -1144,8 +1146,11 @@ void init_triton_ir(py::module &&m) {
printingFlags);
})
.def("run",
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation()));
[](mlir::PassManager &self, mlir::ModuleOp &mod) {
// TODO: maybe dump module to file and print error for better
// diagnostics
if (mlir::failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
})
.def(
"add_sccp_pass",
@@ -1168,6 +1173,10 @@ void init_triton_ir(py::module &&m) {
})
.def("add_cse_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createCSEPass()); })
.def("add_licm_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createLoopInvariantCodeMotionPass());
})
.def("add_triton_combine_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());