[Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
@@ -170,16 +170,20 @@ public:
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
// patterns.
|
||||
OpBuilder::InsertPoint indexInsertPoint;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
// Normal conversions
|
||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ConvertLayoutOp
|
||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// DotOp
|
||||
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
@@ -191,11 +195,11 @@ public:
|
||||
// LoadStoreOp
|
||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ReduceOp
|
||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ViewOp
|
||||
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
axisInfoAnalysis, &allocation, smem,
|
||||
@@ -215,6 +219,13 @@ public:
|
||||
private:
|
||||
Value smem;
|
||||
|
||||
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
baseIndexCache;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
CacheKeyDenseMapInfo>
|
||||
indexCache;
|
||||
|
||||
int computeCapability{};
|
||||
|
||||
void initSharedMemory(size_t size,
|
||||
|
Reference in New Issue
Block a user