[FRONTEND] Added ExpandDimsOp
primitive (#36)
This commit is contained in:
@@ -167,8 +167,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Reshape
|
||||
if (llvm::isa<triton::ViewOp>(op)) {
|
||||
// expandDims
|
||||
if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -176,28 +176,12 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
if (retShape[d] == 1) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
} else if (!is_skewed && retShape[d] == opShape[current]) {
|
||||
contiguity.push_back(opInfo.getContiguity()[current]);
|
||||
divisibility.push_back(opInfo.getDivisibility()[current]);
|
||||
constancy.push_back(opInfo.getConstancy()[current]);
|
||||
current++;
|
||||
} else {
|
||||
is_skewed = true;
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
}
|
||||
}
|
||||
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
||||
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
||||
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
|
||||
contiguity.insert(contiguity.begin() + expandDims.axis(), 1);
|
||||
divisibility.insert(divisibility.begin() + expandDims.axis(), 1);
|
||||
constancy.insert(constancy.begin() + expandDims.axis(), 1);
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Broadcast
|
||||
|
Reference in New Issue
Block a user