[TritonIR] Convert Triton dialect's Combine pass to MLIR DRR based (#16)

This commit is contained in:
Yan Chunwei
2022-07-28 03:50:08 +08:00
committed by GitHub
parent 432c3df265
commit e02c82c765
4 changed files with 162 additions and 132 deletions

View File

@@ -0,0 +1,53 @@
#ifndef TRITON_PATTERNS
#define TRITON_PATTERNS
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
// (ref: ArithmeticCanonicalization.td)
def CombineGEPPattern : Pat<
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
def CombineSelectMaskedLoadPattern : Pat<
(SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue),
(TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
def CombineBroadcastConstantPattern : Pat<
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
(Arith_ConstantOp (getConstantValue $value, $bcast_res)),
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;
#endif