[Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
This commit is contained in:
@@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
@@ -128,7 +127,60 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||
|
||||
struct CacheKeyDenseMapInfo {
|
||||
static IndexCacheKeyT getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return std::make_pair(
|
||||
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
SmallVector<int64_t>{});
|
||||
}
|
||||
static IndexCacheKeyT getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return std::make_pair(
|
||||
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
|
||||
}
|
||||
static unsigned getHashValue(IndexCacheKeyT key) {
|
||||
return llvm::hash_combine(
|
||||
mlir::hash_value(key.first),
|
||||
llvm::hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTritonGPUOpToLLVMPatternBase {
|
||||
public:
|
||||
// Two levels of value cache in emitting indices calculation:
|
||||
// Key: pair<layout, shape>
|
||||
struct IndexCacheInfo {
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
*baseIndexCache;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
CacheKeyDenseMapInfo> *indexCache;
|
||||
OpBuilder::InsertPoint *indexInsertPoint;
|
||||
};
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
|
||||
: converter(&typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem)
|
||||
: converter(&typeConverter), allocation(allocation), smem(smem) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
IndexCacheInfo indexCacheInfo)
|
||||
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
|
||||
allocation(allocation), smem(smem) {}
|
||||
|
||||
LLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
const SharedMemoryObject &smemObj,
|
||||
@@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp>
|
||||
class ConvertTritonGPUOpToLLVMPattern
|
||||
: public ConvertOpToLLVMPattern<SourceOp>,
|
||||
public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
allocation(allocation), smem(smem) {}
|
||||
|
||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
@@ -169,6 +202,23 @@ public:
|
||||
return threadId;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shared memory utilities
|
||||
// -----------------------------------------------------------------------
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
T value) const {
|
||||
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
Value offVal = idx_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Utilities
|
||||
// -----------------------------------------------------------------------
|
||||
@@ -242,6 +292,116 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct SmallVectorKeyInfo {
|
||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||
return llvm::hash_combine_range(key.begin(), key.end());
|
||||
}
|
||||
static bool isEqual(const SmallVector<unsigned> &lhs,
|
||||
const SmallVector<unsigned> &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
static SmallVector<unsigned> getEmptyKey() {
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
static SmallVector<unsigned> getTombstoneKey() {
|
||||
return {std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Get offsets / indices for any layout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
|
||||
auto cache = indexCacheInfo.baseIndexCache;
|
||||
assert(cache && "baseIndexCache is nullptr");
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
if (cache->count(key) > 0) {
|
||||
return cache->lookup(key);
|
||||
} else {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
restoreInsertionPointIfSet(insertPt, rewriter);
|
||||
SmallVector<Value> result;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||
} else {
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = rewriter.saveInsertionPoint();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
IndexCacheKeyT key(layout, llvm::to_vector(shape));
|
||||
auto cache = indexCacheInfo.indexCache;
|
||||
assert(cache && "indexCache is nullptr");
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
if (cache->count(key) > 0) {
|
||||
return cache->lookup(key);
|
||||
} else {
|
||||
ConversionPatternRewriter::InsertionGuard guard(b);
|
||||
restoreInsertionPointIfSet(insertPt, b);
|
||||
SmallVector<SmallVector<Value>> result;
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
result = emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
}
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = b.saveInsertionPoint();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (insertPt->isSet()) {
|
||||
rewriter.restoreInsertionPoint(*insertPt);
|
||||
} else {
|
||||
auto func =
|
||||
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
rewriter.setInsertionPointToStart(&func.getBody().front());
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Blocked layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
@@ -411,38 +571,6 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Get offsets / indices for any layout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
|
||||
@@ -470,22 +598,6 @@ public:
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
struct SmallVectorKeyInfo {
|
||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||
return llvm::hash_combine_range(key.begin(), key.end());
|
||||
}
|
||||
static bool isEqual(const SmallVector<unsigned> &lhs,
|
||||
const SmallVector<unsigned> &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
static SmallVector<unsigned> getEmptyKey() {
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
static SmallVector<unsigned> getTombstoneKey() {
|
||||
return {std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
@@ -505,46 +617,45 @@ public:
|
||||
return resultIndices;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shared memory utilities
|
||||
// -----------------------------------------------------------------------
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
T value) const {
|
||||
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
Value offVal = idx_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
|
||||
protected:
|
||||
LLVMTypeConverter *converter;
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
IndexCacheInfo indexCacheInfo;
|
||||
};
|
||||
|
||||
template <typename SourceOp>
|
||||
class ConvertTritonGPUOpToLLVMPattern
|
||||
: public ConvertOpToLLVMPattern<SourceOp>,
|
||||
public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
IndexCacheInfo indexCacheInfo,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
|
||||
indexCacheInfo) {}
|
||||
|
||||
protected:
|
||||
LLVMTypeConverter *getTypeConverter() const {
|
||||
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user