[Backend] Add value cache in emitting indices calculation and some refinement (#1018)

1, add explicit value cache in emitting indices calculation;
2, move the indices calculation emitting logics into
ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by
templates. Refer to the discussion in this thread by @LyricZhao :
https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
goostavz
2022-12-30 03:19:59 +08:00
committed by GitHub
parent 2ba74d2729
commit 1d3029faf8
11 changed files with 355 additions and 180 deletions

View File

@@ -53,9 +53,6 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy);
// TODO: [goostavz] We should make a cache for the calculation of
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
// optimize that
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
@@ -182,7 +179,7 @@ private:
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
@@ -501,8 +498,8 @@ private:
auto srcStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter,
srcBlockedLayout, srcShape);
auto srcIndices =
emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape);
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
@@ -680,7 +677,9 @@ private:
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit) {
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
indexCacheInfo, benefit);
}