[Triton-MLIR][BACKEND] Refine dot conversion (#710)
This PR does 1. Refine the dot conversion 2. some other tiny code refinement
This commit is contained in:
@@ -56,11 +56,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
||||
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] = std::max(
|
||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
||||
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
|
||||
paddedRepShape[d] =
|
||||
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||
}
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
|
@@ -65,7 +65,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
DimVectorT retContiguity;
|
||||
DimVectorT retDivisibility;
|
||||
DimVectorT retConstancy;
|
||||
for (size_t d = 0; d < lhs.getRank(); d++) {
|
||||
for (size_t d = 0; d < lhs.getRank(); ++d) {
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(
|
||||
gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
@@ -87,7 +87,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp(
|
||||
AxisInfo::DimVectorT newContiguity;
|
||||
AxisInfo::DimVectorT newDivisibility;
|
||||
AxisInfo::DimVectorT newConstancy;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
for (size_t d = 0; d < rank; ++d) {
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
@@ -166,7 +166,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
@@ -202,7 +202,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
AxisInfo::DimVectorT contiguity;
|
||||
AxisInfo::DimVectorT divisibility;
|
||||
AxisInfo::DimVectorT constancy;
|
||||
for (size_t d = 0; d < retTy.getRank(); d++) {
|
||||
for (size_t d = 0; d < retTy.getRank(); ++d) {
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
|
Reference in New Issue
Block a user