Merge triton-mlir
branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
550
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Normal file
550
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Normal file
@@ -0,0 +1,550 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
|
||||
|
||||
// TODO: refactor so that it doesn't fail if Allocation.h
|
||||
// is included after utility.h (due to conflict in `store` macro
|
||||
// and <atomic>
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
|
||||
//
|
||||
#include "Utility.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
|
||||
// All the rights are reserved by the LLVM community.
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||
private:
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
for (const auto &attr : attrs) {
|
||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||
attr.getName() == "std.varargs" ||
|
||||
(filterArgAttrs &&
|
||||
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
|
||||
continue;
|
||||
result.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function for wrapping all attributes into a single DictionaryAttr
|
||||
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
||||
return DictionaryAttr::get(b.getContext(),
|
||||
b.getNamedAttr("llvm.struct_attrs", attrs));
|
||||
}
|
||||
|
||||
protected:
|
||||
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
||||
// to this legalization pattern.
|
||||
LLVM::LLVMFuncOp
|
||||
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
// Propagate argument/result attributes to all converted arguments/result
|
||||
// obtained after converting a given original argument/result.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
||||
attributes);
|
||||
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
|
||||
assert(!resAttrDicts.empty() && "expected array to be non-empty");
|
||||
auto newResAttrDicts =
|
||||
(funcOp.getNumResults() == 1)
|
||||
? resAttrDicts
|
||||
: rewriter.getArrayAttr(
|
||||
{wrapAsStructAttrs(rewriter, resAttrDicts)});
|
||||
attributes.push_back(rewriter.getNamedAttr(
|
||||
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
|
||||
}
|
||||
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
||||
SmallVector<Attribute, 4> newArgAttrs(
|
||||
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
||||
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
||||
auto mapping = result.getInputMapping(i);
|
||||
assert(mapping && "unexpected deletion of function argument");
|
||||
for (size_t j = 0; j < mapping->size; ++j)
|
||||
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
||||
}
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
|
||||
rewriter.getArrayAttr(newArgAttrs)));
|
||||
}
|
||||
for (const auto &pair : llvm::enumerate(attributes)) {
|
||||
if (pair.value().getName() == "llvm.linkage") {
|
||||
attributes.erase(attributes.begin() + pair.index());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Create an LLVM function, use external linkage by default until MLIR
|
||||
// functions have linkage.
|
||||
LLVM::Linkage linkage = LLVM::Linkage::External;
|
||||
if (funcOp->hasAttr("llvm.linkage")) {
|
||||
auto attr =
|
||||
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
|
||||
if (!attr) {
|
||||
funcOp->emitError()
|
||||
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
|
||||
return nullptr;
|
||||
}
|
||||
linkage = attr.getLinkage();
|
||||
}
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
&result)))
|
||||
return nullptr;
|
||||
|
||||
return newFuncOp;
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
const SharedMemoryObject &smemObj,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = smemObj.getElems();
|
||||
auto types = smemObj.getTypes();
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp>
|
||||
class ConvertTritonGPUOpToLLVMPattern
|
||||
: public ConvertOpToLLVMPattern<SourceOp>,
|
||||
public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
allocation(allocation), smem(smem) {}
|
||||
|
||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
||||
Value threadId = cast.getResult(0);
|
||||
return threadId;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Utilities
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank == order.size());
|
||||
auto reordered = reorder(shape, order);
|
||||
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape) const {
|
||||
unsigned rank = shape.size();
|
||||
assert(rank > 0);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
if (rank == 1) {
|
||||
multiDim[0] = linear;
|
||||
} else {
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
||||
Value dimSize = idx_val(en.value());
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
multiDim[rank - 1] = remained;
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
|
||||
ArrayRef<unsigned> order) const {
|
||||
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
|
||||
reorder<unsigned>(shape, order));
|
||||
}
|
||||
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
auto rank = multiDim.size();
|
||||
Value linear = idx_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, dimShape] :
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = idx_val(dimShape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||
assert(offsets.size() == strides.size());
|
||||
Value ret = idx_val(0);
|
||||
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
|
||||
ret = add(ret, mul(offset, stride));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Blocked layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
auto sizePerThread = blocked_layout.getSizePerThread();
|
||||
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
|
||||
auto order = blocked_layout.getOrder();
|
||||
unsigned rank = shape.size();
|
||||
|
||||
// delinearize threadId to get the base index
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimThreadId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
|
||||
SmallVector<Value> multiDimBase(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// Wrap around multiDimWarpId/multiDimThreadId incase
|
||||
// shape[k] > shapePerCTA[k]
|
||||
auto maxWarps =
|
||||
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
|
||||
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
|
||||
// multiDimBase[k] = (multiDimThreadId[k] +
|
||||
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
||||
// sizePerThread[k];
|
||||
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
|
||||
Value sizePerThreadK = idx_val(sizePerThread[k]);
|
||||
multiDimBase[k] =
|
||||
mul(sizePerThreadK, add(multiDimThreadId[k],
|
||||
mul(multiDimWarpId[k], threadsPerWarpK)));
|
||||
}
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
auto order = blockedLayout.getOrder();
|
||||
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
SmallVector<unsigned> tilesPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
|
||||
++blockOffset)
|
||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
||||
++threadOffset)
|
||||
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
|
||||
++elemOffset)
|
||||
offset[k].push_back(blockOffset * sizePerThread[k] *
|
||||
threadsPerWarp[k] * warpsPerCTA[k] +
|
||||
warpOffset * sizePerThread[k] *
|
||||
threadsPerWarp[k] +
|
||||
threadOffset * sizePerThread[k] + elemOffset);
|
||||
}
|
||||
|
||||
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||
SmallVector<unsigned> multiDimNanoTileId =
|
||||
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
|
||||
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
|
||||
linearNanoTileElemId, sizePerThread, order);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
unsigned reorderedMultiDimId =
|
||||
multiDimNanoTileId[k] *
|
||||
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
||||
multiDimNanoTileElemId[k];
|
||||
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
|
||||
}
|
||||
}
|
||||
return reorderedOffset;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Mma layout indices
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 2, j});
|
||||
ret.push_back({i + 2, j + 1});
|
||||
ret.push_back({i, j + 8});
|
||||
ret.push_back({i, j + 9});
|
||||
ret.push_back({i + 2, j + 8});
|
||||
ret.push_back({i + 2, j + 9});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
|
||||
idx_val(_warpsPerCTA[1])};
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(warpId, warpsPerCTA[0]);
|
||||
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
|
||||
Value offWarp0 = mul(warpId0, idx_val(16));
|
||||
Value offWarp1 = mul(warpId1, idx_val(8));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
|
||||
ret.push_back({i, j});
|
||||
ret.push_back({i, j + 1});
|
||||
ret.push_back({i + 8, j});
|
||||
ret.push_back({i + 8, j + 1});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Get offsets / indices for any layout
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||
}
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
|
||||
// TODO: [phil] redundant indices computation do not appear to hurt
|
||||
// performance much, but they could still significantly slow down
|
||||
// computations.
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
||||
// step 2, get offset of each element
|
||||
auto offset = emitOffsetForLayout(layout, shape);
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
unsigned rank = shape.size();
|
||||
unsigned elemsPerThread = offset.size();
|
||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||
SmallVector<Value>(rank));
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n)
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
|
||||
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
struct SmallVectorKeyInfo {
|
||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||
return llvm::hash_combine_range(key.begin(), key.end());
|
||||
}
|
||||
static bool isEqual(const SmallVector<unsigned> &lhs,
|
||||
const SmallVector<unsigned> &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
static SmallVector<unsigned> getEmptyKey() {
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
static SmallVector<unsigned> getTombstoneKey() {
|
||||
return {std::numeric_limits<unsigned>::max()};
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
auto parentIndices =
|
||||
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
||||
unsigned numIndices = parentIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices;
|
||||
for (unsigned i = 0; i < numIndices; ++i) {
|
||||
SmallVector<Value> indices = parentIndices[i];
|
||||
indices.erase(indices.begin() + dim);
|
||||
resultIndices.push_back(indices);
|
||||
}
|
||||
return resultIndices;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Emit indices
|
||||
// -----------------------------------------------------------------------
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Shared memory utilities
|
||||
// -----------------------------------------------------------------------
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
T value) const {
|
||||
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
Value offVal = idx_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
|
||||
protected:
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
};
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user