|
|
|
@@ -9,10 +9,12 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
|
|
|
|
using ::mlir::LLVM::getStructFromElements;
|
|
|
|
|
using ::mlir::LLVM::MMA16816ConversionHelper;
|
|
|
|
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
|
|
|
|
using ::mlir::triton::gpu::getContigPerThread;
|
|
|
|
|
using ::mlir::triton::gpu::getElemsPerThread;
|
|
|
|
|
using ::mlir::triton::gpu::getOrder;
|
|
|
|
|
using ::mlir::triton::gpu::getShapePerCTA;
|
|
|
|
|
using ::mlir::triton::gpu::getSizePerThread;
|
|
|
|
|
using ::mlir::triton::gpu::isaDistributedLayout;
|
|
|
|
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
|
|
|
|
|
|
|
|
|
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
|
|
|
@@ -24,108 +26,63 @@ bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
|
|
|
|
dotOperandLayout.getParent() == mmaLayout;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
|
|
|
|
|
ArrayRef<Value> srcIndices, Value dst, Value smemBase,
|
|
|
|
|
Type elemTy, Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
void storeDistributedToShared(Value src, Value llSrc,
|
|
|
|
|
ArrayRef<Value> dstStrides,
|
|
|
|
|
ArrayRef<SmallVector<Value>> srcIndices,
|
|
|
|
|
Value dst, Value smemBase, Type elemTy,
|
|
|
|
|
Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcShape = srcTy.getShape();
|
|
|
|
|
assert(srcShape.size() == 2 && "Unexpected rank of insertSlice");
|
|
|
|
|
|
|
|
|
|
assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared");
|
|
|
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto srcDistributedLayout = srcTy.getEncoding();
|
|
|
|
|
if (auto mmaLayout = srcDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
|
|
|
|
|
assert((!mmaLayout.isVolta()) &&
|
|
|
|
|
"ConvertLayout MMAv1->Shared is not suppported yet");
|
|
|
|
|
}
|
|
|
|
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto inOrd = srcBlockedLayout.getOrder();
|
|
|
|
|
auto inOrd = getOrder(srcDistributedLayout);
|
|
|
|
|
auto outOrd = dstSharedLayout.getOrder();
|
|
|
|
|
if (inOrd != outOrd)
|
|
|
|
|
llvm_unreachable(
|
|
|
|
|
"blocked -> shared with different order not yet implemented");
|
|
|
|
|
unsigned inVec =
|
|
|
|
|
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
|
|
|
|
inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1;
|
|
|
|
|
unsigned outVec = dstSharedLayout.getVec();
|
|
|
|
|
unsigned minVec = std::min(outVec, inVec);
|
|
|
|
|
unsigned perPhase = dstSharedLayout.getPerPhase();
|
|
|
|
|
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
|
|
|
|
unsigned numElems = getElemsPerThread(srcTy);
|
|
|
|
|
assert(numElems == srcIndices.size());
|
|
|
|
|
auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
|
|
|
|
|
auto srcAccumSizeInThreads =
|
|
|
|
|
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
|
|
|
|
auto wordTy = vec_ty(elemTy, minVec);
|
|
|
|
|
auto elemPtrTy = ptr_ty(elemTy);
|
|
|
|
|
|
|
|
|
|
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
|
|
|
|
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
|
|
|
|
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
|
|
|
|
|
|
|
|
|
// Visit each input value in the order they are placed in inVals
|
|
|
|
|
//
|
|
|
|
|
// Please note that the order was not awaring of blockLayout.getOrder(),
|
|
|
|
|
// thus the adjacent elems may not belong to a same word. This could be
|
|
|
|
|
// improved if we update the elements order by emitIndicesForBlockedLayout()
|
|
|
|
|
SmallVector<unsigned> wordsInEachRep(2);
|
|
|
|
|
wordsInEachRep[0] = inOrd[0] == 0
|
|
|
|
|
? srcBlockedLayout.getSizePerThread()[0] / minVec
|
|
|
|
|
: srcBlockedLayout.getSizePerThread()[0];
|
|
|
|
|
wordsInEachRep[1] = inOrd[0] == 0
|
|
|
|
|
? srcBlockedLayout.getSizePerThread()[1]
|
|
|
|
|
: srcBlockedLayout.getSizePerThread()[1] / minVec;
|
|
|
|
|
Value outVecVal = i32_val(outVec);
|
|
|
|
|
Value minVecVal = i32_val(minVec);
|
|
|
|
|
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
|
|
|
|
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
|
|
|
|
Value word;
|
|
|
|
|
for (unsigned i = 0; i < numElems; ++i) {
|
|
|
|
|
if (i % srcAccumSizeInThreads == 0) {
|
|
|
|
|
// start of a replication
|
|
|
|
|
for (unsigned w = 0; w < numWordsEachRep; ++w) {
|
|
|
|
|
wordVecs[w] = undef(wordTy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
|
|
|
|
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
|
|
|
|
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
|
|
|
|
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
|
|
|
|
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
|
|
|
|
auto wordVecIdx =
|
|
|
|
|
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
|
|
|
|
wordVecs[wordVecIdx] =
|
|
|
|
|
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos));
|
|
|
|
|
if (i % minVec == 0)
|
|
|
|
|
word = undef(wordTy);
|
|
|
|
|
word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec));
|
|
|
|
|
if (i % minVec == minVec - 1) {
|
|
|
|
|
// step 1: recover the multidim_index from the index of
|
|
|
|
|
SmallVector<Value> multiDimIdx = srcIndices[i];
|
|
|
|
|
SmallVector<Value> dbgVal = srcIndices[i];
|
|
|
|
|
|
|
|
|
|
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
|
|
|
|
|
// end of replication, store the vectors into shared memory
|
|
|
|
|
unsigned linearRepIdx = i / srcAccumSizeInThreads;
|
|
|
|
|
auto multiDimRepIdx =
|
|
|
|
|
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
|
|
|
|
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
|
|
|
|
++linearWordIdx) {
|
|
|
|
|
// step 1: recover the multidim_index from the index of
|
|
|
|
|
// input_elements
|
|
|
|
|
auto multiDimWordIdx =
|
|
|
|
|
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
|
|
|
|
SmallVector<Value> multiDimIdx(2);
|
|
|
|
|
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
|
|
|
|
|
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
|
|
|
|
|
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
|
|
|
|
|
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
|
|
|
|
|
multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0));
|
|
|
|
|
multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1));
|
|
|
|
|
// step 2: do swizzling
|
|
|
|
|
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
|
|
|
|
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
|
|
|
|
Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
|
|
|
|
|
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
|
|
|
|
|
phaseId = urem(phaseId, i32_val(maxPhase));
|
|
|
|
|
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
|
|
|
|
off_0 = mul(off_0, outVecVal);
|
|
|
|
|
remained = udiv(remained, minVecVal);
|
|
|
|
|
off_0 = add(off_0, mul(remained, minVecVal));
|
|
|
|
|
Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
|
|
|
|
|
|
|
|
|
|
// step 2: do swizzling
|
|
|
|
|
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
|
|
|
|
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
|
|
|
|
Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]);
|
|
|
|
|
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
|
|
|
|
|
phaseId = urem(phaseId, i32_val(maxPhase));
|
|
|
|
|
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
|
|
|
|
off_0 = mul(off_0, outVecVal);
|
|
|
|
|
remained = udiv(remained, minVecVal);
|
|
|
|
|
off_0 = add(off_0, mul(remained, minVecVal));
|
|
|
|
|
Value offset = add(off_1, off_0);
|
|
|
|
|
|
|
|
|
|
// step 3: store
|
|
|
|
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
|
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
|
|
|
|
store(wordVecs[linearWordIdx], smemAddr);
|
|
|
|
|
}
|
|
|
|
|
// step 3: store
|
|
|
|
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
|
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
|
|
|
|
store(word, smemAddr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@@ -145,20 +102,15 @@ public:
|
|
|
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
Attribute srcLayout = srcTy.getEncoding();
|
|
|
|
|
Attribute dstLayout = dstTy.getEncoding();
|
|
|
|
|
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
|
if (isaDistributedLayout(srcLayout) &&
|
|
|
|
|
dstLayout.isa<SharedEncodingAttr>()) {
|
|
|
|
|
return lowerBlockedToShared(op, adaptor, rewriter);
|
|
|
|
|
return lowerDistributedToShared(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
if (srcLayout.isa<SharedEncodingAttr>() &&
|
|
|
|
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
|
|
|
|
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
if ((srcLayout.isa<BlockedEncodingAttr>() ||
|
|
|
|
|
srcLayout.isa<MmaEncodingAttr>() ||
|
|
|
|
|
srcLayout.isa<SliceEncodingAttr>()) &&
|
|
|
|
|
(dstLayout.isa<BlockedEncodingAttr>() ||
|
|
|
|
|
dstLayout.isa<MmaEncodingAttr>() ||
|
|
|
|
|
dstLayout.isa<SliceEncodingAttr>())) {
|
|
|
|
|
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
|
|
|
|
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
if (srcLayout.isa<MmaEncodingAttr>() &&
|
|
|
|
@@ -476,8 +428,8 @@ private:
|
|
|
|
|
// Swizzling in shared memory to avoid bank conflict. Normally used for
|
|
|
|
|
// A/B operands of dots.
|
|
|
|
|
LogicalResult
|
|
|
|
|
lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
Value dst = op.result();
|
|
|
|
@@ -487,22 +439,20 @@ private:
|
|
|
|
|
auto dstShape = dstTy.getShape();
|
|
|
|
|
assert(srcShape.size() == 2 &&
|
|
|
|
|
"Unexpected rank of ConvertLayout(blocked->shared)");
|
|
|
|
|
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto srcLayout = srcTy.getEncoding();
|
|
|
|
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto inOrd = srcBlockedLayout.getOrder();
|
|
|
|
|
auto inOrd = getOrder(srcLayout);
|
|
|
|
|
auto outOrd = dstSharedLayout.getOrder();
|
|
|
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
|
|
|
|
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
|
|
|
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
|
|
|
|
smemBase = bitcast(smemBase, elemPtrTy);
|
|
|
|
|
|
|
|
|
|
auto srcStrides =
|
|
|
|
|
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
|
|
|
|
auto srcIndices =
|
|
|
|
|
emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape);
|
|
|
|
|
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
|
|
|
|
|
smemBase, elemTy, loc, rewriter);
|
|
|
|
|
|
|
|
|
|
auto dstStrides =
|
|
|
|
|
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
|
|
|
|
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
|
|
|
|
storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst,
|
|
|
|
|
smemBase, elemTy, loc, rewriter);
|
|
|
|
|
auto smemObj =
|
|
|
|
|
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
|
|
|
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
|
|
|
|