[Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
@@ -63,6 +63,7 @@ struct BroadcastOpConversion
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
|
||||
assert(rank == resultTy.getRank());
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
||||
@@ -272,9 +273,13 @@ struct PrintfOpConversion
|
||||
struct MakeRangeOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||
|
||||
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
|
||||
benefit) {}
|
||||
MakeRangeOpConversion(
|
||||
LLVMTypeConverter &converter,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
|
||||
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
|
||||
benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
@@ -500,11 +505,12 @@ struct AsyncWaitOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit) {
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
@@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
}
|
Reference in New Issue
Block a user