2022-07-28 03:50:08 +08:00
|
|
|
#ifndef TRITON_PATTERNS
|
|
|
|
#define TRITON_PATTERNS
|
|
|
|
|
|
|
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
|
|
|
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
|
|
|
|
include "triton/Dialect/Triton/IR/TritonOps.td"
|
|
|
|
|
|
|
|
|
|
|
|
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
|
|
|
|
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
|
|
|
|
|
|
|
|
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
|
|
|
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
|
|
|
def CombineDotAddIPattern : Pat<
|
2022-12-06 09:32:13 -08:00
|
|
|
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
|
|
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
2022-07-28 03:50:08 +08:00
|
|
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
|
|
|
def CombineDotAddFPattern : Pat<
|
2022-12-06 09:32:13 -08:00
|
|
|
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
|
|
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
2022-07-28 03:50:08 +08:00
|
|
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
|
|
|
|
|
|
|
def CombineDotAddIRevPattern : Pat<
|
2022-12-06 09:32:13 -08:00
|
|
|
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
|
|
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
2022-07-28 03:50:08 +08:00
|
|
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
|
|
|
def CombineDotAddFRevPattern : Pat<
|
2022-12-06 09:32:13 -08:00
|
|
|
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
|
|
|
(TT_DotOp $a, $b, $d, $allowTF32),
|
2022-07-28 03:50:08 +08:00
|
|
|
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
|
|
|
|
2022-12-06 23:29:50 -08:00
|
|
|
// TODO: this fails for addptr(addptr(ptr, i32), i64)
|
|
|
|
// Commented out until fixed
|
2022-09-15 16:12:52 -07:00
|
|
|
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
|
2022-07-28 03:50:08 +08:00
|
|
|
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
|
|
|
// (ref: ArithmeticCanonicalization.td)
|
2022-12-06 23:29:50 -08:00
|
|
|
// def CombineAddPtrPattern : Pat<
|
|
|
|
// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
|
|
|
|
// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
2022-07-28 03:50:08 +08:00
|
|
|
|
|
|
|
// broadcast(cst) => cst
|
|
|
|
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
|
|
|
def CombineBroadcastConstantPattern : Pat<
|
|
|
|
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
|
|
|
|
(Arith_ConstantOp (getConstantValue $value, $bcast_res)),
|
|
|
|
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;
|
|
|
|
|
|
|
|
#endif
|