Files
triton/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp
2023-01-06 20:27:49 -08:00

75 lines
2.5 KiB
C++

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
using namespace mlir;
static inline bool willIncreaseRegisterPressure(triton::gpu::ConvertLayoutOp op) {
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
auto dstType = op.getResult().getType().cast<RankedTensorType>();
auto srcEncoding = srcType.getEncoding();
auto dstEncoding = dstType.getEncoding();
if(srcEncoding.isa<triton::gpu::SharedEncodingAttr>())
return true;
if(dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return true;
return false;
}
class TritonGPUSinkConversionsFromSharedPass
: public TritonGPUSinkConversionsFromSharedBase<TritonGPUSinkConversionsFromSharedPass> {
public:
TritonGPUSinkConversionsFromSharedPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
// Move convert(load) immediately after dependent load
m.walk([&](triton::gpu::ConvertLayoutOp op){
auto load = dyn_cast<triton::LoadOp>(op.getOperand().getDefiningOp());
if(!load)
return;
op->moveAfter(load);
});
// Sink conversions into loops when they will increase
// register pressure
DenseMap<triton::gpu::ConvertLayoutOp, Operation *> opToMove;
m.walk([&](triton::gpu::ConvertLayoutOp op){
if(!willIncreaseRegisterPressure(op))
return;
auto user_begin = op->user_begin();
auto user_end = op->user_end();
if(std::distance(user_begin, user_end) != 1)
return;
opToMove.insert({op, *user_begin});
});
for(auto &kv: opToMove)
kv.first->moveBefore(kv.second);
return;
}
};
std::unique_ptr<Pass>
mlir::createTritonGPUSinkConversionsFromSharedPass() {
return std::make_unique<TritonGPUSinkConversionsFromSharedPass>();
}