[Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
@@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
struct LoadStoreConversionBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
@@ -640,7 +640,7 @@ struct InsertSliceOpConversion
|
||||
|
||||
auto llSrc = adaptor.source();
|
||||
auto srcIndices =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape);
|
||||
emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape);
|
||||
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
|
||||
elemTy, loc, rewriter);
|
||||
// Barrier is not necessary.
|
||||
@@ -657,12 +657,12 @@ struct InsertSliceAsyncOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
InsertSliceAsyncOpConversion(
|
||||
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
converter, allocation, smem, indexCacheInfo, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -865,12 +865,12 @@ struct InsertSliceAsyncOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit) {
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
|
||||
@@ -878,7 +878,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
indexCacheInfo, benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||
axisInfoAnalysis, benefit);
|
||||
indexCacheInfo, axisInfoAnalysis,
|
||||
benefit);
|
||||
}
|
||||
|
Reference in New Issue
Block a user