[Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
@ -53,9 +53,6 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
|
|||||||
auto wordTy = vec_ty(elemTy, minVec);
|
auto wordTy = vec_ty(elemTy, minVec);
|
||||||
auto elemPtrTy = ptr_ty(elemTy);
|
auto elemPtrTy = ptr_ty(elemTy);
|
||||||
|
|
||||||
// TODO: [goostavz] We should make a cache for the calculation of
|
|
||||||
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
|
|
||||||
// optimize that
|
|
||||||
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
||||||
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
||||||
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
||||||
@ -182,7 +179,7 @@ private:
|
|||||||
unsigned rank = shape.size();
|
unsigned rank = shape.size();
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
auto multiDimOffsetFirstElem =
|
auto multiDimOffsetFirstElem =
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
|
||||||
SmallVector<Value> multiDimOffset(rank);
|
SmallVector<Value> multiDimOffset(rank);
|
||||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||||
elemId, getSizePerThread(layout), getOrder(layout));
|
elemId, getSizePerThread(layout), getOrder(layout));
|
||||||
@ -501,8 +498,8 @@ private:
|
|||||||
|
|
||||||
auto srcStrides =
|
auto srcStrides =
|
||||||
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
||||||
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter,
|
auto srcIndices =
|
||||||
srcBlockedLayout, srcShape);
|
emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape);
|
||||||
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
|
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
|
||||||
smemBase, elemTy, loc, rewriter);
|
smemBase, elemTy, loc, rewriter);
|
||||||
|
|
||||||
@ -680,7 +677,9 @@ private:
|
|||||||
void populateConvertLayoutOpToLLVMPatterns(
|
void populateConvertLayoutOpToLLVMPatterns(
|
||||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem, PatternBenefit benefit) {
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit) {
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
indexCacheInfo, benefit);
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,8 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
|
|||||||
void populateConvertLayoutOpToLLVMPatterns(
|
void populateConvertLayoutOpToLLVMPatterns(
|
||||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem, PatternBenefit benefit);
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
|
|||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
|
||||||
// Contains some helper functions for both Load and Store conversions.
|
// Contains some helper functions for both Load and Store conversions.
|
||||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
struct LoadStoreConversionBase {
|
||||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||||
: axisAnalysisPass(axisAnalysisPass) {}
|
: axisAnalysisPass(axisAnalysisPass) {}
|
||||||
|
|
||||||
@ -640,7 +640,7 @@ struct InsertSliceOpConversion
|
|||||||
|
|
||||||
auto llSrc = adaptor.source();
|
auto llSrc = adaptor.source();
|
||||||
auto srcIndices =
|
auto srcIndices =
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape);
|
emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape);
|
||||||
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
|
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
|
||||||
elemTy, loc, rewriter);
|
elemTy, loc, rewriter);
|
||||||
// Barrier is not necessary.
|
// Barrier is not necessary.
|
||||||
@ -657,12 +657,12 @@ struct InsertSliceAsyncOpConversion
|
|||||||
using ConvertTritonGPUOpToLLVMPattern<
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
|
InsertSliceAsyncOpConversion(
|
||||||
const Allocation *allocation, Value smem,
|
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
|
||||||
AxisInfoAnalysis &axisAnalysisPass,
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit)
|
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||||
converter, allocation, smem, benefit),
|
converter, allocation, smem, indexCacheInfo, benefit),
|
||||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@ -865,11 +865,11 @@ struct InsertSliceAsyncOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateLoadStoreOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit) {
|
PatternBenefit benefit) {
|
||||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
@ -878,7 +878,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||||
axisInfoAnalysis, benefit);
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
indexCacheInfo, benefit);
|
||||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||||
axisInfoAnalysis, benefit);
|
indexCacheInfo, axisInfoAnalysis,
|
||||||
|
benefit);
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,11 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateLoadStoreOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit);
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -164,7 +164,7 @@ private:
|
|||||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
|
|
||||||
SmallVector<SmallVector<unsigned>> offset =
|
SmallVector<SmallVector<unsigned>> offset =
|
||||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
emitOffsetForLayout(srcLayout, srcShape);
|
||||||
|
|
||||||
std::map<SmallVector<unsigned>, Value> accs;
|
std::map<SmallVector<unsigned>, Value> accs;
|
||||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||||
@ -479,10 +479,12 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateReduceOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit) {
|
PatternBenefit benefit) {
|
||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
|
||||||
|
indexCacheInfo, benefit);
|
||||||
}
|
}
|
||||||
|
@ -6,10 +6,11 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateReduceOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit);
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
@ -63,6 +63,7 @@ struct BroadcastOpConversion
|
|||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
unsigned rank = srcTy.getRank();
|
unsigned rank = srcTy.getRank();
|
||||||
|
|
||||||
assert(rank == resultTy.getRank());
|
assert(rank == resultTy.getRank());
|
||||||
auto order = triton::gpu::getOrder(srcLayout);
|
auto order = triton::gpu::getOrder(srcLayout);
|
||||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
||||||
@ -272,8 +273,12 @@ struct PrintfOpConversion
|
|||||||
struct MakeRangeOpConversion
|
struct MakeRangeOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||||
|
|
||||||
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
|
MakeRangeOpConversion(
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
|
LLVMTypeConverter &converter,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
|
||||||
|
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
|
||||||
benefit) {}
|
benefit) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@ -500,10 +505,11 @@ struct AsyncWaitOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonGPUToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit) {
|
PatternBenefit benefit) {
|
||||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||||
@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
benefit);
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||||
}
|
}
|
@ -6,10 +6,11 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonGPUToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit);
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject;
|
|||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||||
|
|
||||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||||
// since it is not exposed on header files in mlir v14
|
// since it is not exposed on header files in mlir v14
|
||||||
@ -128,7 +127,60 @@ protected:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||||
|
|
||||||
|
struct CacheKeyDenseMapInfo {
|
||||||
|
static IndexCacheKeyT getEmptyKey() {
|
||||||
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
|
return std::make_pair(
|
||||||
|
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||||
|
SmallVector<int64_t>{});
|
||||||
|
}
|
||||||
|
static IndexCacheKeyT getTombstoneKey() {
|
||||||
|
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||||
|
return std::make_pair(
|
||||||
|
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||||
|
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(IndexCacheKeyT key) {
|
||||||
|
return llvm::hash_combine(
|
||||||
|
mlir::hash_value(key.first),
|
||||||
|
llvm::hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
|
}
|
||||||
|
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
|
||||||
|
return LHS == RHS;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
|
public:
|
||||||
|
// Two levels of value cache in emitting indices calculation:
|
||||||
|
// Key: pair<layout, shape>
|
||||||
|
struct IndexCacheInfo {
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||||
|
*baseIndexCache;
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||||
|
CacheKeyDenseMapInfo> *indexCache;
|
||||||
|
OpBuilder::InsertPoint *indexInsertPoint;
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
|
||||||
|
: converter(&typeConverter) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem)
|
||||||
|
: converter(&typeConverter), allocation(allocation), smem(smem) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
IndexCacheInfo indexCacheInfo)
|
||||||
|
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
|
||||||
|
allocation(allocation), smem(smem) {}
|
||||||
|
|
||||||
|
LLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
getStructFromSharedMemoryObject(Location loc,
|
getStructFromSharedMemoryObject(Location loc,
|
||||||
const SharedMemoryObject &smemObj,
|
const SharedMemoryObject &smemObj,
|
||||||
@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
template <typename SourceOp>
|
|
||||||
class ConvertTritonGPUOpToLLVMPattern
|
|
||||||
: public ConvertOpToLLVMPattern<SourceOp>,
|
|
||||||
public ConvertTritonGPUOpToLLVMPatternBase {
|
|
||||||
public:
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
|
|
||||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
|
||||||
|
|
||||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
|
||||||
const Allocation *allocation,
|
|
||||||
Value smem,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
|
||||||
allocation(allocation), smem(smem) {}
|
|
||||||
|
|
||||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||||
@ -169,6 +202,23 @@ public:
|
|||||||
return threadId;
|
return threadId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Shared memory utilities
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
template <typename T>
|
||||||
|
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
T value) const {
|
||||||
|
|
||||||
|
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||||
|
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||||
|
auto bufferId = allocation->getBufferId(value);
|
||||||
|
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||||
|
size_t offset = allocation->getOffset(bufferId);
|
||||||
|
Value offVal = idx_val(offset);
|
||||||
|
Value base = gep(ptrTy, smem, offVal);
|
||||||
|
return base;
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// Utilities
|
// Utilities
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@ -242,6 +292,116 @@ public:
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SmallVectorKeyInfo {
|
||||||
|
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||||
|
return llvm::hash_combine_range(key.begin(), key.end());
|
||||||
|
}
|
||||||
|
static bool isEqual(const SmallVector<unsigned> &lhs,
|
||||||
|
const SmallVector<unsigned> &rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
static SmallVector<unsigned> getEmptyKey() {
|
||||||
|
return SmallVector<unsigned>();
|
||||||
|
}
|
||||||
|
static SmallVector<unsigned> getTombstoneKey() {
|
||||||
|
return {std::numeric_limits<unsigned>::max()};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Get offsets / indices for any layout
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
const Attribute &layout,
|
||||||
|
ArrayRef<int64_t> shape) const {
|
||||||
|
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
|
||||||
|
auto cache = indexCacheInfo.baseIndexCache;
|
||||||
|
assert(cache && "baseIndexCache is nullptr");
|
||||||
|
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||||
|
if (cache->count(key) > 0) {
|
||||||
|
return cache->lookup(key);
|
||||||
|
} else {
|
||||||
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
restoreInsertionPointIfSet(insertPt, rewriter);
|
||||||
|
SmallVector<Value> result;
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
result =
|
||||||
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.isVolta())
|
||||||
|
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||||
|
if (mmaLayout.isAmpere())
|
||||||
|
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||||
|
}
|
||||||
|
cache->insert(std::make_pair(key, result));
|
||||||
|
*insertPt = rewriter.saveInsertionPoint();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<SmallVector<unsigned>>
|
||||||
|
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||||
|
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||||
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.isVolta())
|
||||||
|
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||||
|
if (mmaLayout.isAmpere())
|
||||||
|
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Emit indices
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||||
|
ConversionPatternRewriter &b,
|
||||||
|
const Attribute &layout,
|
||||||
|
ArrayRef<int64_t> shape) const {
|
||||||
|
IndexCacheKeyT key(layout, llvm::to_vector(shape));
|
||||||
|
auto cache = indexCacheInfo.indexCache;
|
||||||
|
assert(cache && "indexCache is nullptr");
|
||||||
|
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||||
|
if (cache->count(key) > 0) {
|
||||||
|
return cache->lookup(key);
|
||||||
|
} else {
|
||||||
|
ConversionPatternRewriter::InsertionGuard guard(b);
|
||||||
|
restoreInsertionPointIfSet(insertPt, b);
|
||||||
|
SmallVector<SmallVector<Value>> result;
|
||||||
|
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||||
|
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||||
|
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
|
result = emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||||
|
} else {
|
||||||
|
llvm_unreachable(
|
||||||
|
"emitIndices for layouts other than blocked & slice not "
|
||||||
|
"implemented yet");
|
||||||
|
}
|
||||||
|
cache->insert(std::make_pair(key, result));
|
||||||
|
*insertPt = b.saveInsertionPoint();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
if (insertPt->isSet()) {
|
||||||
|
rewriter.restoreInsertionPoint(*insertPt);
|
||||||
|
} else {
|
||||||
|
auto func =
|
||||||
|
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
|
||||||
|
rewriter.setInsertionPointToStart(&func.getBody().front());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// Blocked layout indices
|
// Blocked layout indices
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@ -411,38 +571,6 @@ public:
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Get offsets / indices for any layout
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
|
|
||||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
|
||||||
const Attribute &layout,
|
|
||||||
ArrayRef<int64_t> shape) const {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
|
||||||
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
||||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
||||||
if (mmaLayout.isVolta())
|
|
||||||
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
|
||||||
if (mmaLayout.isAmpere())
|
|
||||||
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
|
||||||
}
|
|
||||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<SmallVector<unsigned>>
|
|
||||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
|
||||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
|
||||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
||||||
if (mmaLayout.isVolta())
|
|
||||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
|
||||||
if (mmaLayout.isAmpere())
|
|
||||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
|
||||||
}
|
|
||||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit indices calculation within each ConversionPattern, and returns a
|
// Emit indices calculation within each ConversionPattern, and returns a
|
||||||
// [elemsPerThread X rank] index matrix.
|
// [elemsPerThread X rank] index matrix.
|
||||||
|
|
||||||
@ -470,22 +598,6 @@ public:
|
|||||||
return multiDimIdx;
|
return multiDimIdx;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SmallVectorKeyInfo {
|
|
||||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
|
||||||
return llvm::hash_combine_range(key.begin(), key.end());
|
|
||||||
}
|
|
||||||
static bool isEqual(const SmallVector<unsigned> &lhs,
|
|
||||||
const SmallVector<unsigned> &rhs) {
|
|
||||||
return lhs == rhs;
|
|
||||||
}
|
|
||||||
static SmallVector<unsigned> getEmptyKey() {
|
|
||||||
return SmallVector<unsigned>();
|
|
||||||
}
|
|
||||||
static SmallVector<unsigned> getTombstoneKey() {
|
|
||||||
return {std::numeric_limits<unsigned>::max()};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
SmallVector<SmallVector<Value>>
|
SmallVector<SmallVector<Value>>
|
||||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
const SliceEncodingAttr &sliceLayout,
|
const SliceEncodingAttr &sliceLayout,
|
||||||
@ -505,46 +617,45 @@ public:
|
|||||||
return resultIndices;
|
return resultIndices;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Emit indices
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
|
||||||
ConversionPatternRewriter &b,
|
|
||||||
const Attribute &layout,
|
|
||||||
ArrayRef<int64_t> shape) const {
|
|
||||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
||||||
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
|
||||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
||||||
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
|
||||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
||||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
|
||||||
} else {
|
|
||||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
|
||||||
"implemented yet");
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Shared memory utilities
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
template <typename T>
|
|
||||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
T value) const {
|
|
||||||
|
|
||||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
|
||||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
|
||||||
auto bufferId = allocation->getBufferId(value);
|
|
||||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
|
||||||
size_t offset = allocation->getOffset(bufferId);
|
|
||||||
Value offVal = idx_val(offset);
|
|
||||||
Value base = gep(ptrTy, smem, offVal);
|
|
||||||
return base;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
LLVMTypeConverter *converter;
|
||||||
const Allocation *allocation;
|
const Allocation *allocation;
|
||||||
Value smem;
|
Value smem;
|
||||||
|
IndexCacheInfo indexCacheInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SourceOp>
|
||||||
|
class ConvertTritonGPUOpToLLVMPattern
|
||||||
|
: public ConvertOpToLLVMPattern<SourceOp>,
|
||||||
|
public ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
|
public:
|
||||||
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
IndexCacheInfo indexCacheInfo,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
|
||||||
|
indexCacheInfo) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
LLVMTypeConverter *getTypeConverter() const {
|
||||||
|
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -170,16 +170,20 @@ public:
|
|||||||
// We set a higher benefit here to ensure triton's patterns runs before
|
// We set a higher benefit here to ensure triton's patterns runs before
|
||||||
// arith patterns for some encoding not supported by the community
|
// arith patterns for some encoding not supported by the community
|
||||||
// patterns.
|
// patterns.
|
||||||
|
OpBuilder::InsertPoint indexInsertPoint;
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||||
|
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
// Normal conversions
|
// Normal conversions
|
||||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// ConvertLayoutOp
|
// ConvertLayoutOp
|
||||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// DotOp
|
// DotOp
|
||||||
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
@ -191,11 +195,11 @@ public:
|
|||||||
// LoadStoreOp
|
// LoadStoreOp
|
||||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// ReduceOp
|
// ReduceOp
|
||||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// ViewOp
|
// ViewOp
|
||||||
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
@ -215,6 +219,13 @@ public:
|
|||||||
private:
|
private:
|
||||||
Value smem;
|
Value smem;
|
||||||
|
|
||||||
|
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||||
|
baseIndexCache;
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||||
|
CacheKeyDenseMapInfo>
|
||||||
|
indexCache;
|
||||||
|
|
||||||
int computeCapability{};
|
int computeCapability{};
|
||||||
|
|
||||||
void initSharedMemory(size_t size,
|
void initSharedMemory(size_t size,
|
||||||
|
@ -997,8 +997,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|||||||
// -----
|
// -----
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||||
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|
||||||
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
||||||
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
||||||
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
||||||
@ -1011,6 +1010,48 @@ func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|||||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||||
|
|
||||||
return
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: test_index_cache
|
||||||
|
func @test_index_cache() {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||||
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: test_base_index_cache
|
||||||
|
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: test_index_cache_different_block
|
||||||
|
func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
scf.if %arg1 {
|
||||||
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user