[FRONTEND][BACKEND] Small fixes to multiple_of, num_programs, axisinfo; enable block-sparse tests (#927)
This commit is contained in:
@@ -132,6 +132,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// TODO: refactor & complete binary ops
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
@@ -159,6 +160,20 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Remainder
|
||||
if (llvm::isa<arith::RemSIOp, arith::RemUIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getContiguity(d), rhs.getDivisibility(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// TODO: All other binary ops
|
||||
if (llvm::isa<arith::AndIOp, arith::OrIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
|
Reference in New Issue
Block a user