more optimizations
This commit is contained in:
@@ -27,7 +27,32 @@ class TritonGPUDecomposeConversionsToDotOperandPass
|
||||
public:
|
||||
TritonGPUDecomposeConversionsToDotOperandPass() = default;
|
||||
|
||||
void runOnOperation() override { return; }
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcBlocked && dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||
getOrder(srcBlocked), srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
|
@@ -21,6 +21,17 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static inline bool willIncreaseRegisterPressure(triton::gpu::ConvertLayoutOp op) {
|
||||
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
auto dstEncoding = dstType.getEncoding();
|
||||
if(srcEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return true;
|
||||
if(dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
class TritonGPUSinkConversionsFromSharedPass
|
||||
: public TritonGPUSinkConversionsFromSharedBase<TritonGPUSinkConversionsFromSharedPass> {
|
||||
@@ -28,6 +39,31 @@ public:
|
||||
TritonGPUSinkConversionsFromSharedPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
// Move convert(load) immediately after dependent load
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||
auto load = dyn_cast<triton::LoadOp>(op.getOperand().getDefiningOp());
|
||||
if(!load)
|
||||
return;
|
||||
op->moveAfter(load);
|
||||
});
|
||||
// Sink conversions into loops when they will increase
|
||||
// register pressure
|
||||
DenseMap<triton::gpu::ConvertLayoutOp, Operation *> opToMove;
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||
if(!willIncreaseRegisterPressure(op))
|
||||
return;
|
||||
auto user_begin = op->user_begin();
|
||||
auto user_end = op->user_end();
|
||||
if(std::distance(user_begin, user_end) != 1)
|
||||
return;
|
||||
opToMove.insert({op, *user_begin});
|
||||
});
|
||||
for(auto &kv: opToMove)
|
||||
kv.first->moveBefore(kv.second);
|
||||
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user