TritonGPU combiner

This commit is contained in:
Yan Da
2022-05-16 19:17:15 +08:00
parent e3916c3a46
commit c3c4ac3733
3 changed files with 134 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
#ifndef TRITON_ATTR_DEFS
#define TRITON_ATTR_DEFS
include "mlir/IR/EnumAttr.td"
// Attrs for LoadOp
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {
let cppNamespace = "::mlir::triton";
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"MAX", 2, "max">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"FADD", 4, "fadd">,
I32EnumAttrCase<"FMAX", 5, "fmax">,
I32EnumAttrCase<"FMIN", 6, "fmin">,
I32EnumAttrCase<"XOR", 7, "xor">
]> {
let cppNamespace = "::mlir::triton";
}
// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
[
I32EnumAttrCase<"AND", 1, "and">,
I32EnumAttrCase<"OR", 2, "or">,
I32EnumAttrCase<"XOR", 3, "xor">,
I32EnumAttrCase<"ADD", 4, "add">,
I32EnumAttrCase<"FADD", 5, "fadd">,
I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin">
]> {
let cppNamespace = "::mlir::triton";
}
#endif

View File

@@ -0,0 +1,50 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
static bool isSharedLayout(Value v) {
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
return encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>();
}
return false;
}
namespace {
#include "TritonGPUCombine.inc"
}
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<ConvertLayoutOptPattern>(context);
patterns.add<CopyAsyncOptPattern>(context);
patterns.add<RedundantConvertLayoutOptPattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}

View File

@@ -0,0 +1,25 @@
#ifndef TRITONGPU_PATTERNS
#define TRITONGPU_PATTERNS
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
// convert_layout(load(...), #L) => copy_async(...); barrier
// if #L is smem_layout
def CopyAsyncOptPattern : Pat<
(TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile)),
(TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile),
[(Constraint<CPred<"isSharedLayout($0)">> $res)]>;
// ConvertLayout(ConvertLayout(x, #L0), #L1) => ConvertLayout(x, #L1)
def ConvertLayoutOptPattern : Pat<
(TTG_ConvertLayoutOp (TTG_ConvertLayoutOp $x)),
(TTG_ConvertLayoutOp $x)>;
// TODO: can we replace this with ConvertLayoutOp's folder?
// ConvertLayout(x, #L) => x if x.layout() == #L
def RedundantConvertLayoutOptPattern : Pat<
(TTG_ConvertLayoutOp:$res $x), (replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
#endif