[BACKEND] Add generic support of convert_layout from distributed to shared (#1025)

This commit is contained in:
goostavz
2022-12-31 03:29:58 +08:00
committed by GitHub
parent 194ba103b1
commit 0e8590f1c9
5 changed files with 68 additions and 110 deletions

View File

@@ -39,6 +39,8 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout); SmallVector<unsigned> getOrder(const Attribute &layout);
bool isaDistributedLayout(const Attribute &layout);
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -9,10 +9,12 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
@@ -24,108 +26,63 @@ bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout; dotOperandLayout.getParent() == mmaLayout;
} }
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides, void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> srcIndices, Value dst, Value smemBase, ArrayRef<Value> dstStrides,
Type elemTy, Location loc, ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
auto srcTy = src.getType().cast<RankedTensorType>(); auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 && "Unexpected rank of insertSlice"); assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared");
auto dstTy = dst.getType().cast<RankedTensorType>(); auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcDistributedLayout = srcTy.getEncoding();
if (auto mmaLayout = srcDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout MMAv1->Shared is not suppported yet");
}
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>(); auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder(); auto inOrd = getOrder(srcDistributedLayout);
auto outOrd = dstSharedLayout.getOrder(); auto outOrd = dstSharedLayout.getOrder();
if (inOrd != outOrd)
llvm_unreachable(
"blocked -> shared with different order not yet implemented");
unsigned inVec = unsigned inVec =
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1;
unsigned outVec = dstSharedLayout.getVec(); unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec); unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase(); unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcTy); unsigned numElems = getElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals = getElementsFromStruct(loc, llSrc, rewriter); auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
auto srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
auto wordTy = vec_ty(elemTy, minVec); auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy); auto elemPtrTy = ptr_ty(elemTy);
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
// Visit each input value in the order they are placed in inVals
//
// Please note that the order was not awaring of blockLayout.getOrder(),
// thus the adjacent elems may not belong to a same word. This could be
// improved if we update the elements order by emitIndicesForBlockedLayout()
SmallVector<unsigned> wordsInEachRep(2);
wordsInEachRep[0] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[0] / minVec
: srcBlockedLayout.getSizePerThread()[0];
wordsInEachRep[1] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[1]
: srcBlockedLayout.getSizePerThread()[1] / minVec;
Value outVecVal = i32_val(outVec); Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec); Value minVecVal = i32_val(minVec);
auto numWordsEachRep = product<unsigned>(wordsInEachRep); Value word;
SmallVector<Value> wordVecs(numWordsEachRep);
for (unsigned i = 0; i < numElems; ++i) { for (unsigned i = 0; i < numElems; ++i) {
if (i % srcAccumSizeInThreads == 0) { if (i % minVec == 0)
// start of a replication word = undef(wordTy);
for (unsigned w = 0; w < numWordsEachRep; ++w) { word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec));
wordVecs[w] = undef(wordTy); if (i % minVec == minVec - 1) {
}
}
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
auto wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos));
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
// end of replication, store the vectors into shared memory
unsigned linearRepIdx = i / srcAccumSizeInThreads;
auto multiDimRepIdx =
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) {
// step 1: recover the multidim_index from the index of // step 1: recover the multidim_index from the index of
// input_elements SmallVector<Value> multiDimIdx = srcIndices[i];
auto multiDimWordIdx = SmallVector<Value> dbgVal = srcIndices[i];
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
SmallVector<Value> multiDimIdx(2);
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0));
multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1));
// step 2: do swizzling // step 2: do swizzling
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]); Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
phaseId = urem(phaseId, i32_val(maxPhase)); phaseId = urem(phaseId, i32_val(maxPhase));
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal); off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal); remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal)); off_0 = add(off_0, mul(remained, minVecVal));
Value offset = add(off_1, off_0); Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
// step 3: store // step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset); Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
store(wordVecs[linearWordIdx], smemAddr); store(word, smemAddr);
}
} }
} }
} }
@@ -145,20 +102,15 @@ public:
auto dstTy = dst.getType().cast<RankedTensorType>(); auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding(); Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding();
if (srcLayout.isa<BlockedEncodingAttr>() && if (isaDistributedLayout(srcLayout) &&
dstLayout.isa<SharedEncodingAttr>()) { dstLayout.isa<SharedEncodingAttr>()) {
return lowerBlockedToShared(op, adaptor, rewriter); return lowerDistributedToShared(op, adaptor, rewriter);
} }
if (srcLayout.isa<SharedEncodingAttr>() && if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) { dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter); return lowerSharedToDotOperand(op, adaptor, rewriter);
} }
if ((srcLayout.isa<BlockedEncodingAttr>() || if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
srcLayout.isa<MmaEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>()) &&
(dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>())) {
return lowerDistributedToDistributed(op, adaptor, rewriter); return lowerDistributedToDistributed(op, adaptor, rewriter);
} }
if (srcLayout.isa<MmaEncodingAttr>() && if (srcLayout.isa<MmaEncodingAttr>() &&
@@ -476,7 +428,7 @@ private:
// Swizzling in shared memory to avoid bank conflict. Normally used for // Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots. // A/B operands of dots.
LogicalResult LogicalResult
lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc(); auto loc = op.getLoc();
Value src = op.src(); Value src = op.src();
@@ -487,22 +439,20 @@ private:
auto dstShape = dstTy.getShape(); auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 && assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)"); "Unexpected rank of ConvertLayout(blocked->shared)");
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcLayout = srcTy.getEncoding();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>(); auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder(); auto inOrd = getOrder(srcLayout);
auto outOrd = dstSharedLayout.getOrder(); auto outOrd = dstSharedLayout.getOrder();
Value smemBase = getSharedMemoryBase(loc, rewriter, dst); Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy); smemBase = bitcast(smemBase, elemPtrTy);
auto srcStrides = auto dstStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape); storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst,
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter); smemBase, elemTy, loc, rewriter);
auto smemObj = auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);

View File

@@ -11,9 +11,11 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout); DotOperandEncodingAttr &dotOperandLayout);
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides, void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> srcIndices, Value dst, Value smemBase, ArrayRef<Value> srcStrides,
Type elemPtrTy, Location loc, ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemPtrTy,
Location loc,
ConversionPatternRewriter &rewriter); ConversionPatternRewriter &rewriter);
void populateConvertLayoutOpToLLVMPatterns( void populateConvertLayoutOpToLLVMPatterns(

View File

@@ -639,9 +639,8 @@ struct InsertSliceOpConversion
auto smemBase = gep(elemPtrTy, smemObj.base, offset); auto smemBase = gep(elemPtrTy, smemObj.base, offset);
auto llSrc = adaptor.source(); auto llSrc = adaptor.source();
auto srcIndices = auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape); storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
elemTy, loc, rewriter); elemTy, loc, rewriter);
// Barrier is not necessary. // Barrier is not necessary.
// The membar pass knows that it writes to shared memory and will handle it // The membar pass knows that it writes to shared memory and will handle it

View File

@@ -254,6 +254,11 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
} }
}; };
bool isaDistributedLayout(const Attribute &layout) {
return layout.isa<BlockedEncodingAttr>() || layout.isa<MmaEncodingAttr>() ||
layout.isa<SliceEncodingAttr>();
}
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir