[BACKEND] Added support for 1D conversion blocked -> slice (#831)

This commit is contained in:
Philippe Tillet
2022-11-01 13:19:58 -07:00
committed by GitHub
parent c9d84237e8
commit 12d60cb4a3
5 changed files with 103 additions and 78 deletions

View File

@@ -11,7 +11,9 @@
#include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
srcLayout = srcSliceLayout.getParent();
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
dstLayout = dstSliceLayout.getParent();
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
assert((srcBlockedLayout || srcMmaLayout) &&
"Unexpected srcLayout in getScratchConfigForCvtLayout");
assert((dstBlockedLayout || dstMmaLayout) &&
"Unexpected dstLayout in getScratchConfigForCvtLayout");
assert(!(srcMmaLayout && dstMmaLayout) &&
"Unexpected mma -> mma layout conversion");
auto inOrd =
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
auto outOrd =
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
unsigned srcContigPerThread =
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
unsigned dstContigPerThread =
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
}
if (rank == 1)
return paddedRepShape;
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];