diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 165fb86fb..be3a48314 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -67,8 +67,6 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } -// Add other specification if needed... - } // namespace // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive @@ -89,8 +87,12 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(val, ptr) rewriter.create(loc, val, ptr) #define address_of(...) rewriter.create(loc, __VA_ARGS__) +#define barrier rewriter.create(loc) +#define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) #define vec_ty(type, num) VectorType::get(num, type) +#define void_ty LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__) // Creator for constant #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) @@ -109,6 +111,18 @@ namespace type = mlir::triton::type; class TritonGPUToLLVMTypeConverter; +// TODO: keep these before we have better debug log utilities +template +void print_array(ArrayRef array, const std::string &str) { + std::cout << str << ": "; + for (const T &e : array) + std::cout << e << ","; + std::cout << std::endl; +} +template void print_scalar(const T &e, const std::string &str) { + std::cout << str << ": " << e << std::endl; +} + // FuncOpConversion/FuncOpConversionBase is borrowed from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // since it is not exposed on header files in mlir v14 @@ -360,6 +374,13 @@ public: return threadId; } + Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc, + int64_t value) const { + return rewriter.create( + loc, this->getTypeConverter()->getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), value)); + } + // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. SmallVector delinearize(ConversionPatternRewriter &rewriter, @@ -859,7 +880,7 @@ struct StoreOpConversion asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); } - // TODO(Superjomn) Need to check masks before vectorize the load for all + // TODO(Superjomn) Need to check masks before vectorize the load for // the values share one predicate? Here assume all the mask values are // the same. Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1); @@ -1060,7 +1081,6 @@ struct LoadOpConversion LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value ptr = op.ptr(); Value mask = op.mask(); Value other = op.other(); @@ -1413,87 +1433,24 @@ public: LogicalResult matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); Value src = op.src(); Value dst = op.result(); auto srcTy = src.getType().cast(); auto dstTy = dst.getType().cast(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); + if (srcLayout.isa() && + dstLayout.isa()) { + return lowerBlockedToShared(op, adaptor, rewriter); + } if ((!srcLayout.isa() && !srcLayout.isa()) || (!dstLayout.isa() && !dstLayout.isa())) { // TODO: to be implemented - llvm::errs() << "Unsupported ConvertLayout found"; return failure(); } - auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - smemBase = bit_cast(elemPtrTy, smemBase); - - auto shape = dstTy.getShape(); - unsigned rank = dstTy.getRank(); - SmallVector numReplicates(rank); - SmallVector inNumCTAsEachRep(rank); - SmallVector outNumCTAsEachRep(rank); - SmallVector inNumCTAs(rank); - SmallVector outNumCTAs(rank); - auto srcShapePerCTA = getShapePerCTA(srcLayout); - auto dstShapePerCTA = getShapePerCTA(dstLayout); - for (unsigned d = 0; d < rank; ++d) { - unsigned inPerCTA = std::min(shape[d], srcShapePerCTA[d]); - unsigned outPerCTA = std::min(shape[d], dstShapePerCTA[d]); - unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); - numReplicates[d] = ceil(shape[d], maxPerCTA); - inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; - outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; - // TODO: confirm this - assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); - inNumCTAs[d] = ceil(shape[d], inPerCTA); - outNumCTAs[d] = ceil(shape[d], outPerCTA); - } - // Potentially we need to store for multiple CTAs in this replication - unsigned accumNumReplicates = product(numReplicates); - unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); - auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); - unsigned inVec = 0; - unsigned outVec = 0; - auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); - - unsigned outElems = getElemsPerThread(dstLayout, shape); - auto outOrd = getOrder(dstLayout); - SmallVector outVals(outElems); - for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { - auto multiDimRepId = getMultiDimIndex(repId, numReplicates); - rewriter.create(loc); - if (srcLayout.isa() || - srcLayout.isa()) { - processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, - multiDimRepId, inVec, paddedRepShape, outOrd, vals, - smemBase); - } else { - assert(0 && "ConvertLayout with input layout not implemented"); - return failure(); - } - rewriter.create(loc); - if (dstLayout.isa() || - dstLayout.isa()) { - processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, - outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, - outOrd, outVals, smemBase); - } else { - assert(0 && "ConvertLayout with output layout not implemented"); - return failure(); - } - } - - SmallVector types(outElems, llvmElemTy); - Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); - Value result = getStructFromElements(loc, outVals, rewriter, structTy); - rewriter.replaceOp(op, result); - return success(); + return lowerDistributedToDistributed(op, adaptor, rewriter); } private: @@ -1508,122 +1465,334 @@ private: return result; }; - // shared memory access for blocked or mma layout + // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, ArrayRef numCTAsEachRep, ArrayRef multiDimRepId, unsigned vec, ArrayRef paddedRepShape, ArrayRef outOrd, SmallVector &vals, - Value smemBase) const { - unsigned accumNumCTAsEachRep = product(numCTAsEachRep); - auto layout = type.getEncoding(); - auto blockedLayout = layout.dyn_cast(); - auto mmaLayout = layout.dyn_cast(); - auto rank = type.getRank(); - auto sizePerThread = getSizePerThread(layout); - auto accumSizePerThread = product(sizePerThread); - auto llvmIndexTy = getTypeConverter()->getIndexType(); - SmallVector numCTAs(rank); - auto shapePerCTA = getShapePerCTA(layout); - for (unsigned d = 0; d < rank; ++d) { - numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); - } - auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); - SmallVector multiDimOffsetFirstElem; - Value mmaGrpId; - Value mmaGrpIdP8; - Value mmaThreadIdInGrpM2; - Value mmaThreadIdInGrpM2P1; - if (blockedLayout) { - multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( - loc, rewriter, blockedLayout, type.getShape()); - } else if (mmaLayout) { - // TODO: simplify these - auto cast = rewriter.create( - loc, TypeRange{llvmIndexTy}, - ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>( - loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)}); - Value threadId = cast.getResult(0); - Value warpSize = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), 32); - Value laneId = rewriter.create(loc, threadId, warpSize); - Value fourVal = idx_val(4); - mmaGrpId = rewriter.create(loc, laneId, fourVal); - mmaGrpIdP8 = rewriter.create(loc, mmaGrpId, idx_val(8)); - Value mmaThreadIdInGrp = - rewriter.create(loc, laneId, fourVal); - mmaThreadIdInGrpM2 = - rewriter.create(loc, mmaThreadIdInGrp, idx_val(2)); - mmaThreadIdInGrpM2P1 = - rewriter.create(loc, mmaThreadIdInGrpM2, idx_val(1)); - } - for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { - auto multiDimCTAInRepId = - getMultiDimIndex(ctaId, numCTAsEachRep); - SmallVector multiDimCTAId(rank); - for (auto it : llvm::enumerate(multiDimCTAInRepId)) { - auto d = it.index(); - multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); - } + Value smemBase) const; - unsigned linearCTAId = getLinearIndex(multiDimCTAId, numCTAs); - // TODO: This is actually redundant index calculation, we should - // consider of caching the index calculation result in case - // of performance issue observed. - for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - SmallVector multiDimOffset(rank); - if (blockedLayout) { - SmallVector multiDimElemId = getMultiDimIndex( - elemId, blockedLayout.getSizePerThread()); - for (unsigned d = 0; d < rank; ++d) { - multiDimOffset[d] = rewriter.create( - loc, multiDimOffsetFirstElem[d], - createIndexAttrConstant(rewriter, loc, llvmIndexTy, - multiDimCTAInRepId[d] * shapePerCTA[d] + - multiDimElemId[d])); - } - } else if (mmaLayout) { - assert(rank == 2); - assert(mmaLayout.getVersion() == 2 && - "mmaLayout ver1 not implemented yet"); - multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8; - multiDimOffset[1] = - elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1; - } else { - assert(0 && "unexpected layout in processReplica"); + // blocked/mma -> blocked/mma. + // Data padding in shared memory to avoid bank confict. + LogicalResult + lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + // blocked -> shared. + // Swizzling in shared memory to avoid bank conflict. Normally used for + // A/B operands of dots. + LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; +}; + +void ConvertLayoutOpConversion::processReplica( + Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, + RankedTensorType type, ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, ArrayRef outOrd, + SmallVector &vals, Value smemBase) const { + unsigned accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding(); + auto blockedLayout = layout.dyn_cast(); + auto mmaLayout = layout.dyn_cast(); + auto rank = type.getRank(); + auto sizePerThread = getSizePerThread(layout); + auto accumSizePerThread = product(sizePerThread); + auto llvmIndexTy = getTypeConverter()->getIndexType(); + SmallVector numCTAs(rank); + auto shapePerCTA = getShapePerCTA(layout); + for (unsigned d = 0; d < rank; ++d) { + numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); + } + auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); + SmallVector multiDimOffsetFirstElem; + SmallVector mmaColIdx(2); + SmallVector mmaRowIdx(2); + if (blockedLayout) { + multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( + loc, rewriter, blockedLayout, type.getShape()); + } else if (mmaLayout) { + Value threadId = getThreadId(rewriter, loc); + Value warpSize = idx_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + // auto multiDimWarpId = + // delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA()); + // TODO: double confirm if its document bug or DotConversion's Bug + SmallVector multiDimWarpId(2); + multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); + multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); + Value four = idx_val(4); + Value mmaGrpId = udiv(laneId, four); + Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); + Value mmaThreadIdInGrp = urem(laneId, four); + Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2)); + Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1)); + Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16)); + mmaColIdx[0] = add(mmaGrpId, colWarpOffset); + mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset); + Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8)); + mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset); + mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset); + } + for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { + auto multiDimCTAInRepId = getMultiDimIndex(ctaId, numCTAsEachRep); + SmallVector multiDimCTAId(rank); + for (auto it : llvm::enumerate(multiDimCTAInRepId)) { + auto d = it.index(); + multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); + } + + unsigned linearCTAId = getLinearIndex(multiDimCTAId, numCTAs); + // TODO: This is actually redundant index calculation, we should + // consider of caching the index calculation result in case + // of performance issue observed. + for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { + SmallVector multiDimOffset(rank); + if (blockedLayout) { + SmallVector multiDimElemId = getMultiDimIndex( + elemId, blockedLayout.getSizePerThread()); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = + add(multiDimOffsetFirstElem[d], + idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] + + multiDimElemId[d])); } - Value offset = - linearize(rewriter, loc, reorder(multiDimOffset, outOrd), - reorder(paddedRepShape, outOrd)); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - Value ptr = gep(elemPtrTy, smemBase, offset); - auto vecTy = vec_ty(llvmElemTy, vec); - ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr); - if (stNotRd) { - Value valVec = rewriter.create(loc, vecTy); - for (unsigned v = 0; v < vec; ++v) { - Value vVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), v); - valVec = insert_element( - vecTy, valVec, - vals[elemId + linearCTAId * accumSizePerThread + v], vVal); - } - store(valVec, ptr); - } else { - Value valVec = load(ptr); - for (unsigned v = 0; v < vec; ++v) { - Value vVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), v); - vals[elemId + linearCTAId * accumSizePerThread + v] = - extract_element(llvmElemTy, valVec, vVal); - } + } else if (mmaLayout) { + assert(rank == 2); + assert(mmaLayout.getVersion() == 2 && + "mmaLayout ver1 not implemented yet"); + multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[0] = add( + multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0])); + multiDimOffset[1] = add( + multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1])); + } else { + assert(0 && "unexpected layout in processReplica"); + } + Value offset = + linearize(rewriter, loc, reorder(multiDimOffset, outOrd), + reorder(paddedRepShape, outOrd)); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); + Value ptr = gep(elemPtrTy, smemBase, offset); + auto vecTy = vec_ty(llvmElemTy, vec); + ptr = bit_cast(ptr_ty(vecTy, 3), ptr); + if (stNotRd) { + Value valVec = undef(vecTy); + for (unsigned v = 0; v < vec; ++v) { + valVec = insert_element( + vecTy, valVec, + vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v)); + } + store(valVec, ptr); + } else { + Value valVec = load(ptr); + for (unsigned v = 0; v < vec; ++v) { + vals[elemId + linearCTAId * accumSizePerThread + v] = + extract_element(llvmElemTy, valVec, idx_val(v)); } } } } +} + +LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = op.src(); + Value dst = op.result(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto elemPtrTy = ptr_ty(llvmElemTy, 3); + smemBase = bit_cast(elemPtrTy, smemBase); + auto shape = dstTy.getShape(); + unsigned rank = dstTy.getRank(); + SmallVector numReplicates(rank); + SmallVector inNumCTAsEachRep(rank); + SmallVector outNumCTAsEachRep(rank); + SmallVector inNumCTAs(rank); + SmallVector outNumCTAs(rank); + auto srcShapePerCTA = getShapePerCTA(srcLayout); + auto dstShapePerCTA = getShapePerCTA(dstLayout); + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(shape[d], srcShapePerCTA[d]); + unsigned outPerCTA = std::min(shape[d], dstShapePerCTA[d]); + unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); + numReplicates[d] = ceil(shape[d], maxPerCTA); + inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; + outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; + assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); + inNumCTAs[d] = ceil(shape[d], inPerCTA); + outNumCTAs[d] = ceil(shape[d], outPerCTA); + } + // Potentially we need to store for multiple CTAs in this replication + unsigned accumNumReplicates = product(numReplicates); + unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); + auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); + unsigned inVec = 0; + unsigned outVec = 0; + auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + + unsigned outElems = getElemsPerThread(dstLayout, shape); + auto outOrd = getOrder(dstLayout); + SmallVector outVals(outElems); + + for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { + auto multiDimRepId = getMultiDimIndex(repId, numReplicates); + barrier; + if (srcLayout.isa() || + srcLayout.isa()) { + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, outOrd, vals, + smemBase); + } else { + assert(0 && "ConvertLayout with input layout not implemented"); + return failure(); + } + barrier; + if (dstLayout.isa() || + dstLayout.isa()) { + processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, + multiDimRepId, outVec, paddedRepShape, outOrd, outVals, + smemBase); + } else { + assert(0 && "ConvertLayout with output layout not implemented"); + return failure(); + } + } + + SmallVector types(outElems, llvmElemTy); + Type structTy = struct_ty(getContext(), types); + Value result = getStructFromElements(loc, outVals, rewriter, structTy); + rewriter.replaceOp(op, result); + + return success(); }; +LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = op.src(); + Value dst = op.result(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); + auto srcShape = srcTy.getShape(); + assert(srcShape.size() == 2 && + "Unexpected rank of ConvertLayout(blocked->shared)"); + auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto dstSharedLayout = dstTy.getEncoding().cast(); + auto inOrd = srcBlockedLayout.getOrder(); + auto outOrd = dstSharedLayout.getOrder(); + unsigned inVec = + inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; + unsigned outVec = dstSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned perPhase = dstSharedLayout.getPerPhase(); + unsigned maxPhase = dstSharedLayout.getMaxPhase(); + unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape); + auto inVals = getElementsFromStruct(loc, adaptor.src(), numElems, rewriter); + unsigned srcAccumSizeInThreads = + product(srcBlockedLayout.getSizePerThread()); + auto elemTy = srcTy.getElementType(); + auto wordTy = vec_ty(elemTy, minVec); + + // TODO: [goostavz] We should make a cache for the calculation of + // emitBaseIndexForBlockedLayout in case backend compiler not being able to + // optimize that + SmallVector multiDimOffsetFirstElem = + emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape); + SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); + SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), + ceil(srcShape[1], srcShapePerCTA[1])}; + + // Visit each input value in the order they are placed in inVals + // + // Please note that the order was not awaring of blockLayout.getOrder(), + // thus the adjacent elems may not belong to a same word. This could be + // improved if we update the elements order by emitIndicesForBlockedLayout() + SmallVector wordsInEachRep(2); + wordsInEachRep[0] = inOrd[0] == 0 + ? srcBlockedLayout.getSizePerThread()[0] / minVec + : srcBlockedLayout.getSizePerThread()[0]; + wordsInEachRep[1] = inOrd[0] == 0 + ? srcBlockedLayout.getSizePerThread()[1] + : srcBlockedLayout.getSizePerThread()[1] / minVec; + Value outVecVal = idx_val(outVec); + Value minVecVal = idx_val(minVec); + Value smemBase = getSharedMemoryBase(loc, rewriter, dst); + auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); + smemBase = bit_cast(elemPtrTy, smemBase); + unsigned numWordsEachRep = product(wordsInEachRep); + SmallVector wordVecs(numWordsEachRep); + for (unsigned i = 0; i < numElems; ++i) { + if (i % srcAccumSizeInThreads == 0) { + // start of a replication + for (unsigned w = 0; w < numWordsEachRep; ++w) { + wordVecs[w] = undef(wordTy); + } + } + unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; + auto multiDimIdxInNanoTile = getMultiDimIndex( + linearIdxInNanoTile, srcBlockedLayout.getSizePerThread()); + multiDimIdxInNanoTile[inOrd[0]] /= minVec; + unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; + unsigned wordVecIdx = + getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep); + wordVecs[wordVecIdx] = + insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos)); + + if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { + // end of replication, store the vectors into shared memory + unsigned linearRepIdx = i / srcAccumSizeInThreads; + auto multiDimRepIdx = getMultiDimIndex(linearRepIdx, reps); + for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; + ++linearWordIdx) { + // step 1: recover the multidim_index from the index of input_elements + auto multiDimWordIdx = + getMultiDimIndex(linearWordIdx, wordsInEachRep); + SmallVector multiDimIdx(2); + auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + + multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); + auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] + + multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1); + multiDimIdx[0] = add(multiDimOffsetFirstElem[0], idx_val(wordOffset0)); + multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1)); + + // step 2: do swizzling + Value remained = urem(multiDimIdx[inOrd[0]], outVecVal); + multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal); + Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]])); + Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase)); + phaseId = urem(phaseId, idx_val(maxPhase)); + Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId); + off_0 = mul(off_0, outVecVal); + remained = udiv(remained, minVecVal); + off_0 = add(off_0, mul(remained, minVecVal)); + Value offset = add(off_1, off_0); + + // step 3: store + Value smemAddr = gep(elemPtrTy, smemBase, offset); + smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr); + store(wordVecs[linearWordIdx], smemAddr); + } + } + } + // TODO: double confirm if the Barrier is necessary here + barrier; + rewriter.replaceOp(op, smemBase); + return success(); +} /// ====================== dot codegen begin ========================== // Data loader for mma.16816 instruction. @@ -1843,7 +2012,7 @@ public: if (canUseLdmatrix) ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); - else if (elemBytes == 4 && needTrans) // tf32 & trans + else if (elemBytes == 4 && needTrans) ptrIdx = matIdx[order[0]]; else if (elemBytes == 1 && needTrans) ptrIdx = matIdx[order[0]] * 4; @@ -2127,10 +2296,6 @@ struct DotOpConversionHelper { .cast() .getEncoding() .cast(); - - ATensorTy = A.getType().cast(); - BTensorTy = B.getType().cast(); - DTensorTy = D.getType().cast(); } // Load SplatLike C which contains a constVal. It simply returns 4 fp32 @@ -2469,7 +2634,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, bool needTrans = kOrder != order[0]; // (a, b) is the coordinate. - auto load = [&, loader, ptrs, offs, needTrans](int a, int b) { + auto load = [=, &vals, &helper, &ld2](int a, int b) { auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, ptrs, helper.getMatType(), helper.getShemPtrTy()); @@ -2490,78 +2655,68 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, }; std::function loadA; - std::function loadB = getLoadMatrixFn( - B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, - 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); - - if (aTensorTy.getEncoding() - .dyn_cast()) { // load from smem + if (aTensorTy.getEncoding().isa()) { + // load from smem loadA = getLoadMatrixFn( A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); - } else if (auto blockedLayout = - aTensorTy.getEncoding() - .dyn_cast()) { // load from registers, - // used in gemm fuse + } else if (aTensorTy.getEncoding().isa()) { + // load from registers, used in gemm fuse // TODO(Superjomn) Port the logic. assert(false && "Loading A from register is not supported yet."); } else { assert(false && "A's layout is not supported."); } - const unsigned mStride = numRepN * 2; - SmallVector fc(numRepM * mStride + numRepN * 2); + std::function loadB = getLoadMatrixFn( + B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); + + const int fcSize = 4 * numRepM * numRepN; + SmallVector fc(fcSize); + + // Currently, we only support a SplatLike C. For the other cases, e.g., C in + // shared layout or blocked layout, we will support them by expanding + // convert_layout. + auto hc = helper.loadSplatLikeC(C, loc, rewriter); + assert(hc.size() == 4UL && "Only splat-like C is supported now"); + for (int i = 0; i < fc.size(); i++) + fc[i] = hc[0]; + auto callMma = [&](unsigned m, unsigned n, unsigned k) { + unsigned colsPerThread = numRepN * 2; PTXBuilder builder; - auto &mma = *builder.create(helper.getMmaInstr().str()); - auto retArgs = builder.newListOperand(4, "=r"); - auto aArgs = builder.newListOperand({ {ha[{m, k}], "r"}, {ha[{m + 1, k}], "r"}, {ha[{m, k + 1}], "r"}, {ha[{m + 1, k + 1}], "r"}, }); - auto bArgs = builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); - - // Currently, we only support a SplatLike C. For the other cases, e.g., C in - // shared layout or blocked layout, we will support them by expanding - // convert_layout. - auto hc = helper.loadSplatLikeC(C, loc, rewriter); - assert(hc.size() == 4UL && "Only splat-like C is supported now"); - auto cArgs = builder.newListOperand(); - for (int i = 0; i < hc.size(); ++i) { - cArgs->listAppend(builder.newOperand( - hc[i], std::to_string(i))); // reuse the output registers + for (int i = 0; i < 4; ++i) { + cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i], + std::to_string(i))); + // reuse the output registers } - mma(retArgs, aArgs, bArgs, cArgs); - Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); auto getIntAttr = [&](int v) { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; - fc[(m + 0) * mStride + (n * 2 + 0)] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(0)); - fc[(m + 0) * mStride + (n * 2 + 1)] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(1)); - fc[(m + 1) * mStride + (n * 2 + 0)] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(2)); - fc[(m + 1) * mStride + (n * 2 + 1)] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(3)); + for (int i = 0; i < 4; i++) + fc[m * colsPerThread + 4 * n + i] = + extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); }; // Main program - for (unsigned k = 0; k < numRepK; ++k) { for (unsigned m = 0; m < numRepM; ++m) loadA(2 * m, 2 * k); @@ -2741,6 +2896,9 @@ void ConvertTritonGPUToLLVM::initSharedMemory( "Inliner pass is expected before TritonGPUToLLVM"); b.setInsertionPointToStart(&funcs[0].getBody().front()); smem = b.create(loc, global); + auto ptrTy = + LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3); + smem = b.create(loc, ptrTy, smem); } } // namespace diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9050b0485..d531c0fca 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -87,7 +87,6 @@ SmallVector getShapePerCTA(const Attribute &layout) { } else { assert(0 && "Unimplemented usage of getShapePerCTA"); } - return shape; } @@ -104,7 +103,7 @@ SmallVector getOrder(const Attribute &layout) { assert(0 && "Unimplemented usage of getOrder"); return {}; } -} +}; } // namespace gpu } // namespace triton @@ -215,9 +214,12 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { } unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { - int threads = product(getWarpsPerCTA()); - int numElem = product(shape); - return numElem / threads; + size_t rank = shape.size(); + assert(rank == 2 && "Unexpected rank of mma layout"); + assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); + unsigned elemsCol = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; + unsigned elemsRow = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; + return elemsCol * elemsRow; } unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py new file mode 100644 index 000000000..2d3e88170 --- /dev/null +++ b/python/tests/test_gemm.py @@ -0,0 +1,52 @@ +import pytest +import torch +from torch.testing import assert_allclose + +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr +): + offs_m = tl.arange(0, M) + offs_n = tl.arange(0, N) + offs_k = tl.arange(0, K) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + + c = tl.dot(a, b) + + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + +# TODO: num_warps could only be 4 for now + + +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ + [128, 256, 32, 4], + [256, 128, 16, 4], + [128, 16, 32, 4], + [32, 128, 64, 4], +]) +def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): + a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) + b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) + grid = lambda META: (1, ) + matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + M=SIZE_M, N=SIZE_N, K=SIZE_K, + num_warps=NUM_WARPS) + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full") + assert_allclose(c, golden, rtol=1e-3, atol=1e-3) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index d0441d921..562d59526 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -910,7 +910,7 @@ def ptx_get_version(cuda_version) -> int: def path_to_ptxas(): - prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", "/usr/local/cuda/"] + prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", os.environ.get('CUDA_PATH', default_cuda_dir())] for prefix in prefixes: ptxas = os.path.join(prefix, "bin", "ptxas") if os.path.exists(ptxas): diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index d687a857f..f54a8d983 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -299,6 +299,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_alloc_tensor func @basic_alloc_tensor() { // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: llvm.bitcast // CHECK-NEXT: llvm.mlir.constant // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.bitcast @@ -315,13 +316,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_extract_slice func @basic_extract_slice() { // CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]] // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant - // CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]] - // CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast + // CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]] + // CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast // CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant // CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]] - // CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]] + // CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]] %index = arith.constant 1 : i32 %0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0> %1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0> @@ -515,3 +517,20 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { return } } + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8> + // CHECK-LABEL: convert_layout_blocked_shared + func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + return + } +} \ No newline at end of file