[Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693)

This commit is contained in:
goostavz
2022-09-27 11:58:47 +08:00
committed by GitHub
parent 1e91ed30d0
commit 61b61755e5
6 changed files with 205 additions and 104 deletions

View File

@@ -11,6 +11,7 @@
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -32,39 +33,40 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
// TODO: move to TritonGPUAttrDefs.h.inc
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d];
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
return 0;
}
};
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<BlockedEncodingAttr>()) {
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder();
auto outOrd = dstBlockedLayout.getOrder();
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1
: inOrd[0] == 0 ? 1
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
outVec =
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] = std::max(
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
std::min<unsigned>(dstTy.getShape()[d],
getShapePerCTA(dstLayout, d)));
}
paddedRepShape[outOrd[0]] += pad;
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
assert((srcBlockedLayout || srcMmaLayout) &&
"Unexpected srcLayout in getScratchConfigForCvtLayout");
assert((dstBlockedLayout || dstMmaLayout) &&
"Unexpected dstLayout in getScratchConfigForCvtLayout");
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
auto inOrd =
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
auto outOrd =
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
unsigned srcContigPerThread =
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
unsigned dstContigPerThread =
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] = std::max(
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
}
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];
}
paddedRepShape[paddedDim] += pad;
return paddedRepShape;
}