diff --git a/CMakeLists.txt b/CMakeLists.txt index ff4b9ceee..efb47f7df 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,6 +179,7 @@ add_library(triton SHARED ${PYTHON_SRC}) target_link_libraries(triton ${PYTHON_LIBRARIES} TritonIR + TritonTransforms TritonDriver # optimizations MLIRPass diff --git a/include/triton/Dialect/Triton/CMakeLists.txt b/include/triton/Dialect/Triton/CMakeLists.txt index 46573add6..9f57627c3 100644 --- a/include/triton/Dialect/Triton/CMakeLists.txt +++ b/include/triton/Dialect/Triton/CMakeLists.txt @@ -1,10 +1,2 @@ -set(LLVM_TARGET_DEFINITIONS TritonOps.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) -mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) -mlir_tablegen(Types.h.inc -gen-typedef-decls) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs) -add_public_tablegen_target(TritonTableGen) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..46573add6 --- /dev/null +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/Dialect/Triton/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h similarity index 81% rename from include/triton/Dialect/Triton/Dialect.h rename to include/triton/Dialect/Triton/IR/Dialect.h index 0cb893e2f..fc25667b9 100644 --- a/include/triton/Dialect/Triton/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -8,11 +8,11 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/SCF/SCF.h" -#include "triton/Dialect/Triton/Traits.h" +#include "triton/Dialect/Triton/IR/Traits.h" #include "triton/Dialect/Triton/Dialect.h.inc" #include "triton/Dialect/Triton/OpsEnums.h.inc" #define GET_OP_CLASSES -#include "triton/Dialect/Triton/Ops.h.inc" +#include "triton/Dialect/Triton/IR/Ops.h.inc" #endif // TRITON_IR_DIALECT_H_ diff --git a/include/triton/Dialect/Triton/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h similarity index 100% rename from include/triton/Dialect/Triton/Traits.h rename to include/triton/Dialect/Triton/IR/Traits.h diff --git a/include/triton/Dialect/Triton/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td similarity index 100% rename from include/triton/Dialect/Triton/TritonDialect.td rename to include/triton/Dialect/Triton/IR/TritonDialect.td diff --git a/include/triton/Dialect/Triton/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td similarity index 100% rename from include/triton/Dialect/Triton/TritonOps.td rename to include/triton/Dialect/Triton/IR/TritonOps.td diff --git a/include/triton/Dialect/Triton/Types.h b/include/triton/Dialect/Triton/IR/Types.h similarity index 78% rename from include/triton/Dialect/Triton/Types.h rename to include/triton/Dialect/Triton/IR/Types.h index dad9ae091..5ffd7db35 100644 --- a/include/triton/Dialect/Triton/Types.h +++ b/include/triton/Dialect/Triton/IR/Types.h @@ -5,6 +5,6 @@ #include "mlir/IR/Types.h" #define GET_TYPEDEF_CLASSES -#include "triton/Dialect/Triton/Types.h.inc" +#include "triton/Dialect/Triton/IR/Types.h.inc" #endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/include/triton/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..372a9ec11 --- /dev/null +++ b/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) +add_public_tablegen_target(TritonTransformsIncGen) diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 000000000..71909734a --- /dev/null +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,14 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createCombineOpsPass(); + +} +} + +#endif diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 000000000..bd568cb3f --- /dev/null +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,23 @@ +#ifndef TRITON_PASSES +#define TRITON_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonCombineOps : Pass { + let summary = "combine ops"; + let description = [{ + dot(a, b, 0) + c => dot(a, b, c) + + gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1)) + + select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other) + }]; + + let constructor = "mlir::triton::createCombineOpsPass"; + + let dependentDialects = ["mlir::arith::ArithmeticDialect", + /*SelectOp*/"mlir::StandardOpsDialect"]; +} + +#endif diff --git a/lib/Dialect/Triton/CMakeLists.txt b/lib/Dialect/Triton/CMakeLists.txt index 2fa15a9b9..9f57627c3 100644 --- a/lib/Dialect/Triton/CMakeLists.txt +++ b/lib/Dialect/Triton/CMakeLists.txt @@ -1,18 +1,2 @@ -add_mlir_dialect_library(TritonIR - Dialect.cpp - Ops.cpp - Types.cpp - - DEPENDS - TritonTableGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRArithmetic - MLIRSCF - - # Since LLVM 15 - # MLIRFunc - # else - MLIRStandard -) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..2fa15a9b9 --- /dev/null +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(TritonIR + Dialect.cpp + Ops.cpp + Types.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithmetic + MLIRSCF + + # Since LLVM 15 + # MLIRFunc + # else + MLIRStandard +) diff --git a/lib/Dialect/Triton/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp similarity index 83% rename from lib/Dialect/Triton/Dialect.cpp rename to lib/Dialect/Triton/IR/Dialect.cpp index 33188f157..e21384d47 100644 --- a/lib/Dialect/Triton/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -1,5 +1,5 @@ -#include "triton/Dialect/Triton/Dialect.h" -#include "triton/Dialect/Triton/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/lib/Dialect/Triton/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp similarity index 97% rename from lib/Dialect/Triton/Ops.cpp rename to lib/Dialect/Triton/IR/Ops.cpp index ec64927b3..acb4adf6c 100644 --- a/lib/Dialect/Triton/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1,5 +1,5 @@ -#include "triton/Dialect/Triton/Dialect.h" -#include "triton/Dialect/Triton/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/lib/Dialect/Triton/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp similarity index 81% rename from lib/Dialect/Triton/Types.cpp rename to lib/Dialect/Triton/IR/Types.cpp index 5aa8c8773..5ad151b19 100644 --- a/lib/Dialect/Triton/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -1,5 +1,5 @@ -#include "triton/Dialect/Triton/Dialect.h" -#include "triton/Dialect/Triton/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` #include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` @@ -7,7 +7,7 @@ using namespace mlir; using namespace mlir::triton; #define GET_TYPEDEF_CLASSES -#include "triton/Dialect/Triton/Types.cpp.inc" +#include "triton/Dialect/Triton/IR/Types.cpp.inc" //===----------------------------------------------------------------------===// // Triton Dialect @@ -15,7 +15,7 @@ using namespace mlir::triton; void TritonDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST -#include "triton/Dialect/Triton/Types.cpp.inc" +#include "triton/Dialect/Triton/IR/Types.cpp.inc" >(); } diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..61e1a97a2 --- /dev/null +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect_library(TritonTransforms + Combine.cpp + + DEPENDS + TritonTransformsIncGen +) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 000000000..aaed058b7 --- /dev/null +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,141 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#include + +// using namespace mlir; + +namespace { +// dot(a, b, 0) + c => dot(a, b, c) +class CombineDotOp : public mlir::RewritePattern { +public: + CombineDotOp(mlir::MLIRContext *context) + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {} + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (llvm::isa(op)) { + if (isCandidate(op->getOperand(0)).succeeded()) { + auto dotOp = op->getOperand(0).getDefiningOp(); + rewriter.replaceOpWithNewOp( + op, dotOp->getResultTypes().front(), dotOp.a(), + dotOp.b(), op->getOperand(1), dotOp.allowTF32()); + return mlir::success(); + } else if (isCandidate(op->getOperand(1)).succeeded()) { + auto dotOp = op->getOperand(1).getDefiningOp(); + rewriter.replaceOpWithNewOp( + op, dotOp->getResultTypes().front(), dotOp.a(), + dotOp.b(), op->getOperand(0), dotOp.allowTF32()); + return mlir::success(); + } + } + return mlir::failure(); + } + +private: + // Is this value a dot and has 0 as `c`. + mlir::LogicalResult isCandidate(mlir::Value val) const { + if (auto dot = val.getDefiningOp()) { + if (isZero(dot.c())) + return mlir::success(); + } + return mlir::failure(); + } + + bool isZero(mlir::Value val) const { + if (mlir::matchPattern(val, mlir::m_Zero()) || + mlir::matchPattern(val, mlir::m_AnyZeroFloat())) + return true; + // broadcast(constant_0) + if (auto bc = val.getDefiningOp()) { + if (mlir::matchPattern(bc.src(), mlir::m_Zero()) || + mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat())) + return true; + } + return false; + } +}; + +// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect +// (ref: ArithmeticCanonicalization.td) +class CombineGEPOp : public mlir::RewritePattern { +public: + CombineGEPOp(mlir::MLIRContext *context) + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {} + + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (llvm::isa(op)) { + if (auto gep2 = op->getOperand(0).getDefiningOp()) { + auto loc = op->getLoc(); + mlir::Value newIdx = rewriter.create( + loc, op->getOperand(1), gep2->getOperand(1)); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), gep2->getOperand(0), newIdx + ); + return mlir::success(); + } + } + return mlir::failure(); + } +}; + +// select(cond, load(ptrs, broadcast(cond), ???), other) +// => load(ptrs, broadcast(cond), other) +class CombineSelectMaskedLoadOp : public mlir::RewritePattern { +public: + CombineSelectMaskedLoadOp(mlir::MLIRContext *context) + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {} + + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (llvm::isa(op)) { + if (auto load = op->getOperand(1).getDefiningOp()) { + mlir::Value cond = op->getOperand(0); + if (auto bc = load.mask().getDefiningOp()) { + if (bc.src().getDefiningOp() == cond.getDefiningOp()) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), + load.ptr(), load.mask(), op->getOperand(2), + load.cache(), load.evict(), load.isVolatile() + ); + return mlir::success(); + } + } + } + } + return mlir::failure(); + } +}; +} // anonymous namespace + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +class CombineOpsPass : public TritonCombineOpsBase { +public: + void runOnOperation() override { + mlir::MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + mlir::ModuleOp m = getOperation(); + + patterns.add(context); + patterns.add(context); + patterns.add(context); + // patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr mlir::triton::createCombineOpsPass() { + return std::make_unique(); +} diff --git a/python/src/triton.cc b/python/src/triton.cc index e69b0842a..97964394f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -12,8 +12,10 @@ #include "mlir/Transforms/Passes.h" -#include "triton/Dialect/Triton/Dialect.h" -#include "triton/Dialect/Triton/Types.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "triton/Dialect/Triton/Transforms/Passes.h" #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -1332,6 +1334,9 @@ void init_triton_ir(py::module &&m) { .def("add_canonicalizer_pass", [](mlir::PassManager &self) { self.addPass(mlir::createCanonicalizerPass()); }) + .def("add_triton_combine_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::createCombineOpsPass()); + }) ; } diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py index 69897a8d1..10d2b7da5 100644 --- a/rewrite-test/jit/matmul/matmul.py +++ b/rewrite-test/jit/matmul/matmul.py @@ -99,6 +99,7 @@ mod.dump() pm = _triton.ir.pass_manager(ctx) pm.add_inliner_pass() +pm.add_triton_combine_pass() pm.add_canonicalizer_pass() pm.run(mod)