Add TritonCombineOps
This commit is contained in:
@@ -179,6 +179,7 @@ add_library(triton SHARED ${PYTHON_SRC})
|
|||||||
target_link_libraries(triton
|
target_link_libraries(triton
|
||||||
${PYTHON_LIBRARIES}
|
${PYTHON_LIBRARIES}
|
||||||
TritonIR
|
TritonIR
|
||||||
|
TritonTransforms
|
||||||
TritonDriver
|
TritonDriver
|
||||||
# optimizations
|
# optimizations
|
||||||
MLIRPass
|
MLIRPass
|
||||||
|
@@ -1,10 +1,2 @@
|
|||||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
add_subdirectory(IR)
|
||||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
add_subdirectory(Transforms)
|
||||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
|
||||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
|
||||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
|
||||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
|
||||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
|
||||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
|
||||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
|
||||||
add_public_tablegen_target(TritonTableGen)
|
|
||||||
|
10
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
10
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||||
|
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||||
|
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||||
|
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||||
|
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||||
|
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||||
|
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||||
|
add_public_tablegen_target(TritonTableGen)
|
@@ -8,11 +8,11 @@
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
|
||||||
#include "triton/Dialect/Triton/Traits.h"
|
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||||
#include "triton/Dialect/Triton/Dialect.h.inc"
|
#include "triton/Dialect/Triton/Dialect.h.inc"
|
||||||
#include "triton/Dialect/Triton/OpsEnums.h.inc"
|
#include "triton/Dialect/Triton/OpsEnums.h.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/Dialect/Triton/Ops.h.inc"
|
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||||
|
|
||||||
#endif // TRITON_IR_DIALECT_H_
|
#endif // TRITON_IR_DIALECT_H_
|
@@ -5,6 +5,6 @@
|
|||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
#define GET_TYPEDEF_CLASSES
|
#define GET_TYPEDEF_CLASSES
|
||||||
#include "triton/Dialect/Triton/Types.h.inc"
|
#include "triton/Dialect/Triton/IR/Types.h.inc"
|
||||||
|
|
||||||
#endif // TRITON_IR_TYPES_H_
|
#endif // TRITON_IR_TYPES_H_
|
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
|
||||||
|
add_public_tablegen_target(TritonTransformsIncGen)
|
14
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
14
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||||
|
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createCombineOpsPass();
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
#ifndef TRITON_PASSES
|
||||||
|
#define TRITON_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
|
||||||
|
let summary = "combine ops";
|
||||||
|
let description = [{
|
||||||
|
dot(a, b, 0) + c => dot(a, b, c)
|
||||||
|
|
||||||
|
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
|
||||||
|
|
||||||
|
select(cond, load(ptrs, broadcast(cond), ???), other) =>
|
||||||
|
load(ptrs, broadcast(cond), other)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let constructor = "mlir::triton::createCombineOpsPass";
|
||||||
|
|
||||||
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
|
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -1,18 +1,2 @@
|
|||||||
add_mlir_dialect_library(TritonIR
|
add_subdirectory(IR)
|
||||||
Dialect.cpp
|
add_subdirectory(Transforms)
|
||||||
Ops.cpp
|
|
||||||
Types.cpp
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
TritonTableGen
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRArithmetic
|
|
||||||
MLIRSCF
|
|
||||||
|
|
||||||
# Since LLVM 15
|
|
||||||
# MLIRFunc
|
|
||||||
# else
|
|
||||||
MLIRStandard
|
|
||||||
)
|
|
||||||
|
18
lib/Dialect/Triton/IR/CMakeLists.txt
Normal file
18
lib/Dialect/Triton/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
add_mlir_dialect_library(TritonIR
|
||||||
|
Dialect.cpp
|
||||||
|
Ops.cpp
|
||||||
|
Types.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TritonTableGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRArithmetic
|
||||||
|
MLIRSCF
|
||||||
|
|
||||||
|
# Since LLVM 15
|
||||||
|
# MLIRFunc
|
||||||
|
# else
|
||||||
|
MLIRStandard
|
||||||
|
)
|
@@ -1,5 +1,5 @@
|
|||||||
#include "triton/Dialect/Triton/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
@@ -1,5 +1,5 @@
|
|||||||
#include "triton/Dialect/Triton/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
@@ -1,5 +1,5 @@
|
|||||||
#include "triton/Dialect/Triton/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
||||||
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
||||||
|
|
||||||
@@ -7,7 +7,7 @@ using namespace mlir;
|
|||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
#define GET_TYPEDEF_CLASSES
|
#define GET_TYPEDEF_CLASSES
|
||||||
#include "triton/Dialect/Triton/Types.cpp.inc"
|
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Triton Dialect
|
// Triton Dialect
|
||||||
@@ -15,7 +15,7 @@ using namespace mlir::triton;
|
|||||||
void TritonDialect::registerTypes() {
|
void TritonDialect::registerTypes() {
|
||||||
addTypes<
|
addTypes<
|
||||||
#define GET_TYPEDEF_LIST
|
#define GET_TYPEDEF_LIST
|
||||||
#include "triton/Dialect/Triton/Types.cpp.inc"
|
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
6
lib/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
6
lib/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
add_mlir_dialect_library(TritonTransforms
|
||||||
|
Combine.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TritonTransformsIncGen
|
||||||
|
)
|
141
lib/Dialect/Triton/Transforms/Combine.cpp
Normal file
141
lib/Dialect/Triton/Transforms/Combine.cpp
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
// using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// dot(a, b, 0) + c => dot(a, b, c)
|
||||||
|
class CombineDotOp : public mlir::RewritePattern {
|
||||||
|
public:
|
||||||
|
CombineDotOp(mlir::MLIRContext *context)
|
||||||
|
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||||
|
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
|
||||||
|
if (isCandidate(op->getOperand(0)).succeeded()) {
|
||||||
|
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
|
||||||
|
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||||
|
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||||
|
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
|
||||||
|
return mlir::success();
|
||||||
|
} else if (isCandidate(op->getOperand(1)).succeeded()) {
|
||||||
|
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
|
||||||
|
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
||||||
|
op, dotOp->getResultTypes().front(), dotOp.a(),
|
||||||
|
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Is this value a dot and has 0 as `c`.
|
||||||
|
mlir::LogicalResult isCandidate(mlir::Value val) const {
|
||||||
|
if (auto dot = val.getDefiningOp<mlir::triton::DotOp>()) {
|
||||||
|
if (isZero(dot.c()))
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isZero(mlir::Value val) const {
|
||||||
|
if (mlir::matchPattern(val, mlir::m_Zero()) ||
|
||||||
|
mlir::matchPattern(val, mlir::m_AnyZeroFloat()))
|
||||||
|
return true;
|
||||||
|
// broadcast(constant_0)
|
||||||
|
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||||
|
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
|
||||||
|
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
|
||||||
|
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
||||||
|
// (ref: ArithmeticCanonicalization.td)
|
||||||
|
class CombineGEPOp : public mlir::RewritePattern {
|
||||||
|
public:
|
||||||
|
CombineGEPOp(mlir::MLIRContext *context)
|
||||||
|
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
if (llvm::isa<mlir::triton::GEPOp>(op)) {
|
||||||
|
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
|
||||||
|
loc, op->getOperand(1), gep2->getOperand(1));
|
||||||
|
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
|
||||||
|
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
|
||||||
|
);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||||
|
// => load(ptrs, broadcast(cond), other)
|
||||||
|
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
|
||||||
|
public:
|
||||||
|
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
|
||||||
|
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
if (llvm::isa<mlir::SelectOp>(op)) {
|
||||||
|
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
|
||||||
|
mlir::Value cond = op->getOperand(0);
|
||||||
|
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
|
||||||
|
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
|
||||||
|
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
|
||||||
|
op, op->getResultTypes().front(),
|
||||||
|
load.ptr(), load.mask(), op->getOperand(2),
|
||||||
|
load.cache(), load.evict(), load.isVolatile()
|
||||||
|
);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
|
||||||
|
public:
|
||||||
|
void runOnOperation() override {
|
||||||
|
mlir::MLIRContext *context = &getContext();
|
||||||
|
mlir::RewritePatternSet patterns(context);
|
||||||
|
mlir::ModuleOp m = getOperation();
|
||||||
|
|
||||||
|
patterns.add<CombineDotOp>(context);
|
||||||
|
patterns.add<CombineSelectMaskedLoadOp>(context);
|
||||||
|
patterns.add<CombineGEPOp>(context);
|
||||||
|
// patterns.add<CombineReduceOp>(context);
|
||||||
|
|
||||||
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::triton::createCombineOpsPass() {
|
||||||
|
return std::make_unique<CombineOpsPass>();
|
||||||
|
}
|
@@ -12,8 +12,10 @@
|
|||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
|
||||||
#include "triton/Dialect/Triton/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
|
|
||||||
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
@@ -1332,6 +1334,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
|
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createCanonicalizerPass());
|
self.addPass(mlir::createCanonicalizerPass());
|
||||||
})
|
})
|
||||||
|
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::triton::createCombineOpsPass());
|
||||||
|
})
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -99,6 +99,7 @@ mod.dump()
|
|||||||
|
|
||||||
pm = _triton.ir.pass_manager(ctx)
|
pm = _triton.ir.pass_manager(ctx)
|
||||||
pm.add_inliner_pass()
|
pm.add_inliner_pass()
|
||||||
|
pm.add_triton_combine_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user