[OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834)

This commit is contained in:
Philippe Tillet
2022-11-02 22:58:09 -07:00
committed by GitHub
parent 847a318a03
commit 91a9773b38
4 changed files with 59 additions and 29 deletions

View File

@@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// TODO: All other binary ops
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Splat
if (llvm::isa<triton::SplatOp>(op)) {
Type _retTy = *op->result_type_begin();
@@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
for (int d = 0; d < retTy.getRank(); ++d) {
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
constancy.push_back(opShape[d] == 1 ? retShape[d]
: opInfo.getConstancy(d));
}
curr = AxisInfo(contiguity, divisibility, constancy);
}