[Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
@@ -53,9 +53,6 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
auto elemPtrTy = ptr_ty(elemTy);
|
||||
|
||||
// TODO: [goostavz] We should make a cache for the calculation of
|
||||
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
|
||||
// optimize that
|
||||
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
||||
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
||||
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
||||
@@ -182,7 +179,7 @@ private:
|
||||
unsigned rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, getSizePerThread(layout), getOrder(layout));
|
||||
@@ -501,8 +498,8 @@ private:
|
||||
|
||||
auto srcStrides =
|
||||
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
||||
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter,
|
||||
srcBlockedLayout, srcShape);
|
||||
auto srcIndices =
|
||||
emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape);
|
||||
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
|
||||
smemBase, elemTy, loc, rewriter);
|
||||
|
||||
@@ -680,7 +677,9 @@ private:
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem, PatternBenefit benefit) {
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
indexCacheInfo, benefit);
|
||||
}
|
||||
|
Reference in New Issue
Block a user