From 1d3029faf8e4bff85b1d203028cf72203ad8f7e5 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Fri, 30 Dec 2022 03:19:59 +0800 Subject: [PATCH] [Backend] Add value cache in emitting indices calculation and some refinement (#1018) 1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969 --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 15 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.h | 4 +- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 31 +- .../TritonGPUToLLVM/LoadStoreOpToLLVM.h | 12 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 16 +- .../TritonGPUToLLVM/ReduceOpToLLVM.h | 11 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 24 +- .../TritonGPUToLLVM/TritonGPUToLLVM.h | 11 +- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 323 ++++++++++++------ .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 19 +- test/Conversion/tritongpu_to_llvm.mlir | 69 +++- 11 files changed, 355 insertions(+), 180 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1a85ca5da..cd18ed751 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -53,9 +53,6 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, auto wordTy = vec_ty(elemTy, minVec); auto elemPtrTy = ptr_ty(elemTy); - // TODO: [goostavz] We should make a cache for the calculation of - // emitBaseIndexForBlockedLayout in case backend compiler not being able to - // optimize that SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), ceil(srcShape[1], srcShapePerCTA[1])}; @@ -182,7 +179,7 @@ private: unsigned rank = shape.size(); if (auto blockedLayout = layout.dyn_cast()) { auto multiDimOffsetFirstElem = - emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape); SmallVector multiDimOffset(rank); SmallVector multiDimElemId = getMultiDimIndex( elemId, getSizePerThread(layout), getOrder(layout)); @@ -501,8 +498,8 @@ private: auto srcStrides = getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); - auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter, - srcBlockedLayout, srcShape); + auto srcIndices = + emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape); storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst, smemBase, elemTy, loc, rewriter); @@ -680,7 +677,9 @@ private: void populateConvertLayoutOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, PatternBenefit benefit) { + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { patterns.add(typeConverter, allocation, smem, - benefit); + indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h index ebf943b6f..ec435b2ab 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h @@ -19,6 +19,8 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, void populateConvertLayoutOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, PatternBenefit benefit); + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 97ce9457a..92b11a94c 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; // Contains some helper functions for both Load and Store conversions. -struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { +struct LoadStoreConversionBase { explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) : axisAnalysisPass(axisAnalysisPass) {} @@ -640,7 +640,7 @@ struct InsertSliceOpConversion auto llSrc = adaptor.source(); auto srcIndices = - emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape); + emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape); storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, elemTy, loc, rewriter); // Barrier is not necessary. @@ -657,12 +657,12 @@ struct InsertSliceAsyncOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; - InsertSliceAsyncOpConversion(LLVMTypeConverter &converter, - const Allocation *allocation, Value smem, - AxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + InsertSliceAsyncOpConversion( + LLVMTypeConverter &converter, const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( - converter, allocation, smem, benefit), + converter, allocation, smem, indexCacheInfo, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult @@ -865,12 +865,12 @@ struct InsertSliceAsyncOpConversion } }; -void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit) { +void populateLoadStoreOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, @@ -878,7 +878,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, - benefit); + indexCacheInfo, benefit); patterns.add(typeConverter, allocation, smem, - axisInfoAnalysis, benefit); + indexCacheInfo, axisInfoAnalysis, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h index 96c2f1afd..b5042019e 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h @@ -6,11 +6,11 @@ using namespace mlir; using namespace mlir::triton; -void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit); +void populateLoadStoreOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 5f055fa6f..69abd889b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -164,7 +164,7 @@ private: auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); SmallVector> offset = - emitOffsetForBlockedLayout(srcLayout, srcShape); + emitOffsetForLayout(srcLayout, srcShape); std::map, Value> accs; std::map, Value> accIndices; @@ -479,10 +479,12 @@ private: } }; -void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit) { - patterns.add(typeConverter, allocation, smem, benefit); +void populateReduceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, allocation, smem, + indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h index fc4c5145c..f2c0af463 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h @@ -6,10 +6,11 @@ using namespace mlir; using namespace mlir::triton; -void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit); +void populateReduceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif \ No newline at end of file diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 3fcb83d95..2261688f0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -63,6 +63,7 @@ struct BroadcastOpConversion auto srcShape = srcTy.getShape(); auto resultShape = resultTy.getShape(); unsigned rank = srcTy.getRank(); + assert(rank == resultTy.getRank()); auto order = triton::gpu::getOrder(srcLayout); auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape); @@ -272,9 +273,13 @@ struct PrintfOpConversion struct MakeRangeOpConversion : public ConvertTritonGPUOpToLLVMPattern { - MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertTritonGPUOpToLLVMPattern(converter, - benefit) {} + MakeRangeOpConversion( + LLVMTypeConverter &converter, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( + converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo, + benefit) {} LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, @@ -500,11 +505,12 @@ struct AsyncWaitOpConversion } }; -void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit) { +void populateTritonGPUToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); @@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, indexCacheInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } \ No newline at end of file diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index e96330176..2a6e22bf0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -6,10 +6,11 @@ using namespace mlir; using namespace mlir::triton; -void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit); +void populateTritonGPUToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 6020c9617..7f11a762e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; - // FuncOpConversion/FuncOpConversionBase is borrowed from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // since it is not exposed on header files in mlir v14 @@ -128,7 +127,60 @@ protected: } }; -struct ConvertTritonGPUOpToLLVMPatternBase { +using IndexCacheKeyT = std::pair>; + +struct CacheKeyDenseMapInfo { + static IndexCacheKeyT getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return std::make_pair( + mlir::Attribute(static_cast(pointer)), + SmallVector{}); + } + static IndexCacheKeyT getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return std::make_pair( + mlir::Attribute(static_cast(pointer)), + SmallVector{std::numeric_limits::max()}); + } + static unsigned getHashValue(IndexCacheKeyT key) { + return llvm::hash_combine( + mlir::hash_value(key.first), + llvm::hash_combine_range(key.second.begin(), key.second.end())); + } + static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) { + return LHS == RHS; + } +}; + +class ConvertTritonGPUOpToLLVMPatternBase { +public: + // Two levels of value cache in emitting indices calculation: + // Key: pair + struct IndexCacheInfo { + DenseMap, CacheKeyDenseMapInfo> + *baseIndexCache; + DenseMap>, + CacheKeyDenseMapInfo> *indexCache; + OpBuilder::InsertPoint *indexInsertPoint; + }; + + explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter) + : converter(&typeConverter) {} + + explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem) + : converter(&typeConverter), allocation(allocation), smem(smem) {} + + explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + IndexCacheInfo indexCacheInfo) + : converter(&typeConverter), indexCacheInfo(indexCacheInfo), + allocation(allocation), smem(smem) {} + + LLVMTypeConverter *getTypeConverter() const { return converter; } + static Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, @@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase { LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); return getStructFromElements(loc, elems, rewriter, structTy); } -}; - -template -class ConvertTritonGPUOpToLLVMPattern - : public ConvertOpToLLVMPattern, - public ConvertTritonGPUOpToLLVMPatternBase { -public: - using OpAdaptor = typename SourceOp::Adaptor; - - explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} - - explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const Allocation *allocation, - Value smem, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit), - allocation(allocation), smem(smem) {} Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); @@ -169,6 +202,23 @@ public: return threadId; } + // ----------------------------------------------------------------------- + // Shared memory utilities + // ----------------------------------------------------------------------- + template + Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, + T value) const { + + auto ptrTy = LLVM::LLVMPointerType::get( + this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); + auto bufferId = allocation->getBufferId(value); + assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); + size_t offset = allocation->getOffset(bufferId); + Value offVal = idx_val(offset); + Value base = gep(ptrTy, smem, offVal); + return base; + } + // ----------------------------------------------------------------------- // Utilities // ----------------------------------------------------------------------- @@ -242,6 +292,116 @@ public: return ret; } + struct SmallVectorKeyInfo { + static unsigned getHashValue(const SmallVector &key) { + return llvm::hash_combine_range(key.begin(), key.end()); + } + static bool isEqual(const SmallVector &lhs, + const SmallVector &rhs) { + return lhs == rhs; + } + static SmallVector getEmptyKey() { + return SmallVector(); + } + static SmallVector getTombstoneKey() { + return {std::numeric_limits::max()}; + } + }; + + // ----------------------------------------------------------------------- + // Get offsets / indices for any layout + // ----------------------------------------------------------------------- + + SmallVector emitBaseIndexForLayout(Location loc, + ConversionPatternRewriter &rewriter, + const Attribute &layout, + ArrayRef shape) const { + IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape)); + auto cache = indexCacheInfo.baseIndexCache; + assert(cache && "baseIndexCache is nullptr"); + auto insertPt = indexCacheInfo.indexInsertPoint; + if (cache->count(key) > 0) { + return cache->lookup(key); + } else { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + restoreInsertionPointIfSet(insertPt, rewriter); + SmallVector result; + if (auto blockedLayout = layout.dyn_cast()) { + result = + emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + } else if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.isVolta()) + result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape); + if (mmaLayout.isAmpere()) + result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape); + } else { + llvm_unreachable("unsupported emitBaseIndexForLayout"); + } + cache->insert(std::make_pair(key, result)); + *insertPt = rewriter.saveInsertionPoint(); + return result; + } + } + + SmallVector> + emitOffsetForLayout(const Attribute &layout, ArrayRef shape) const { + if (auto blockedLayout = layout.dyn_cast()) + return emitOffsetForBlockedLayout(blockedLayout, shape); + if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.isVolta()) + return emitOffsetForMmaLayoutV1(mmaLayout, shape); + if (mmaLayout.isAmpere()) + return emitOffsetForMmaLayoutV2(mmaLayout, shape); + } + llvm_unreachable("unsupported emitOffsetForLayout"); + } + + // ----------------------------------------------------------------------- + // Emit indices + // ----------------------------------------------------------------------- + SmallVector> emitIndices(Location loc, + ConversionPatternRewriter &b, + const Attribute &layout, + ArrayRef shape) const { + IndexCacheKeyT key(layout, llvm::to_vector(shape)); + auto cache = indexCacheInfo.indexCache; + assert(cache && "indexCache is nullptr"); + auto insertPt = indexCacheInfo.indexInsertPoint; + if (cache->count(key) > 0) { + return cache->lookup(key); + } else { + ConversionPatternRewriter::InsertionGuard guard(b); + restoreInsertionPointIfSet(insertPt, b); + SmallVector> result; + if (auto blocked = layout.dyn_cast()) { + result = emitIndicesForDistributedLayout(loc, b, blocked, shape); + } else if (auto mma = layout.dyn_cast()) { + result = emitIndicesForDistributedLayout(loc, b, mma, shape); + } else if (auto slice = layout.dyn_cast()) { + result = emitIndicesForSliceLayout(loc, b, slice, shape); + } else { + llvm_unreachable( + "emitIndices for layouts other than blocked & slice not " + "implemented yet"); + } + cache->insert(std::make_pair(key, result)); + *insertPt = b.saveInsertionPoint(); + return result; + } + } + +private: + void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt, + ConversionPatternRewriter &rewriter) const { + if (insertPt->isSet()) { + rewriter.restoreInsertionPoint(*insertPt); + } else { + auto func = + rewriter.getInsertionPoint()->getParentOfType(); + rewriter.setInsertionPointToStart(&func.getBody().front()); + } + } + // ----------------------------------------------------------------------- // Blocked layout indices // ----------------------------------------------------------------------- @@ -411,38 +571,6 @@ public: return ret; } - // ----------------------------------------------------------------------- - // Get offsets / indices for any layout - // ----------------------------------------------------------------------- - - SmallVector emitBaseIndexForLayout(Location loc, - ConversionPatternRewriter &rewriter, - const Attribute &layout, - ArrayRef shape) const { - if (auto blockedLayout = layout.dyn_cast()) - return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); - if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isVolta()) - return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape); - if (mmaLayout.isAmpere()) - return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape); - } - llvm_unreachable("unsupported emitBaseIndexForLayout"); - } - - SmallVector> - emitOffsetForLayout(const Attribute &layout, ArrayRef shape) const { - if (auto blockedLayout = layout.dyn_cast()) - return emitOffsetForBlockedLayout(blockedLayout, shape); - if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isVolta()) - return emitOffsetForMmaLayoutV1(mmaLayout, shape); - if (mmaLayout.isAmpere()) - return emitOffsetForMmaLayoutV2(mmaLayout, shape); - } - llvm_unreachable("unsupported emitOffsetForLayout"); - } - // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. @@ -470,22 +598,6 @@ public: return multiDimIdx; } - struct SmallVectorKeyInfo { - static unsigned getHashValue(const SmallVector &key) { - return llvm::hash_combine_range(key.begin(), key.end()); - } - static bool isEqual(const SmallVector &lhs, - const SmallVector &rhs) { - return lhs == rhs; - } - static SmallVector getEmptyKey() { - return SmallVector(); - } - static SmallVector getTombstoneKey() { - return {std::numeric_limits::max()}; - } - }; - SmallVector> emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, const SliceEncodingAttr &sliceLayout, @@ -505,46 +617,45 @@ public: return resultIndices; } - // ----------------------------------------------------------------------- - // Emit indices - // ----------------------------------------------------------------------- - SmallVector> emitIndices(Location loc, - ConversionPatternRewriter &b, - const Attribute &layout, - ArrayRef shape) const { - if (auto blocked = layout.dyn_cast()) { - return emitIndicesForDistributedLayout(loc, b, blocked, shape); - } else if (auto mma = layout.dyn_cast()) { - return emitIndicesForDistributedLayout(loc, b, mma, shape); - } else if (auto slice = layout.dyn_cast()) { - return emitIndicesForSliceLayout(loc, b, slice, shape); - } else { - assert(0 && "emitIndices for layouts other than blocked & slice not " - "implemented yet"); - return {}; - } - } - - // ----------------------------------------------------------------------- - // Shared memory utilities - // ----------------------------------------------------------------------- - template - Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, - T value) const { - - auto ptrTy = LLVM::LLVMPointerType::get( - this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); - auto bufferId = allocation->getBufferId(value); - assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); - size_t offset = allocation->getOffset(bufferId); - Value offVal = idx_val(offset); - Value base = gep(ptrTy, smem, offVal); - return base; - } - protected: + LLVMTypeConverter *converter; const Allocation *allocation; Value smem; + IndexCacheInfo indexCacheInfo; +}; + +template +class ConvertTritonGPUOpToLLVMPattern + : public ConvertOpToLLVMPattern, + public ConvertTritonGPUOpToLLVMPatternBase { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {} + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {} + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + IndexCacheInfo indexCacheInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem, + indexCacheInfo) {} + +protected: + LLVMTypeConverter *getTypeConverter() const { + return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter(); + } }; #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 797cd6f6d..897ab913d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -170,16 +170,20 @@ public: // We set a higher benefit here to ensure triton's patterns runs before // arith patterns for some encoding not supported by the community // patterns. + OpBuilder::InsertPoint indexInsertPoint; + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{ + &baseIndexCache, &indexCache, &indexInsertPoint}; + RewritePatternSet patterns(context); // Normal conversions populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // ConvertLayoutOp populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // DotOp populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, @@ -191,11 +195,11 @@ public: // LoadStoreOp populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // ReduceOp populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // ViewOp populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, @@ -215,6 +219,13 @@ public: private: Value smem; + using IndexCacheKeyT = std::pair>; + DenseMap, CacheKeyDenseMapInfo> + baseIndexCache; + DenseMap>, + CacheKeyDenseMapInfo> + indexCache; + int computeCapability{}; void initSharedMemory(size_t size, diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index abc5e9a31..e49a231bf 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -997,20 +997,61 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - -func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { - // CHECK: nvvm.read.ptx.sreg.nctaid.x - // CHECK: nvvm.read.ptx.sreg.nctaid.y - // CHECK: nvvm.read.ptx.sreg.nctaid.z - %blockdimx = tt.get_num_programs {axis=0:i32} : i32 - %blockdimy = tt.get_num_programs {axis=1:i32} : i32 - %blockdimz = tt.get_num_programs {axis=2:i32} : i32 - %v0 = arith.addi %blockdimx, %blockdimy : i32 - %v1 = arith.addi %v0, %blockdimz : i32 - %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> - tt.store %a, %0 : tensor<32xi32, #blocked0> - - return + func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.nctaid.x + // CHECK: nvvm.read.ptx.sreg.nctaid.y + // CHECK: nvvm.read.ptx.sreg.nctaid.z + %blockdimx = tt.get_num_programs {axis=0:i32} : i32 + %blockdimy = tt.get_num_programs {axis=1:i32} : i32 + %blockdimz = tt.get_num_programs {axis=2:i32} : i32 + %v0 = arith.addi %blockdimx, %blockdimy : i32 + %v1 = arith.addi %v0, %blockdimz : i32 + %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32xi32, #blocked0> + + return + } } +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: test_index_cache + func @test_index_cache() { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + return + } } + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_base_index_cache + func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + %1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + return + } +} + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_index_cache_different_block + func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + scf.if %arg1 { + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + %1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + } + return + } +} \ No newline at end of file