more optimizations

This commit is contained in:
Phil Tillet
2023-01-06 20:27:49 -08:00
parent 18c7a72973
commit 600bcefb12
4 changed files with 262 additions and 28 deletions

View File

@@ -27,7 +27,32 @@ class TritonGPUDecomposeConversionsToDotOperandPass
public:
TritonGPUDecomposeConversionsToDotOperandPass() = default;
void runOnOperation() override { return; }
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcBlocked =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
};
std::unique_ptr<Pass>

View File

@@ -21,6 +21,17 @@
using namespace mlir;
static inline bool willIncreaseRegisterPressure(triton::gpu::ConvertLayoutOp op) {
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
auto dstType = op.getResult().getType().cast<RankedTensorType>();
auto srcEncoding = srcType.getEncoding();
auto dstEncoding = dstType.getEncoding();
if(srcEncoding.isa<triton::gpu::SharedEncodingAttr>())
return true;
if(dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return true;
return false;
}
class TritonGPUSinkConversionsFromSharedPass
: public TritonGPUSinkConversionsFromSharedBase<TritonGPUSinkConversionsFromSharedPass> {
@@ -28,6 +39,31 @@ public:
TritonGPUSinkConversionsFromSharedPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
// Move convert(load) immediately after dependent load
m.walk([&](triton::gpu::ConvertLayoutOp op){
auto load = dyn_cast<triton::LoadOp>(op.getOperand().getDefiningOp());
if(!load)
return;
op->moveAfter(load);
});
// Sink conversions into loops when they will increase
// register pressure
DenseMap<triton::gpu::ConvertLayoutOp, Operation *> opToMove;
m.walk([&](triton::gpu::ConvertLayoutOp op){
if(!willIncreaseRegisterPressure(op))
return;
auto user_begin = op->user_begin();
auto user_end = op->user_end();
if(std::distance(user_begin, user_end) != 1)
return;
opToMove.insert({op, *user_begin});
});
for(auto &kv: opToMove)
kv.first->moveBefore(kv.second);
return;
}
};