TritonGPU combiner

This commit is contained in:
Yan Da
2022-05-16 19:17:15 +08:00
parent e3916c3a46
commit c3c4ac3733
3 changed files with 134 additions and 0 deletions

View File

@@ -0,0 +1,50 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
static bool isSharedLayout(Value v) {
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
return encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>();
}
return false;
}
namespace {
#include "TritonGPUCombine.inc"
}
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<ConvertLayoutOptPattern>(context);
patterns.add<CopyAsyncOptPattern>(context);
patterns.add<RedundantConvertLayoutOptPattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}