[Backend] Add value cache in emitting indices calculation and some refinement (#1018)

1, add explicit value cache in emitting indices calculation;
2, move the indices calculation emitting logics into
ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by
templates. Refer to the discussion in this thread by @LyricZhao :
https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
goostavz
2022-12-30 03:19:59 +08:00
committed by GitHub
parent 2ba74d2729
commit 1d3029faf8
11 changed files with 355 additions and 180 deletions

View File

@ -53,9 +53,6 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
auto wordTy = vec_ty(elemTy, minVec); auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy); auto elemPtrTy = ptr_ty(elemTy);
// TODO: [goostavz] We should make a cache for the calculation of
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
// optimize that
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout); SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]), SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])}; ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
@ -182,7 +179,7 @@ private:
unsigned rank = shape.size(); unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem = auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank); SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>( SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout)); elemId, getSizePerThread(layout), getOrder(layout));
@ -501,8 +498,8 @@ private:
auto srcStrides = auto srcStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter, auto srcIndices =
srcBlockedLayout, srcShape); emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape);
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst, storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter); smemBase, elemTy, loc, rewriter);
@ -680,7 +677,9 @@ private:
void populateConvertLayoutOpToLLVMPatterns( void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit) { const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem, patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit); indexCacheInfo, benefit);
} }

View File

@ -19,6 +19,8 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
void populateConvertLayoutOpToLLVMPatterns( void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit); const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
#endif #endif

View File

@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions. // Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {} : axisAnalysisPass(axisAnalysisPass) {}
@ -640,7 +640,7 @@ struct InsertSliceOpConversion
auto llSrc = adaptor.source(); auto llSrc = adaptor.source();
auto srcIndices = auto srcIndices =
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape); emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape);
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
elemTy, loc, rewriter); elemTy, loc, rewriter);
// Barrier is not necessary. // Barrier is not necessary.
@ -657,12 +657,12 @@ struct InsertSliceAsyncOpConversion
using ConvertTritonGPUOpToLLVMPattern< using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter, InsertSliceAsyncOpConversion(
const Allocation *allocation, Value smem, LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>( : ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, benefit), converter, allocation, smem, indexCacheInfo, benefit),
LoadStoreConversionBase(axisAnalysisPass) {} LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult LogicalResult
@ -865,11 +865,11 @@ struct InsertSliceAsyncOpConversion
} }
}; };
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns(
RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) { PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit); patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit); patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
@ -878,7 +878,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit); axisInfoAnalysis, benefit);
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem, patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
benefit); indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem, patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit); indexCacheInfo, axisInfoAnalysis,
benefit);
} }

View File

@ -6,11 +6,11 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns(
RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit); PatternBenefit benefit);
#endif #endif

View File

@ -164,7 +164,7 @@ private:
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset = SmallVector<SmallVector<unsigned>> offset =
emitOffsetForBlockedLayout(srcLayout, srcShape); emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs; std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices; std::map<SmallVector<unsigned>, Value> accIndices;
@ -479,10 +479,12 @@ private:
} }
}; };
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateReduceOpToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) { PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit); patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, benefit);
} }

View File

@ -6,10 +6,11 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateReduceOpToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit); PatternBenefit benefit);
#endif #endif

View File

@ -63,6 +63,7 @@ struct BroadcastOpConversion
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape(); auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank(); unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank()); assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout); auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape); auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
@ -272,8 +273,12 @@ struct PrintfOpConversion
struct MakeRangeOpConversion struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> { : public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) MakeRangeOpConversion(
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter, LLVMTypeConverter &converter,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
benefit) {} benefit) {}
LogicalResult LogicalResult
@ -500,10 +505,11 @@ struct AsyncWaitOpConversion
} }
}; };
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateTritonGPUToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) { PatternBenefit benefit) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit); patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem, patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit); benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit); patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit); patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit); patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit); patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit); patterns.add<PrintfOpConversion>(typeConverter, benefit);
} }

View File

@ -6,10 +6,11 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateTritonGPUToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit); PatternBenefit benefit);
#endif #endif

View File

@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr;
// FuncOpConversion/FuncOpConversionBase is borrowed from // FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14 // since it is not exposed on header files in mlir v14
@ -128,7 +127,60 @@ protected:
} }
}; };
struct ConvertTritonGPUOpToLLVMPatternBase { using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
struct CacheKeyDenseMapInfo {
static IndexCacheKeyT getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{});
}
static IndexCacheKeyT getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
}
static unsigned getHashValue(IndexCacheKeyT key) {
return llvm::hash_combine(
mlir::hash_value(key.first),
llvm::hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
return LHS == RHS;
}
};
class ConvertTritonGPUOpToLLVMPatternBase {
public:
// Two levels of value cache in emitting indices calculation:
// Key: pair<layout, shape>
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo> *indexCache;
OpBuilder::InsertPoint *indexInsertPoint;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
: converter(&typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem)
: converter(&typeConverter), allocation(allocation), smem(smem) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
allocation(allocation), smem(smem) {}
LLVMTypeConverter *getTypeConverter() const { return converter; }
static Value static Value
getStructFromSharedMemoryObject(Location loc, getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj, const SharedMemoryObject &smemObj,
@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy); return getStructFromElements(loc, elems, rewriter, structTy);
} }
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
allocation(allocation), smem(smem) {}
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto llvmIndexTy = this->getTypeConverter()->getIndexType();
@ -169,6 +202,23 @@ public:
return threadId; return threadId;
} }
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
// Utilities // Utilities
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
@ -242,6 +292,116 @@ public:
return ret; return ret;
} }
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
auto cache = indexCacheInfo.baseIndexCache;
assert(cache && "baseIndexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
restoreInsertionPointIfSet(insertPt, rewriter);
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
return result;
}
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key(layout, llvm::to_vector(shape));
auto cache = indexCacheInfo.indexCache;
assert(cache && "indexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(b);
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"implemented yet");
}
cache->insert(std::make_pair(key, result));
*insertPt = b.saveInsertionPoint();
return result;
}
}
private:
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
ConversionPatternRewriter &rewriter) const {
if (insertPt->isSet()) {
rewriter.restoreInsertionPoint(*insertPt);
} else {
auto func =
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
rewriter.setInsertionPointToStart(&func.getBody().front());
}
}
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
// Blocked layout indices // Blocked layout indices
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
@ -411,38 +571,6 @@ public:
return ret; return ret;
} }
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
}
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// Emit indices calculation within each ConversionPattern, and returns a // Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix. // [elemsPerThread X rank] index matrix.
@ -470,22 +598,6 @@ public:
return multiDimIdx; return multiDimIdx;
} }
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
SmallVector<SmallVector<Value>> SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout, const SliceEncodingAttr &sliceLayout,
@ -505,46 +617,45 @@ public:
return resultIndices; return resultIndices;
} }
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
protected: protected:
LLVMTypeConverter *converter;
const Allocation *allocation; const Allocation *allocation;
Value smem; Value smem;
IndexCacheInfo indexCacheInfo;
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
indexCacheInfo) {}
protected:
LLVMTypeConverter *getTypeConverter() const {
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
}
}; };
#endif #endif

View File

@ -170,16 +170,20 @@ public:
// We set a higher benefit here to ensure triton's patterns runs before // We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community // arith patterns for some encoding not supported by the community
// patterns. // patterns.
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
// Normal conversions // Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// ConvertLayoutOp // ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// DotOp // DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
@ -191,11 +195,11 @@ public:
// LoadStoreOp // LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// ReduceOp // ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// ViewOp // ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
@ -215,6 +219,13 @@ public:
private: private:
Value smem; Value smem;
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo>
indexCache;
int computeCapability{}; int computeCapability{};
void initSharedMemory(size_t size, void initSharedMemory(size_t size,

View File

@ -997,8 +997,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// ----- // -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.nctaid.x // CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y // CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z // CHECK: nvvm.read.ptx.sreg.nctaid.z
@ -1011,6 +1010,48 @@ func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
tt.store %a, %0 : tensor<32xi32, #blocked0> tt.store %a, %0 : tensor<32xi32, #blocked0>
return return
}
} }
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: test_index_cache
func @test_index_cache() {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_base_index_cache
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_index_cache_different_block
func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
scf.if %arg1 {
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
}
return
}
} }