TritonGPU combiner
This commit is contained in:
50
lib/Dialect/TritonGPU/Transforms/Combine.cpp
Normal file
50
lib/Dialect/TritonGPU/Transforms/Combine.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static bool isSharedLayout(Value v) {
|
||||
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
return encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "TritonGPUCombine.inc"
|
||||
}
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<ConvertLayoutOptPattern>(context);
|
||||
patterns.add<CopyAsyncOptPattern>(context);
|
||||
patterns.add<RedundantConvertLayoutOptPattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
Reference in New Issue
Block a user