[OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834)
This commit is contained in:
@@ -159,6 +159,16 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// TODO: All other binary ops
|
||||
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Splat
|
||||
if (llvm::isa<triton::SplatOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
@@ -200,7 +210,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
for (int d = 0; d < retTy.getRank(); ++d) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d]
|
||||
: opInfo.getConstancy(d));
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
Reference in New Issue
Block a user