reorder conversions to dot operand
This commit is contained in:
@@ -64,7 +64,6 @@ public:
|
||||
Operation* argOp = op.getOperand().getDefiningOp();
|
||||
if(!argOp)
|
||||
return;
|
||||
llvm::outs() << "moving " << *op << "\n";
|
||||
op->moveAfter(argOp);
|
||||
});
|
||||
// Move transpositions just after their definition
|
||||
@@ -75,7 +74,20 @@ public:
|
||||
return;
|
||||
op->moveAfter(argOp);
|
||||
});
|
||||
|
||||
// Move `dot` operand so that conversions to opIdx=0 happens before conversions to opIdx=1
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op){
|
||||
auto dstType = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto dstEncoding = dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if(!dstEncoding)
|
||||
return;
|
||||
int opIdx = dstEncoding.getOpIdx();
|
||||
if(opIdx != 1)
|
||||
return;
|
||||
if(op->getUsers().empty())
|
||||
return;
|
||||
auto user_begin = op->user_begin();
|
||||
op->moveBefore(*user_begin);
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user