diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 8c24a5777..3ccd2da32 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -39,6 +39,8 @@ SmallVector getShapePerCTA(const Attribute &layout); SmallVector getOrder(const Attribute &layout); +bool isaDistributedLayout(const Attribute &layout); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index cd18ed751..5f4c803e8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -9,10 +9,12 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, @@ -24,108 +26,63 @@ bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, dotOperandLayout.getParent() == mmaLayout; } -void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, - ArrayRef srcIndices, Value dst, Value smemBase, - Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { +void storeDistributedToShared(Value src, Value llSrc, + ArrayRef dstStrides, + ArrayRef> srcIndices, + Value dst, Value smemBase, Type elemTy, + Location loc, + ConversionPatternRewriter &rewriter) { auto srcTy = src.getType().cast(); auto srcShape = srcTy.getShape(); - assert(srcShape.size() == 2 && "Unexpected rank of insertSlice"); - + assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared"); auto dstTy = dst.getType().cast(); - auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto srcDistributedLayout = srcTy.getEncoding(); + if (auto mmaLayout = srcDistributedLayout.dyn_cast()) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout MMAv1->Shared is not suppported yet"); + } auto dstSharedLayout = dstTy.getEncoding().cast(); - auto inOrd = srcBlockedLayout.getOrder(); + auto inOrd = getOrder(srcDistributedLayout); auto outOrd = dstSharedLayout.getOrder(); - if (inOrd != outOrd) - llvm_unreachable( - "blocked -> shared with different order not yet implemented"); unsigned inVec = - inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; + inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1; unsigned outVec = dstSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned perPhase = dstSharedLayout.getPerPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned numElems = getElemsPerThread(srcTy); + assert(numElems == srcIndices.size()); auto inVals = getElementsFromStruct(loc, llSrc, rewriter); - auto srcAccumSizeInThreads = - product(srcBlockedLayout.getSizePerThread()); auto wordTy = vec_ty(elemTy, minVec); auto elemPtrTy = ptr_ty(elemTy); - - SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); - SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), - ceil(srcShape[1], srcShapePerCTA[1])}; - - // Visit each input value in the order they are placed in inVals - // - // Please note that the order was not awaring of blockLayout.getOrder(), - // thus the adjacent elems may not belong to a same word. This could be - // improved if we update the elements order by emitIndicesForBlockedLayout() - SmallVector wordsInEachRep(2); - wordsInEachRep[0] = inOrd[0] == 0 - ? srcBlockedLayout.getSizePerThread()[0] / minVec - : srcBlockedLayout.getSizePerThread()[0]; - wordsInEachRep[1] = inOrd[0] == 0 - ? srcBlockedLayout.getSizePerThread()[1] - : srcBlockedLayout.getSizePerThread()[1] / minVec; Value outVecVal = i32_val(outVec); Value minVecVal = i32_val(minVec); - auto numWordsEachRep = product(wordsInEachRep); - SmallVector wordVecs(numWordsEachRep); + Value word; for (unsigned i = 0; i < numElems; ++i) { - if (i % srcAccumSizeInThreads == 0) { - // start of a replication - for (unsigned w = 0; w < numWordsEachRep; ++w) { - wordVecs[w] = undef(wordTy); - } - } - unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; - auto multiDimIdxInNanoTile = getMultiDimIndex( - linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); - unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; - multiDimIdxInNanoTile[inOrd[0]] /= minVec; - auto wordVecIdx = - getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep, inOrd); - wordVecs[wordVecIdx] = - insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos)); + if (i % minVec == 0) + word = undef(wordTy); + word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); + if (i % minVec == minVec - 1) { + // step 1: recover the multidim_index from the index of + SmallVector multiDimIdx = srcIndices[i]; + SmallVector dbgVal = srcIndices[i]; - if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { - // end of replication, store the vectors into shared memory - unsigned linearRepIdx = i / srcAccumSizeInThreads; - auto multiDimRepIdx = - getMultiDimIndex(linearRepIdx, reps, inOrd); - for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; - ++linearWordIdx) { - // step 1: recover the multidim_index from the index of - // input_elements - auto multiDimWordIdx = - getMultiDimIndex(linearWordIdx, wordsInEachRep, inOrd); - SmallVector multiDimIdx(2); - auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + - multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); - auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] + - multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1); - multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0)); - multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1)); + // step 2: do swizzling + Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); + multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); + Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]); + Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); + phaseId = urem(phaseId, i32_val(maxPhase)); + Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); + off_0 = mul(off_0, outVecVal); + remained = udiv(remained, minVecVal); + off_0 = add(off_0, mul(remained, minVecVal)); + Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]])); - // step 2: do swizzling - Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); - multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); - Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]); - Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); - phaseId = urem(phaseId, i32_val(maxPhase)); - Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); - off_0 = mul(off_0, outVecVal); - remained = udiv(remained, minVecVal); - off_0 = add(off_0, mul(remained, minVecVal)); - Value offset = add(off_1, off_0); - - // step 3: store - Value smemAddr = gep(elemPtrTy, smemBase, offset); - smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); - store(wordVecs[linearWordIdx], smemAddr); - } + // step 3: store + Value smemAddr = gep(elemPtrTy, smemBase, offset); + smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); + store(word, smemAddr); } } } @@ -145,20 +102,15 @@ public: auto dstTy = dst.getType().cast(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (srcLayout.isa() && + if (isaDistributedLayout(srcLayout) && dstLayout.isa()) { - return lowerBlockedToShared(op, adaptor, rewriter); + return lowerDistributedToShared(op, adaptor, rewriter); } if (srcLayout.isa() && dstLayout.isa()) { return lowerSharedToDotOperand(op, adaptor, rewriter); } - if ((srcLayout.isa() || - srcLayout.isa() || - srcLayout.isa()) && - (dstLayout.isa() || - dstLayout.isa() || - dstLayout.isa())) { + if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { return lowerDistributedToDistributed(op, adaptor, rewriter); } if (srcLayout.isa() && @@ -476,8 +428,8 @@ private: // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. LogicalResult - lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); Value src = op.src(); Value dst = op.result(); @@ -487,22 +439,20 @@ private: auto dstShape = dstTy.getShape(); assert(srcShape.size() == 2 && "Unexpected rank of ConvertLayout(blocked->shared)"); - auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto srcLayout = srcTy.getEncoding(); auto dstSharedLayout = dstTy.getEncoding().cast(); - auto inOrd = srcBlockedLayout.getOrder(); + auto inOrd = getOrder(srcLayout); auto outOrd = dstSharedLayout.getOrder(); Value smemBase = getSharedMemoryBase(loc, rewriter, dst); auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); smemBase = bitcast(smemBase, elemPtrTy); - auto srcStrides = - getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); - auto srcIndices = - emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape); - storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst, - smemBase, elemTy, loc, rewriter); - + auto dstStrides = + getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); + storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst, + smemBase, elemTy, loc, rewriter); auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h index ec435b2ab..d5b866a44 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h @@ -11,10 +11,12 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, DotOperandEncodingAttr &dotOperandLayout); -void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, - ArrayRef srcIndices, Value dst, Value smemBase, - Type elemPtrTy, Location loc, - ConversionPatternRewriter &rewriter); +void storeDistributedToShared(Value src, Value llSrc, + ArrayRef srcStrides, + ArrayRef> srcIndices, + Value dst, Value smemBase, Type elemPtrTy, + Location loc, + ConversionPatternRewriter &rewriter); void populateConvertLayoutOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 92b11a94c..a06910277 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -639,10 +639,9 @@ struct InsertSliceOpConversion auto smemBase = gep(elemPtrTy, smemObj.base, offset); auto llSrc = adaptor.source(); - auto srcIndices = - emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape); - storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, - elemTy, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); + storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, + elemTy, loc, rewriter); // Barrier is not necessary. // The membar pass knows that it writes to shared memory and will handle it // properly. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d671f377d..a4573ca82 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -254,6 +254,11 @@ SmallVector getOrder(const Attribute &layout) { } }; +bool isaDistributedLayout(const Attribute &layout) { + return layout.isa() || layout.isa() || + layout.isa(); +} + } // namespace gpu } // namespace triton } // namespace mlir