[Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693)
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include <numeric>
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -32,39 +33,40 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
// TODO: move to TritonGPUAttrDefs.h.inc
|
||||
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>();
|
||||
auto inOrd = srcBlockedLayout.getOrder();
|
||||
auto outOrd = dstBlockedLayout.getOrder();
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1
|
||||
: inOrd[0] == 0 ? 1
|
||||
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
|
||||
outVec =
|
||||
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] = std::max(
|
||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
||||
std::min<unsigned>(dstTy.getShape()[d],
|
||||
getShapePerCTA(dstLayout, d)));
|
||||
}
|
||||
paddedRepShape[outOrd[0]] += pad;
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
assert((srcBlockedLayout || srcMmaLayout) &&
|
||||
"Unexpected srcLayout in getScratchConfigForCvtLayout");
|
||||
assert((dstBlockedLayout || dstMmaLayout) &&
|
||||
"Unexpected dstLayout in getScratchConfigForCvtLayout");
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
"Unexpected mma -> mma layout conversion");
|
||||
auto inOrd =
|
||||
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
|
||||
auto outOrd =
|
||||
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
|
||||
unsigned srcContigPerThread =
|
||||
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
|
||||
unsigned dstContigPerThread =
|
||||
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] = std::max(
|
||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
||||
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
|
||||
}
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||
}
|
||||
paddedRepShape[paddedDim] += pad;
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user