TritonGPU combiner

This commit is contained in:
Yan Da
2022-05-16 19:17:15 +08:00
parent e3916c3a46
commit c3c4ac3733
3 changed files with 134 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
#ifndef TRITONGPU_PATTERNS
#define TRITONGPU_PATTERNS
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
// convert_layout(load(...), #L) => copy_async(...); barrier
// if #L is smem_layout
def CopyAsyncOptPattern : Pat<
(TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile)),
(TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile),
[(Constraint<CPred<"isSharedLayout($0)">> $res)]>;
// ConvertLayout(ConvertLayout(x, #L0), #L1) => ConvertLayout(x, #L1)
def ConvertLayoutOptPattern : Pat<
(TTG_ConvertLayoutOp (TTG_ConvertLayoutOp $x)),
(TTG_ConvertLayoutOp $x)>;
// TODO: can we replace this with ConvertLayoutOp's folder?
// ConvertLayout(x, #L) => x if x.layout() == #L
def RedundantConvertLayoutOptPattern : Pat<
(TTG_ConvertLayoutOp:$res $x), (replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
#endif