Add TritonCombineOps

This commit is contained in:
Yan Da
2022-04-27 19:28:21 +08:00
parent c70f6b666e
commit 8dfe78f6cf
20 changed files with 239 additions and 41 deletions

View File

@@ -179,6 +179,7 @@ add_library(triton SHARED ${PYTHON_SRC})
target_link_libraries(triton
${PYTHON_LIBRARIES}
TritonIR
TritonTransforms
TritonDriver
# optimizations
MLIRPass

View File

@@ -1,10 +1,2 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(TritonTableGen)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,10 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(TritonTableGen)

View File

@@ -8,11 +8,11 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "triton/Dialect/Triton/Traits.h"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/Dialect.h.inc"
#include "triton/Dialect/Triton/OpsEnums.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/Ops.h.inc"
#include "triton/Dialect/Triton/IR/Ops.h.inc"
#endif // TRITON_IR_DIALECT_H_

View File

@@ -5,6 +5,6 @@
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/Types.h.inc"
#include "triton/Dialect/Triton/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
add_public_tablegen_target(TritonTransformsIncGen)

View File

@@ -0,0 +1,14 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
}
}
#endif

View File

@@ -0,0 +1,23 @@
#ifndef TRITON_PASSES
#define TRITON_PASSES
include "mlir/Pass/PassBase.td"
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
let summary = "combine ops";
let description = [{
dot(a, b, 0) + c => dot(a, b, c)
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)
}];
let constructor = "mlir::triton::createCombineOpsPass";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
/*SelectOp*/"mlir::StandardOpsDialect"];
}
#endif

View File

@@ -1,18 +1,2 @@
add_mlir_dialect_library(TritonIR
Dialect.cpp
Ops.cpp
Types.cpp
DEPENDS
TritonTableGen
LINK_LIBS PUBLIC
MLIRIR
MLIRArithmetic
MLIRSCF
# Since LLVM 15
# MLIRFunc
# else
MLIRStandard
)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,18 @@
add_mlir_dialect_library(TritonIR
Dialect.cpp
Ops.cpp
Types.cpp
DEPENDS
TritonTableGen
LINK_LIBS PUBLIC
MLIRIR
MLIRArithmetic
MLIRSCF
# Since LLVM 15
# MLIRFunc
# else
MLIRStandard
)

View File

@@ -1,5 +1,5 @@
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"

View File

@@ -1,5 +1,5 @@
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"

View File

@@ -1,5 +1,5 @@
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
@@ -7,7 +7,7 @@ using namespace mlir;
using namespace mlir::triton;
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/Types.cpp.inc"
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
//===----------------------------------------------------------------------===//
// Triton Dialect
@@ -15,7 +15,7 @@ using namespace mlir::triton;
void TritonDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/Dialect/Triton/Types.cpp.inc"
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
>();
}

View File

@@ -0,0 +1,6 @@
add_mlir_dialect_library(TritonTransforms
Combine.cpp
DEPENDS
TritonTransformsIncGen
)

View File

@@ -0,0 +1,141 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include <memory>
// using namespace mlir;
namespace {
// dot(a, b, 0) + c => dot(a, b, c)
class CombineDotOp : public mlir::RewritePattern {
public:
CombineDotOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
if (isCandidate(op->getOperand(0)).succeeded()) {
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
return mlir::success();
} else if (isCandidate(op->getOperand(1)).succeeded()) {
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
return mlir::success();
}
}
return mlir::failure();
}
private:
// Is this value a dot and has 0 as `c`.
mlir::LogicalResult isCandidate(mlir::Value val) const {
if (auto dot = val.getDefiningOp<mlir::triton::DotOp>()) {
if (isZero(dot.c()))
return mlir::success();
}
return mlir::failure();
}
bool isZero(mlir::Value val) const {
if (mlir::matchPattern(val, mlir::m_Zero()) ||
mlir::matchPattern(val, mlir::m_AnyZeroFloat()))
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
return true;
}
return false;
}
};
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
// (ref: ArithmeticCanonicalization.td)
class CombineGEPOp : public mlir::RewritePattern {
public:
CombineGEPOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::triton::GEPOp>(op)) {
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
auto loc = op->getLoc();
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
loc, op->getOperand(1), gep2->getOperand(1));
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
);
return mlir::success();
}
}
return mlir::failure();
}
};
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::SelectOp>(op)) {
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
mlir::Value cond = op->getOperand(0);
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
op, op->getResultTypes().front(),
load.ptr(), load.mask(), op->getOperand(2),
load.cache(), load.evict(), load.isVolatile()
);
return mlir::success();
}
}
}
}
return mlir::failure();
}
};
} // anonymous namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::ModuleOp m = getOperation();
patterns.add<CombineDotOp>(context);
patterns.add<CombineSelectMaskedLoadOp>(context);
patterns.add<CombineGEPOp>(context);
// patterns.add<CombineReduceOp>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};
std::unique_ptr<mlir::Pass> mlir::triton::createCombineOpsPass() {
return std::make_unique<CombineOpsPass>();
}

View File

@@ -12,8 +12,10 @@
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/Triton/Dialect.h"
#include "triton/Dialect/Triton/Types.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
@@ -1332,6 +1334,9 @@ void init_triton_ir(py::module &&m) {
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createCanonicalizerPass());
})
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());
})
;
}

View File

@@ -99,6 +99,7 @@ mod.dump()
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.run(mod)