[Backend] Add value cache in emitting indices calculation and some refinement (#1018)

1, add explicit value cache in emitting indices calculation;
2, move the indices calculation emitting logics into
ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by
templates. Refer to the discussion in this thread by @LyricZhao :
https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
goostavz
2022-12-30 03:19:59 +08:00
committed by GitHub
parent 2ba74d2729
commit 1d3029faf8
11 changed files with 355 additions and 180 deletions

View File

@@ -170,16 +170,20 @@ public:
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
RewritePatternSet patterns(context);
// Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
indexCacheInfo, /*benefit=*/10);
// ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
indexCacheInfo, /*benefit=*/10);
// DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
@@ -191,11 +195,11 @@ public:
// LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
indexCacheInfo, /*benefit=*/10);
// ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
indexCacheInfo, /*benefit=*/10);
// ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
@@ -215,6 +219,13 @@ public:
private:
Value smem;
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo>
indexCache;
int computeCapability{};
void initSharedMemory(size_t size,