[CI] run clang-format (#24)
This commit is contained in:
@@ -8,24 +8,23 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y){
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y) {
|
||||
// Base Case
|
||||
if (a == 0) {
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
int gcd = gcd_impl(b % a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*x = y1 - (b / a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
@@ -35,17 +34,17 @@ static int gcd(int a, int b) {
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
size_t rank = 1;
|
||||
if(TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
if (TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
rank = ty.getRank();
|
||||
int divHint = 1;
|
||||
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
|
||||
Operation* op = blockArg.getOwner()->getParentOp();
|
||||
if(FuncOp fun = dyn_cast<FuncOp>(op)){
|
||||
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if(attr)
|
||||
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
||||
Attribute attr =
|
||||
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if (attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
@@ -55,51 +54,51 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) {
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
ContiguityT retContiguity;
|
||||
DivisibilityT retDivisibility;
|
||||
ConstancyT retConstancy;
|
||||
for(size_t d = 0; d < lhs.getRank(); d++){
|
||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retDivisibility.push_back(
|
||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||
}
|
||||
return AxisInfo(retContiguity, retDivisibility, retConstancy);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfoAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for(size_t d = 0; d < rank; d++){
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo curr;
|
||||
// This preserves the input axes (e.g., cast):
|
||||
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||
triton::PtrToIntOp, triton::IntToPtrOp>(op))
|
||||
curr = operands[0]->getValue();
|
||||
// Constant ranges
|
||||
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
|
||||
if (triton::MakeRangeOp make_range =
|
||||
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::ContiguityT contiguity = {end - start};
|
||||
@@ -108,61 +107,59 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Constant
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
|
||||
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
|
||||
if(intAttr){
|
||||
if (intAttr) {
|
||||
size_t val = intAttr.getValue().getZExtValue();
|
||||
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if(splatAttr && splatAttr.getElementType().isInteger(32)){
|
||||
if (splatAttr && splatAttr.getElementType().isInteger(32)) {
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
|
||||
curr = AxisInfo(
|
||||
AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
};
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Multiplication
|
||||
if (llvm::isa<arith::MulIOp>(op)){
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return 1;
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
if (llvm::isa<arith::MulIOp>(op)) {
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; };
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return lhs.getDivisibility(d) * rhs.getDivisibility(d);
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Splat
|
||||
if (llvm::isa<triton::SplatOp>(op)){
|
||||
if (llvm::isa<triton::SplatOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
@@ -171,7 +168,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
}
|
||||
// Reshape
|
||||
// TODO: Replace by `unsqueeze`
|
||||
if (llvm::isa<triton::ReshapeOp>(op)){
|
||||
if (llvm::isa<triton::ReshapeOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -184,20 +181,17 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
AxisInfo::ConstancyT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
if(retShape[d] == 1){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
if (retShape[d] == 1) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
}
|
||||
else if(!is_skewed
|
||||
&& retShape[d] == opShape[current]){
|
||||
} else if (!is_skewed && retShape[d] == opShape[current]) {
|
||||
contiguity.push_back(opInfo.getContiguity()[current]);
|
||||
divisibility.push_back(opInfo.getDivisibility()[current]);
|
||||
constancy.push_back(opInfo.getConstancy()[current]);
|
||||
current++;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
is_skewed = true;
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
@@ -207,7 +201,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Broadcast
|
||||
if (llvm::isa<triton::BroadcastOp>(op)){
|
||||
if (llvm::isa<triton::BroadcastOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -218,14 +212,14 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
if(curr.getRank() == 0){
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
@@ -236,4 +230,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace mlir
|
Reference in New Issue
Block a user