[BACKEND] Added support for 1D conversion blocked -> slice (#831)
This commit is contained in:
@@ -11,7 +11,9 @@
|
||||
#include <numeric>
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -34,28 +36,16 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
|
||||
srcLayout = srcSliceLayout.getParent();
|
||||
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
|
||||
dstLayout = dstSliceLayout.getParent();
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||
assert((srcBlockedLayout || srcMmaLayout) &&
|
||||
"Unexpected srcLayout in getScratchConfigForCvtLayout");
|
||||
assert((dstBlockedLayout || dstMmaLayout) &&
|
||||
"Unexpected dstLayout in getScratchConfigForCvtLayout");
|
||||
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||
"Unexpected mma -> mma layout conversion");
|
||||
auto inOrd =
|
||||
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
|
||||
auto outOrd =
|
||||
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
|
||||
unsigned srcContigPerThread =
|
||||
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
|
||||
unsigned dstContigPerThread =
|
||||
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
|
||||
auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout);
|
||||
auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout);
|
||||
unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
@@ -70,6 +60,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
|
||||
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
|
||||
}
|
||||
if (rank == 1)
|
||||
return paddedRepShape;
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||
|
Reference in New Issue
Block a user