[TritonGPU] Improved documentation and semantics of layout encodings (#30)
This commit is contained in:
@@ -48,17 +48,17 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
ContiguityT contiguity(rank, 1);
|
||||
DivisibilityT divisibility(rank, divHint);
|
||||
ConstancyT constancy(rank, 1);
|
||||
DimVectorT contiguity(rank, 1);
|
||||
DimVectorT divisibility(rank, divHint);
|
||||
DimVectorT constancy(rank, 1);
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
ContiguityT retContiguity;
|
||||
DivisibilityT retDivisibility;
|
||||
ConstancyT retConstancy;
|
||||
DimVectorT retContiguity;
|
||||
DimVectorT retDivisibility;
|
||||
DimVectorT retConstancy;
|
||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(
|
||||
@@ -78,9 +78,9 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
AxisInfo::DimVectorT newContiguity;
|
||||
AxisInfo::DimVectorT newDivisibility;
|
||||
AxisInfo::DimVectorT newConstancy;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
@@ -101,9 +101,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
llvm::dyn_cast<triton::MakeRangeOp>(op)) {
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::ContiguityT contiguity = {end - start};
|
||||
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)};
|
||||
AxisInfo::ConstancyT constancy = {1};
|
||||
AxisInfo::DimVectorT contiguity = {end - start};
|
||||
AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)};
|
||||
AxisInfo::DimVectorT constancy = {1};
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Constant
|
||||
@@ -119,9 +119,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(
|
||||
AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
AxisInfo::DimVectorT(ty.getRank(), 1),
|
||||
AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()));
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
@@ -156,9 +156,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
@@ -167,8 +167,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Reshape
|
||||
// TODO: Replace by `unsqueeze`
|
||||
if (llvm::isa<triton::ReshapeOp>(op)) {
|
||||
if (llvm::isa<triton::ViewOp>(op)) {
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
@@ -176,9 +175,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
@@ -209,9 +208,9 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
|
Reference in New Issue
Block a user