.
This commit is contained in:
@@ -41,13 +41,6 @@ public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
// Move convert(load) immediately after dependent load
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||
auto load = dyn_cast<triton::LoadOp>(op.getOperand().getDefiningOp());
|
||||
if(!load)
|
||||
return;
|
||||
op->moveAfter(load);
|
||||
});
|
||||
// Sink conversions into loops when they will increase
|
||||
// register pressure
|
||||
DenseMap<Operation*, Operation *> opToMove;
|
||||
@@ -62,7 +55,18 @@ public:
|
||||
});
|
||||
for(auto &kv: opToMove)
|
||||
kv.first->moveBefore(kv.second);
|
||||
|
||||
// Move convert(load) immediately after dependent load
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||
auto dstType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto dstEncoding = dstType.getEncoding();
|
||||
if(!dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return;
|
||||
Operation* argOp = op.getOperand().getDefiningOp();
|
||||
if(!argOp)
|
||||
return;
|
||||
llvm::outs() << "moving " << *op << "\n";
|
||||
op->moveAfter(argOp);
|
||||
});
|
||||
// Move transpositions just after their definition
|
||||
opToMove.clear();
|
||||
m.walk([&](triton::TransOp op){
|
||||
|
Reference in New Issue
Block a user