[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)
This commit is contained in:
@@ -11,7 +11,9 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
@@ -26,8 +28,11 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
@@ -43,11 +48,6 @@ namespace type = mlir::triton::type;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter;
|
||||
|
||||
// TODO(Superjomn) Move to somewhere general utilities locates.
|
||||
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
@@ -214,36 +214,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
}
|
||||
};
|
||||
|
||||
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
ArrayRef<int64_t> shape) {
|
||||
assert(multidim_index.size() == shape.size());
|
||||
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
|
||||
int64_t rank = shape.size();
|
||||
int64_t acc_mul = 1;
|
||||
for (int64_t i = 1; i < rank; ++i) {
|
||||
acc_mul *= shape[i];
|
||||
}
|
||||
int64_t linear_index = 0;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
linear_index += multidim_index[i] * acc_mul;
|
||||
if (i != (rank - 1)) {
|
||||
acc_mul = acc_mul / shape[i + 1];
|
||||
}
|
||||
}
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
size_t rank = shape.size();
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
|
||||
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
|
||||
}
|
||||
return product<unsigned>(elemsPerThreadPerDim);
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
@@ -309,9 +279,9 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
|
||||
}
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
SmallVector<Value>
|
||||
static SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
SmallVector<Value> results(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
Type type =
|
||||
@@ -344,7 +314,12 @@ public:
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
reordered[i] = shape[order[i]];
|
||||
}
|
||||
return delinearize(rewriter, loc, linear, reordered);
|
||||
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
|
||||
SmallVector<Value> multiDim(rank);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
multiDim[order[i]] = reorderedMultiDim[i];
|
||||
}
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
@@ -370,13 +345,29 @@ public:
|
||||
return multiDim;
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
int rank = multiDim.size();
|
||||
Value linear = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), 0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.front();
|
||||
for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) {
|
||||
Value dimSize = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(),
|
||||
std::get<1>(z));
|
||||
linear = rewriter.create<LLVM::AddOp>(
|
||||
loc, rewriter.create<LLVM::MulOp>(loc, linear, dimSize),
|
||||
std::get<0>(z));
|
||||
}
|
||||
}
|
||||
return linear;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = b.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
@@ -391,7 +382,6 @@ public:
|
||||
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
|
||||
auto order = blocked_layout.getOrder();
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<Value, 4> threadIds(rank);
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
@@ -400,8 +390,19 @@ public:
|
||||
delinearize(b, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimBase(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] *
|
||||
// threadsPerWarp[k]) *
|
||||
// Wrap around multiDimWarpId/multiDimThreadId incase
|
||||
// shape[k] > shapePerCTA[k]
|
||||
unsigned maxWarps =
|
||||
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
multiDimWarpId[k] = b.create<LLVM::URemOp>(
|
||||
loc, multiDimWarpId[k],
|
||||
createIndexAttrConstant(b, loc, llvmIndexTy, maxWarps));
|
||||
multiDimThreadId[k] = b.create<LLVM::URemOp>(
|
||||
loc, multiDimThreadId[k],
|
||||
createIndexAttrConstant(b, loc, llvmIndexTy, maxThreads));
|
||||
// multiDimBase[k] = (multiDimThreadId[k] +
|
||||
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
||||
// sizePerThread[k];
|
||||
Value threadsPerWarpK =
|
||||
createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]);
|
||||
@@ -413,17 +414,100 @@ public:
|
||||
loc, multiDimThreadId[k],
|
||||
b.create<LLVM::MulOp>(loc, multiDimWarpId[k], threadsPerWarpK)));
|
||||
}
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
} else {
|
||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim) {
|
||||
paddedShape[d] = shape[d];
|
||||
} else if (d == dim) {
|
||||
paddedShape[d] = 1;
|
||||
} else {
|
||||
paddedShape[d] = shape[d - 1];
|
||||
}
|
||||
}
|
||||
auto paddedIndices =
|
||||
emitIndicesForBlockedLayout(loc, b, blockedParent, paddedShape);
|
||||
unsigned numIndices = paddedIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
||||
for (unsigned i = 0; i < numIndices; ++i) {
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d != dim) {
|
||||
resultIndices[i].push_back(paddedIndices[i][d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return resultIndices;
|
||||
|
||||
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
|
||||
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
|
||||
"is not implemented yet");
|
||||
return {};
|
||||
|
||||
} else {
|
||||
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
|
||||
"slice not implemented yet");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||
// implement a indiceCache if necessary.
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k];
|
||||
}
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase =
|
||||
emitBaseIndexForBlockedLayout(loc, b, blockedLayout, shape);
|
||||
|
||||
// step 2, get offset of each element
|
||||
unsigned elemsPerThread = 1;
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
SmallVector<unsigned> multiDimElemsPerThread(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k];
|
||||
multiDimElemsPerThread[k] =
|
||||
ceil<unsigned>(shape[k], shapePerCTA[k]) * sizePerThread[k];
|
||||
elemsPerThread *= multiDimElemsPerThread[k];
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
for (unsigned blockOffset = 0;
|
||||
blockOffset <
|
||||
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
|
||||
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
++blockOffset)
|
||||
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
||||
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
||||
@@ -445,7 +529,7 @@ public:
|
||||
std::multiplies<unsigned>());
|
||||
SmallVector<unsigned> threadsPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
threadsPerDim[k] = shape[k] / sizePerThread[k];
|
||||
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
}
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
unsigned linearNanoTileId = n / accumSizePerThread;
|
||||
@@ -469,6 +553,20 @@ public:
|
||||
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
Value smem, const Allocation *allocation,
|
||||
Operation *op) const {
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset);
|
||||
Value base = rewriter.create<LLVM::GEPOp>(loc, ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
};
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
@@ -482,19 +580,10 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
|
||||
|
||||
auto numElems = layout.getSizePerThread();
|
||||
size_t totalElems =
|
||||
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
|
||||
std::multiplies<>{});
|
||||
size_t numThreads =
|
||||
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
|
||||
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
|
||||
size_t numElemsPerThread = totalElems / numThreads;
|
||||
|
||||
size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
||||
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
@@ -580,7 +669,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||
|
||||
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
|
||||
auto llvmElemPtrPtrTy =
|
||||
@@ -595,16 +684,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(layout && "unexpected layout in getLayout");
|
||||
auto shape = ty.getShape();
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
||||
return std::make_tuple(layout, valueElems);
|
||||
}
|
||||
|
||||
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
|
||||
auto order = layout.getOrder();
|
||||
|
||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
@@ -614,22 +702,18 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
||||
unsigned getVectorizeSize(Value ptr,
|
||||
const BlockedEncodingAttr &layout) const {
|
||||
auto axisInfo = getAxisInfo(ptr);
|
||||
auto contig = axisInfo->getContiguity();
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = layout.getOrder();
|
||||
unsigned align = getAlignment(ptr, layout);
|
||||
|
||||
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
return shape;
|
||||
};
|
||||
|
||||
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
||||
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
|
||||
auto ty = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
assert(ty);
|
||||
auto shape = ty.getShape();
|
||||
|
||||
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
@@ -819,25 +903,22 @@ struct BroadcastOpConversion
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned rank = srcTy.getRank();
|
||||
// TODO: [goostavz] double confirm the op semantics with Phil
|
||||
assert(rank == resultTy.getRank());
|
||||
|
||||
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
|
||||
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
|
||||
SmallVector<unsigned, 2> broadcastDims;
|
||||
SmallVector<int64_t, 2> broadcastSizes;
|
||||
int64_t duplicates = 1;
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
resultLayout.getWarpsPerCTA()[d]);
|
||||
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
|
||||
resultLayout.getThreadsPerWarp()[d] *
|
||||
resultLayout.getWarpsPerCTA()[d];
|
||||
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
|
||||
if (srcShape[d] != resultShape[d]) {
|
||||
assert(srcShape[d] == 1);
|
||||
broadcastDims.push_back(d);
|
||||
broadcastSizes.push_back(resultShape[d]);
|
||||
srcLogicalShape[d] = 1;
|
||||
srcLogicalShape[d + rank] = 1;
|
||||
duplicates *= resultShape[d];
|
||||
srcLogicalShape[d + rank] =
|
||||
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
|
||||
} else {
|
||||
srcLogicalShape[d] = numCtas;
|
||||
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
@@ -845,18 +926,37 @@ struct BroadcastOpConversion
|
||||
resultLogicalShape[d] = numCtas;
|
||||
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
|
||||
}
|
||||
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
||||
int64_t duplicates = 1;
|
||||
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
|
||||
for (auto it : llvm::enumerate(broadcastDims)) {
|
||||
// Incase there are multiple indices in the src that is actually
|
||||
// calculating the same element, srcLogicalShape may not need to be 1.
|
||||
// Such as the case when src of shape [256, 1], and with a blocked layout:
|
||||
// sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2]
|
||||
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
||||
broadcastSizes[it.index()] = d;
|
||||
duplicates *= d;
|
||||
d = resultLogicalShape[it.value() + rank] /
|
||||
srcLogicalShape[it.value() + rank];
|
||||
broadcastSizes[it.index() + broadcastDims.size()] = d;
|
||||
duplicates *= d;
|
||||
}
|
||||
|
||||
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
|
||||
auto elemTy = resultTy.getElementType();
|
||||
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
|
||||
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
||||
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
|
||||
auto resultMultiDim = srcMultiDim;
|
||||
for (int64_t j = 0; j < duplicates; ++j) {
|
||||
auto resultMultiDim = srcMultiDim;
|
||||
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
|
||||
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
|
||||
resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()];
|
||||
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
|
||||
resultMultiDim[bcastDim.value() + rank] +=
|
||||
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
|
||||
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
|
||||
}
|
||||
auto resultLinearIndex =
|
||||
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
|
||||
@@ -871,27 +971,29 @@ struct BroadcastOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct ViewOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ViewOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// We cannot directly
|
||||
// rewriter.replaceOp(op, adaptor.src());
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().cast<RankedTensorType>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto vals =
|
||||
this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
Value view = getStructFromElements(loc, vals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
@@ -911,12 +1013,12 @@ struct MakeRangeOpConversion
|
||||
Location loc = op->getLoc();
|
||||
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
|
||||
auto shape = rankedTy.getShape();
|
||||
auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto layout = rankedTy.getEncoding();
|
||||
|
||||
auto elemTy = rankedTy.getElementType();
|
||||
assert(elemTy.isInteger(32));
|
||||
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
|
||||
auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
|
||||
auto idxs = emitIndices(loc, rewriter, layout, shape);
|
||||
unsigned elems = idxs.size();
|
||||
SmallVector<Value> retVals(elems);
|
||||
for (auto multiDim : llvm::enumerate(idxs)) {
|
||||
@@ -1184,8 +1286,9 @@ struct AddPtrOpConversion
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
@@ -1225,7 +1328,8 @@ public:
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
@@ -1244,6 +1348,228 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertLayoutOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
||||
public:
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
ConvertLayoutOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
|
||||
benefit),
|
||||
allocation_(allocation), smem_(smem) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
||||
// TODO: not implemented
|
||||
assert(0 &&
|
||||
"convert_layout except for blocked -> blocked is not implemented");
|
||||
return failure();
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, smem_, allocation_,
|
||||
op.getOperation());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
smemBase = rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
||||
|
||||
auto shape = dstTy.getShape();
|
||||
unsigned rank = dstTy.getRank();
|
||||
auto getContigPerThread = [&](const Attribute &layout,
|
||||
unsigned d) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getContigPerThread");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return product<unsigned>(blockedLayout.getSizePerThread());
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getOrder();
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
SmallVector<unsigned> numReplicates(rank);
|
||||
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> inNumCTAs(rank);
|
||||
SmallVector<unsigned> outNumCTAs(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
unsigned inPerCTA =
|
||||
std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d));
|
||||
unsigned outPerCTA =
|
||||
std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d));
|
||||
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
||||
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
||||
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
||||
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
|
||||
// TODO: confirm this
|
||||
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
|
||||
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
|
||||
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
|
||||
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
|
||||
unsigned outElems = getElemsPerThread(dstLayout, shape);
|
||||
auto outOrd = getOrder(dstLayout);
|
||||
SmallVector<Value> outVals(outElems);
|
||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy,
|
||||
inNumCTAsEachRep, multiDimRepId, inVec,
|
||||
paddedRepShape, outOrd, vals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with input layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec,
|
||||
paddedRepShape, outOrd, outVals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with output layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[rank - 1 - it.value()] = input[it.index()];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
ArrayRef<unsigned> numCTAsEachRep,
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape,
|
||||
ArrayRef<unsigned> outOrd,
|
||||
SmallVector<Value> &vals, Value smemBase) const {
|
||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto layout = type.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = layout.getSizePerThread();
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
SmallVector<unsigned> shapePerCTA(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
shapePerCTA[d] = layout.getSizePerThread()[d] *
|
||||
layout.getThreadsPerWarp()[d] *
|
||||
layout.getWarpsPerCTA()[d];
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape());
|
||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||
auto multiDimCTAInRepId =
|
||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
||||
SmallVector<unsigned> multiDimCTAId(rank);
|
||||
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
||||
auto d = it.index();
|
||||
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
||||
}
|
||||
|
||||
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
|
||||
// TODO: This is actually redundant index calculation, we should
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
// for (unsigned elemId = linearCTAId * accumSizePerThread;
|
||||
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
auto multiDimElemId =
|
||||
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
|
||||
loc, multiDimOffsetFirstElem[d],
|
||||
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
||||
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
Value offset =
|
||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
||||
reorder<unsigned>(paddedRepShape, outOrd));
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value ptr =
|
||||
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, smemBase, offset);
|
||||
auto vecTy = VectorType::get(vec, llvmElemTy);
|
||||
ptr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMPointerType::get(vecTy, 3), ptr);
|
||||
if (stNotRd) {
|
||||
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value vVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
||||
valVec = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, vecTy, valVec,
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
|
||||
}
|
||||
rewriter.create<LLVM::StoreOp>(loc, valVec, ptr);
|
||||
} else {
|
||||
Value valVec = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value vVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v] =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, llvmElemTy, valVec,
|
||||
vVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Allocation *allocation_;
|
||||
Value smem_;
|
||||
};
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
@@ -1266,9 +1592,10 @@ public:
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
Attribute layout = type.getEncoding();
|
||||
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
if (layout && (layout.isa<BlockedEncodingAttr>() ||
|
||||
layout.isa<SliceEncodingAttr>())) {
|
||||
unsigned numElementsPerThread =
|
||||
getElemsPerThread(blocked_layout, type.getShape());
|
||||
getElemsPerThread(layout, type.getShape());
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
@@ -1285,7 +1612,8 @@ public:
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &analysis,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
PatternBenefit benefit = 1) {
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||
@@ -1296,17 +1624,19 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||
benefit);
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
}
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
@@ -1322,19 +1652,34 @@ public:
|
||||
// TODO: need confirm
|
||||
option.overrideIndexBitwidth(32);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, typeConverter);
|
||||
TritonLLVMConversionTarget target(*context, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 2: Allocate for shared memories
|
||||
// step 3: Convert the rest of ops via partial conversion
|
||||
// The reason for a seperation between 1/3 is that, step 2 is out of
|
||||
// the scope of Dialect Conversion, thus we need to make sure the smem_
|
||||
// is not revised during the conversion of step 3.
|
||||
RewritePatternSet func_patterns(context);
|
||||
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
Allocation allocation(mod);
|
||||
auto axisAnalysis = runAxisAnalysis(mod);
|
||||
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
||||
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
// patterns.
|
||||
RewritePatternSet patterns(context);
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisAnalysis, 10 /*benefit*/);
|
||||
*axisAnalysis, &allocation, smem_,
|
||||
10 /*benefit*/);
|
||||
|
||||
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
@@ -1352,10 +1697,35 @@ protected:
|
||||
auto axisAnalysisPass =
|
||||
std::make_unique<AxisInfoAnalysis>(module->getContext());
|
||||
axisAnalysisPass->run(module);
|
||||
|
||||
return axisAnalysisPass;
|
||||
}
|
||||
|
||||
void initSharedMemory(size_t size,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter);
|
||||
|
||||
Value smem_;
|
||||
};
|
||||
|
||||
void ConvertTritonGPUToLLVM::initSharedMemory(
|
||||
size_t size, TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
ModuleOp mod = getOperation();
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size);
|
||||
auto global = b.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal,
|
||||
"global_smem", /*value=*/Attribute(),
|
||||
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
|
||||
assert(funcs.size() == 1 &&
|
||||
"Inliner pass is expected before TritonGPUToLLVM");
|
||||
b.setInsertionPointToStart(&funcs[0].getBody().front());
|
||||
smem_ = b.create<LLVM::AddressOfOp>(loc, global);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@@ -1366,10 +1736,20 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
// addIllegalDialect<triton::TritonDialect>();
|
||||
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
|
||||
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
// addLegalDialect<NVVM::NVVMDialect>();
|
||||
addIllegalOp<mlir::FuncOp>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
||||
|
Reference in New Issue
Block a user