From 61b61755e5d0006b4229361e26b609b59bdcf629 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Tue, 27 Sep 2022 11:58:47 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693) --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 + lib/Analysis/Allocation.cpp | 66 +++---- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 163 +++++++++++------- lib/Dialect/TritonGPU/IR/Dialect.cpp | 47 ++++- python/triton/compiler.py | 7 +- test/Conversion/tritongpu_to_llvm.mlir | 22 ++- 6 files changed, 205 insertions(+), 104 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 66fae4de3..c36b0e501 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -22,8 +22,12 @@ namespace gpu { unsigned getElemsPerThread(Attribute layout, ArrayRef shape); +SmallVector getSizePerThread(Attribute layout); + unsigned getShapePerCTA(const Attribute &layout, unsigned d); +SmallVector getOrder(const Attribute &layout); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 6afa7ea1a..1a23dca6d 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -11,6 +11,7 @@ #include using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; @@ -32,39 +33,40 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, "Unexpect layout in getScratchConfigForCvtLayout()"); unsigned rank = dstTy.getRank(); SmallVector paddedRepShape(rank); - // TODO: move to TritonGPUAttrDefs.h.inc - auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getSizePerThread()[d] * - blockedLayout.getThreadsPerWarp()[d] * - blockedLayout.getWarpsPerCTA()[d]; - } else { - assert(0 && "Unimplemented usage of getShapePerCTA"); - return 0; - } - }; - if (srcLayout.isa() && - dstLayout.isa()) { - auto srcBlockedLayout = srcLayout.cast(); - auto dstBlockedLayout = dstLayout.cast(); - auto inOrd = srcBlockedLayout.getOrder(); - auto outOrd = dstBlockedLayout.getOrder(); - // TODO: Fix the legacy issue that ourOrd[0] == 0 always means - // that we cannot do vectorization. - inVec = outOrd[0] == 0 ? 1 - : inOrd[0] == 0 ? 1 - : srcBlockedLayout.getSizePerThread()[inOrd[0]]; - outVec = - outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]]; - unsigned pad = std::max(inVec, outVec); - for (unsigned d = 0; d < rank; ++d) { - paddedRepShape[d] = std::max( - std::min(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), - std::min(dstTy.getShape()[d], - getShapePerCTA(dstLayout, d))); - } - paddedRepShape[outOrd[0]] += pad; + auto srcBlockedLayout = srcLayout.dyn_cast(); + auto srcMmaLayout = srcLayout.dyn_cast(); + auto dstBlockedLayout = dstLayout.dyn_cast(); + auto dstMmaLayout = dstLayout.dyn_cast(); + assert((srcBlockedLayout || srcMmaLayout) && + "Unexpected srcLayout in getScratchConfigForCvtLayout"); + assert((dstBlockedLayout || dstMmaLayout) && + "Unexpected dstLayout in getScratchConfigForCvtLayout"); + assert(!(srcMmaLayout && dstMmaLayout) && + "Unexpected mma -> mma layout conversion"); + auto inOrd = + srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder(); + auto outOrd = + dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder(); + unsigned srcContigPerThread = + srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2; + unsigned dstContigPerThread = + dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; + outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + + unsigned pad = std::max(inVec, outVec); + for (unsigned d = 0; d < rank; ++d) { + paddedRepShape[d] = std::max( + std::min(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), + std::min(dstTy.getShape()[d], getShapePerCTA(dstLayout, d))); } + unsigned paddedDim = 1; + if (auto dstBlockedLayout = dstLayout.dyn_cast()) { + paddedDim = dstBlockedLayout.getOrder()[0]; + } + paddedRepShape[paddedDim] += pad; return paddedRepShape; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 8b9db71a7..4d1ba7330 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -29,7 +29,9 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::getElemsPerThread; +using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -350,6 +352,13 @@ public: return threadId; } + Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc, + int64_t value) const { + return rewriter.create( + loc, this->getTypeConverter()->getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), value)); + } + SmallVector delinearize(ConversionPatternRewriter &rewriter, Location loc, Value linear, ArrayRef shape, @@ -423,7 +432,7 @@ public: auto order = blocked_layout.getOrder(); unsigned rank = shape.size(); - // step 1, delinearize threadId to get the base index + // delinearize threadId to get the base index SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); SmallVector multiDimThreadId = @@ -455,6 +464,13 @@ public: return multiDimBase; } + SmallVector + emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b, + const MmaEncodingAttr &mmaLayout, + ArrayRef shape) const { + // ongoing + } + SmallVector> emitIndices(Location loc, ConversionPatternRewriter &b, const Attribute &layout, @@ -1459,9 +1475,11 @@ public: auto dstTy = dst.getType().cast(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if ((!srcLayout.isa()) || - (!dstLayout.isa())) { - // TODO: not implemented + if ((!srcLayout.isa() && + !srcLayout.isa()) || + (!dstLayout.isa() && + !dstLayout.isa())) { + // TODO: to be implemented return failure(); } auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); @@ -1471,31 +1489,6 @@ public: auto shape = dstTy.getShape(); unsigned rank = dstTy.getRank(); - auto getContigPerThread = [&](const Attribute &layout, - unsigned d) -> unsigned { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getSizePerThread()[d]; - } else { - assert(0 && "Unimplemented usage of getContigPerThread"); - return 0; - } - }; - auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned { - if (auto blockedLayout = layout.dyn_cast()) { - return product(blockedLayout.getSizePerThread()); - } else { - assert(0 && "Unimplemented usage of getAccumElemsPerThread"); - return 0; - } - }; - auto getOrder = [&](const Attribute &layout) -> ArrayRef { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getOrder(); - } else { - assert(0 && "Unimplemented usage of getAccumElemsPerThread"); - return {}; - } - }; SmallVector numReplicates(rank); SmallVector inNumCTAsEachRep(rank); SmallVector outNumCTAsEachRep(rank); @@ -1517,7 +1510,6 @@ public: } // Potentially we need to store for multiple CTAs in this replication unsigned accumNumReplicates = product(numReplicates); - unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout); unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); unsigned inVec = 0; @@ -1530,19 +1522,21 @@ public: for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { auto multiDimRepId = getMultiDimIndex(repId, numReplicates); rewriter.create(loc); - if (auto srcBlockedLayout = srcLayout.dyn_cast()) { - processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy, - inNumCTAsEachRep, multiDimRepId, inVec, - paddedRepShape, outOrd, vals, smemBase); + if (srcLayout.isa() || + srcLayout.isa()) { + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, outOrd, vals, + smemBase); } else { assert(0 && "ConvertLayout with input layout not implemented"); return failure(); } rewriter.create(loc); - if (auto dstBlockedLayout = dstLayout.dyn_cast()) { - processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy, - outNumCTAsEachRep, multiDimRepId, outVec, - paddedRepShape, outOrd, outVals, smemBase); + if (dstLayout.isa() || + dstLayout.isa()) { + processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, + outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, + outOrd, outVals, smemBase); } else { assert(0 && "ConvertLayout with output layout not implemented"); return failure(); @@ -1568,30 +1562,58 @@ private: return result; }; - void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter, - bool stNotRd, RankedTensorType type, - ArrayRef numCTAsEachRep, - ArrayRef multiDimRepId, unsigned vec, - ArrayRef paddedRepShape, - ArrayRef outOrd, - SmallVector &vals, Value smemBase) const { + // shared memory access for blocked or mma layout + void processReplica(Location loc, ConversionPatternRewriter &rewriter, + bool stNotRd, RankedTensorType type, + ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, + ArrayRef outOrd, SmallVector &vals, + Value smemBase) const { unsigned accumNumCTAsEachRep = product(numCTAsEachRep); - auto layout = type.getEncoding().cast(); + auto layout = type.getEncoding(); + auto blockedLayout = layout.dyn_cast(); + auto mmaLayout = layout.dyn_cast(); auto rank = type.getRank(); - auto sizePerThread = layout.getSizePerThread(); + auto sizePerThread = getSizePerThread(layout); auto accumSizePerThread = product(sizePerThread); auto llvmIndexTy = getTypeConverter()->getIndexType(); SmallVector numCTAs(rank); SmallVector shapePerCTA(rank); for (unsigned d = 0; d < rank; ++d) { - shapePerCTA[d] = layout.getSizePerThread()[d] * - layout.getThreadsPerWarp()[d] * - layout.getWarpsPerCTA()[d]; + shapePerCTA[d] = getShapePerCTA(layout, d); numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); } auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); - auto multiDimOffsetFirstElem = - emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape()); + SmallVector multiDimOffsetFirstElem; + Value mmaGrpId; + Value mmaGrpIdP8; + Value mmaThreadIdInGrpM2; + Value mmaThreadIdInGrpM2P1; + if (blockedLayout) { + multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( + loc, rewriter, blockedLayout, type.getShape()); + } else if (mmaLayout) { + // TODO: simplify these + auto cast = rewriter.create( + loc, TypeRange{llvmIndexTy}, + ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>( + loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)}); + Value threadId = cast.getResult(0); + Value warpSize = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), 32); + Value laneId = rewriter.create(loc, threadId, warpSize); + Value fourVal = createIndexConst(rewriter, loc, 4); + mmaGrpId = rewriter.create(loc, laneId, fourVal); + mmaGrpIdP8 = rewriter.create( + loc, mmaGrpId, createIndexConst(rewriter, loc, 8)); + Value mmaThreadIdInGrp = + rewriter.create(loc, laneId, fourVal); + mmaThreadIdInGrpM2 = rewriter.create( + loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2)); + mmaThreadIdInGrpM2P1 = rewriter.create( + loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1)); + } for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { auto multiDimCTAInRepId = getMultiDimIndex(ctaId, numCTAsEachRep); @@ -1605,18 +1627,27 @@ private: // TODO: This is actually redundant index calculation, we should // consider of caching the index calculation result in case // of performance issue observed. - // for (unsigned elemId = linearCTAId * accumSizePerThread; - // elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) { for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { - auto multiDimElemId = - getMultiDimIndex(elemId, layout.getSizePerThread()); SmallVector multiDimOffset(rank); - for (unsigned d = 0; d < rank; ++d) { - multiDimOffset[d] = add( - multiDimOffsetFirstElem[d], - createIndexAttrConstant(rewriter, loc, llvmIndexTy, - multiDimCTAInRepId[d] * shapePerCTA[d] + - multiDimElemId[d])); + if (blockedLayout) { + SmallVector multiDimElemId = getMultiDimIndex( + elemId, blockedLayout.getSizePerThread()); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = rewriter.create( + loc, multiDimOffsetFirstElem[d], + createIndexAttrConstant(rewriter, loc, llvmIndexTy, + multiDimCTAInRepId[d] * shapePerCTA[d] + + multiDimElemId[d])); + } + } else if (mmaLayout) { + assert(rank == 2); + assert(mmaLayout.getVersion() == 2 && + "mmaLayout ver1 not implemented yet"); + multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8; + multiDimOffset[1] = + elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1; + } else { + assert(0 && "unexpected layout in processReplica"); } Value offset = linearize(rewriter, loc, reorder(multiDimOffset, outOrd), @@ -2517,16 +2548,14 @@ public: llvm::Optional convertTritonTensorType(RankedTensorType type) { Attribute layout = type.getEncoding(); - if (layout && (layout.isa() || - layout.isa())) { + if (layout && + (layout.isa() || layout.isa() || + layout.isa())) { unsigned numElementsPerThread = getElemsPerThread(layout, type.getShape()); SmallVector types(numElementsPerThread, convertType(type.getElementType())); return LLVM::LLVMStructType::getLiteral(&getContext(), types); - } else if (auto mma_layout = layout.dyn_cast_or_null()) { - // TODO: Not implemented - return type; } else if (auto shared_layout = layout.dyn_cast_or_null()) { return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a590e9137..bcfa3176f 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -58,17 +58,56 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef shape) { } } +SmallVector getSizePerThread(Attribute layout) { + if (auto blockedLayout = layout.dyn_cast()) { + return SmallVector(blockedLayout.getSizePerThread().begin(), + blockedLayout.getSizePerThread().end()); + } else if (auto mmaLayout = layout.dyn_cast()) { + assert(mmaLayout.getVersion() == 2 && + "mmaLayout version = 1 is not implemented yet"); + return SmallVector{2, 2}; + } else { + assert(0 && "getSizePerThread not implemented"); + return {}; + } +} + unsigned getShapePerCTA(const Attribute &layout, unsigned d) { if (auto blockedLayout = layout.dyn_cast()) { return blockedLayout.getSizePerThread()[d] * blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getWarpsPerCTA()[d]; + } else if (auto mmaLayout = layout.dyn_cast()) { + assert(mmaLayout.getVersion() == 2 && + "mmaLayout version = 1 is not implemented yet"); + assert(d < 2 && "Unexpected usage of getShapePerCTA"); + if (d == 0) { + return 16 * mmaLayout.getWarpsPerCTA()[0]; + } else { + // d == 1 + return 8 * mmaLayout.getWarpsPerCTA()[1]; + } } else { assert(0 && "Unimplemented usage of getShapePerCTA"); return 0; } }; +SmallVector getOrder(const Attribute &layout) { + if (auto blockedLayout = layout.dyn_cast()) { + return SmallVector(blockedLayout.getOrder().begin(), + blockedLayout.getOrder().end()); + } else if (auto mmaLayout = layout.dyn_cast()) { + return SmallVector{1, 0}; + } else if (auto sharedLayout = layout.dyn_cast()) { + return SmallVector(sharedLayout.getOrder().begin(), + sharedLayout.getOrder().end()); + } else { + assert(0 && "Unimplemented usage of getOrder"); + return {}; + } +}; + } // namespace gpu } // namespace triton } // namespace mlir @@ -177,9 +216,11 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { } unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { - int threads = product(getWarpsPerCTA()); - int numElem = product(shape); - return numElem / threads; + size_t rank = shape.size(); + assert(rank == 2 && "Unexpected rank of mma layout"); + unsigned elemsCol = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; + unsigned elemsRow = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; + return elemsCol * elemsRow; } unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { diff --git a/python/triton/compiler.py b/python/triton/compiler.py index e7d0b1318..d0441d921 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1127,6 +1127,10 @@ def default_cache_dir(): return os.path.join(os.environ["HOME"], ".triton", "cache") +def default_cuda_dir(): + return os.path.join("/usr", "local", "cuda") + + class CacheManager: def __init__(self, key): @@ -1181,7 +1185,8 @@ def quiet(): def _build(name, src, srcdir): cuda_lib_dir = libcuda_dir() - cu_include_dir = "/usr/local/cuda/include" + cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir()) + cu_include_dir = os.path.join(cuda_path, "include") suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 7f5050ec5..d687a857f 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -486,7 +486,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } } - // TODO: problems in MLIR's parser on slice layout // #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> // module attributes {"triton_gpu.num-warps" = 1 : i32} { @@ -495,3 +494,24 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // return // } // } + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8> + // CHECK-LABEL: convert_layout_mma_block + func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) { + // CHECK: nvvm.barrier0 + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0> + return + } +}