[TritonIR] Convert Triton dialect's Combine
pass to MLIR DRR based (#16)
This commit is contained in:
@@ -1,6 +1,11 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS Combine.td)
|
||||||
|
mlir_tablegen(TritonCombine.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(TritonCombineIncGen)
|
||||||
|
|
||||||
add_mlir_dialect_library(TritonTransforms
|
add_mlir_dialect_library(TritonTransforms
|
||||||
Combine.cpp
|
Combine.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TritonTransformsIncGen
|
TritonTransformsIncGen
|
||||||
|
TritonCombineIncGen
|
||||||
)
|
)
|
||||||
|
@@ -13,44 +13,8 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// dot(a, b, 0) + c => dot(a, b, c)
|
|
||||||
class CombineDotOp : public mlir::RewritePattern {
|
|
||||||
public:
|
|
||||||
CombineDotOp(mlir::MLIRContext *context)
|
|
||||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
|
||||||
context) {}
|
|
||||||
mlir::LogicalResult
|
|
||||||
matchAndRewrite(mlir::Operation *op,
|
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
|
||||||
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
|
|
||||||
if (isCandidate(op->getOperand(0)).succeeded()) {
|
|
||||||
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
|
||||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
|
||||||
op->getOperand(1), dotOp.allowTF32());
|
|
||||||
return mlir::success();
|
|
||||||
} else if (isCandidate(op->getOperand(1)).succeeded()) {
|
|
||||||
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
|
|
||||||
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
|
|
||||||
op->getOperand(0), dotOp.allowTF32());
|
|
||||||
return mlir::success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
bool isZero(mlir::Value val) {
|
||||||
// Is this value a dot and has 0 as `c`.
|
|
||||||
mlir::LogicalResult isCandidate(mlir::Value val) const {
|
|
||||||
if (auto dot = val.getDefiningOp<mlir::triton::DotOp>()) {
|
|
||||||
if (isZero(dot.c()))
|
|
||||||
return mlir::success();
|
|
||||||
}
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isZero(mlir::Value val) const {
|
|
||||||
if (mlir::matchPattern(val, mlir::m_Zero()) ||
|
if (mlir::matchPattern(val, mlir::m_Zero()) ||
|
||||||
mlir::matchPattern(val, mlir::m_AnyZeroFloat()))
|
mlir::matchPattern(val, mlir::m_AnyZeroFloat()))
|
||||||
return true;
|
return true;
|
||||||
@@ -62,94 +26,30 @@ private:
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
|
bool isBroadcastConstantCombinable(Attribute value) {
|
||||||
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
|
||||||
// (ref: ArithmeticCanonicalization.td)
|
|
||||||
class CombineGEPOp : public mlir::RewritePattern {
|
|
||||||
public:
|
|
||||||
CombineGEPOp(mlir::MLIRContext *context)
|
|
||||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
|
||||||
context) {}
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
matchAndRewrite(mlir::Operation *op,
|
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
|
||||||
if (llvm::isa<mlir::triton::GEPOp>(op)) {
|
|
||||||
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
|
|
||||||
loc, op->getOperand(1), gep2->getOperand(1));
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
|
|
||||||
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
|
|
||||||
return mlir::success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
|
||||||
// => load(ptrs, broadcast(cond), other)
|
|
||||||
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
|
|
||||||
public:
|
|
||||||
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
|
|
||||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
|
||||||
context) {}
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
matchAndRewrite(mlir::Operation *op,
|
|
||||||
mlir::PatternRewriter &rewriter) const override {
|
|
||||||
if (llvm::isa<mlir::SelectOp>(op)) {
|
|
||||||
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
|
|
||||||
mlir::Value cond = op->getOperand(0);
|
|
||||||
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
|
|
||||||
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
|
|
||||||
op, op->getResultTypes().front(), load.ptr(), load.mask(),
|
|
||||||
op->getOperand(2), load.cache(), load.evict(),
|
|
||||||
load.isVolatile());
|
|
||||||
return mlir::success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// broadcast(cst) => cst
|
|
||||||
// TODO: move this to .td file
|
|
||||||
class CombineBroadcastConstantOp : public mlir::RewritePattern {
|
|
||||||
public:
|
|
||||||
CombineBroadcastConstantOp(mlir::MLIRContext *context)
|
|
||||||
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
|
|
||||||
context) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation *op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
|
|
||||||
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
|
|
||||||
Attribute value = cst.getValue();
|
|
||||||
Type resType = broadcast.getResult().getType();
|
|
||||||
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||||
if (!denseValue.isSplat())
|
return denseValue.isSplat();
|
||||||
return failure();
|
}
|
||||||
value = DenseElementsAttr::get(resType,
|
return value.isa<FloatAttr, IntegerAttr>();
|
||||||
denseValue.getSplatValue<Attribute>());
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
||||||
|
Value bcast_res) {
|
||||||
|
|
||||||
|
Type resType = bcast_res.getType();
|
||||||
|
DenseElementsAttr res;
|
||||||
|
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||||
|
res =
|
||||||
|
DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
|
||||||
} else {
|
} else {
|
||||||
if (!value.isa<FloatAttr, IntegerAttr>())
|
res = DenseElementsAttr::get(resType, value);
|
||||||
return failure();
|
|
||||||
value = DenseElementsAttr::get(resType, value);
|
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
|
return res;
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return failure();
|
#include "TritonCombine.inc"
|
||||||
}
|
|
||||||
};
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
#define GEN_PASS_CLASSES
|
||||||
@@ -162,11 +62,15 @@ public:
|
|||||||
mlir::RewritePatternSet patterns(context);
|
mlir::RewritePatternSet patterns(context);
|
||||||
mlir::ModuleOp m = getOperation();
|
mlir::ModuleOp m = getOperation();
|
||||||
|
|
||||||
patterns.add<CombineDotOp>(context);
|
// Dot Add %{
|
||||||
patterns.add<CombineSelectMaskedLoadOp>(context);
|
patterns.add<CombineDotAddIPattern>(context);
|
||||||
patterns.add<CombineGEPOp>(context);
|
patterns.add<CombineDotAddFPattern>(context);
|
||||||
patterns.add<CombineBroadcastConstantOp>(context);
|
patterns.add<CombineDotAddIRevPattern>(context);
|
||||||
// patterns.add<CombineReduceOp>(context);
|
patterns.add<CombineDotAddFRevPattern>(context);
|
||||||
|
// %}
|
||||||
|
patterns.add<CombineSelectMaskedLoadPattern>(context);
|
||||||
|
patterns.add<CombineGEPPattern>(context);
|
||||||
|
patterns.add<CombineBroadcastConstantPattern>(context);
|
||||||
|
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
53
lib/Dialect/Triton/Transforms/Combine.td
Normal file
53
lib/Dialect/Triton/Transforms/Combine.td
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#ifndef TRITON_PATTERNS
|
||||||
|
#define TRITON_PATTERNS
|
||||||
|
|
||||||
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
|
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
|
||||||
|
include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||||
|
|
||||||
|
|
||||||
|
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
|
||||||
|
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
|
||||||
|
|
||||||
|
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||||
|
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||||
|
def CombineDotAddIPattern : Pat<
|
||||||
|
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||||
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
|
def CombineDotAddFPattern : Pat<
|
||||||
|
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||||
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
|
|
||||||
|
def CombineDotAddIRevPattern : Pat<
|
||||||
|
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||||
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
|
def CombineDotAddFRevPattern : Pat<
|
||||||
|
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||||
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||||
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||||
|
|
||||||
|
|
||||||
|
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
|
||||||
|
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
||||||
|
// (ref: ArithmeticCanonicalization.td)
|
||||||
|
def CombineGEPPattern : Pat<
|
||||||
|
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
||||||
|
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||||
|
|
||||||
|
// select(cond, load(ptrs, broadcast(cond), ???), other)
|
||||||
|
// => load(ptrs, broadcast(cond), other)
|
||||||
|
def CombineSelectMaskedLoadPattern : Pat<
|
||||||
|
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
|
||||||
|
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
|
||||||
|
|
||||||
|
// broadcast(cst) => cst
|
||||||
|
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||||
|
def CombineBroadcastConstantPattern : Pat<
|
||||||
|
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
|
||||||
|
(Arith_ConstantOp (getConstantValue $value, $bcast_res)),
|
||||||
|
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;
|
||||||
|
|
||||||
|
#endif
|
68
test/Triton/combine.mlir
Normal file
68
test/Triton/combine.mlir
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine
|
||||||
|
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_combine_dot_add_pattern
|
||||||
|
func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
|
||||||
|
// CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
|
||||||
|
// CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
|
||||||
|
// CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
|
||||||
|
%a = arith.constant dense<1.0> : tensor<128x128xf32>
|
||||||
|
%b = arith.constant dense<2.0> : tensor<128x128xf32>
|
||||||
|
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
|
||||||
|
%d = arith.constant dense<3.0> : tensor<128x128xf32>
|
||||||
|
|
||||||
|
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||||
|
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||||
|
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||||
|
|
||||||
|
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_combine_gep_pattern
|
||||||
|
func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||||
|
%off0 = arith.constant 10 : i32
|
||||||
|
%off1 = arith.constant 15 : i32
|
||||||
|
|
||||||
|
// 10 + 15 = 25
|
||||||
|
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
|
||||||
|
|
||||||
|
%base_ = tt.broadcast %base : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
|
||||||
|
|
||||||
|
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
|
||||||
|
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||||
|
%ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||||
|
%ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||||
|
|
||||||
|
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_combine_select_masked_load_pattern
|
||||||
|
func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> tensor<8xf32> {
|
||||||
|
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
|
||||||
|
%false_val = arith.constant dense<0.0> : tensor<8xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[res:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||||
|
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
|
||||||
|
%0 = select %cond, %x, %false_val : tensor<8xf32>
|
||||||
|
|
||||||
|
// CHECK: return %[[res]] : tensor<8xf32>
|
||||||
|
return %0 : tensor<8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
|
||||||
|
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
|
||||||
|
// CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<8x2xf32>
|
||||||
|
%const = arith.constant dense<1.0> : tensor<8xf32>
|
||||||
|
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: return %[[cst]] : tensor<8x2xf32>
|
||||||
|
return %bst_out : tensor<8x2xf32>
|
||||||
|
}
|
Reference in New Issue
Block a user