51 lines
1.4 KiB
C++
51 lines
1.4 KiB
C++
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
|
|
#include <memory>
|
|
|
|
using namespace mlir;
|
|
|
|
static bool isSharedLayout(Value v) {
|
|
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
|
|
Attribute encoding = tensorType.getEncoding();
|
|
return encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
namespace {
|
|
#include "TritonGPUCombine.inc"
|
|
}
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
class TritonGPUCombineOpsPass
|
|
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
|
public:
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp m = getOperation();
|
|
|
|
mlir::RewritePatternSet patterns(context);
|
|
|
|
patterns.add<ConvertLayoutOptPattern>(context);
|
|
patterns.add<CopyAsyncOptPattern>(context);
|
|
patterns.add<RedundantConvertLayoutOptPattern>(context);
|
|
|
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
|
|
return std::make_unique<TritonGPUCombineOpsPass>();
|
|
}
|