[BACKEND] Added support for 1D conversion blocked -> slice (#831)

This commit is contained in:
Philippe Tillet
2022-11-01 13:19:58 -07:00
committed by GitHub
parent c9d84237e8
commit 12d60cb4a3
5 changed files with 103 additions and 78 deletions

View File

@@ -11,7 +11,9 @@
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
srcLayout = srcSliceLayout.getParent();
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
dstLayout = dstSliceLayout.getParent();
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
assert((srcBlockedLayout || srcMmaLayout) &&
"Unexpected srcLayout in getScratchConfigForCvtLayout");
assert((dstBlockedLayout || dstMmaLayout) &&
"Unexpected dstLayout in getScratchConfigForCvtLayout");
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
auto inOrd =
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
auto outOrd =
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
unsigned srcContigPerThread =
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
unsigned dstContigPerThread =
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
}
if (rank == 1)
return paddedRepShape;
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];

View File

@@ -1197,7 +1197,7 @@ struct BroadcastOpConversion
broadcastDims.push_back(d);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] =
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
std::max<unsigned>(1, srcLayout.getSizePerThread()[d]);
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
@@ -2231,6 +2231,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
}
barrier();
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,

View File

@@ -68,18 +68,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<unsigned> sizePerThread(
blockedParent.getSizePerThread().begin(),
blockedParent.getSizePerThread().end());
sizePerThread.erase(sizePerThread.begin() + dim);
return sizePerThread;
} else {
assert(0 && "SliceEncodingAttr with parent other than "
"BlockedEncodingAttr not implemented");
}
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
@@ -144,6 +133,19 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
SmallVector<unsigned> order;
for (unsigned d : parentOrder) {
if (d == dim)
continue;
else if (d > dim)
order.push_back(d - 1);
else
order.push_back(d);
}
return order;
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());