[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)

This commit is contained in:
goostavz
2022-09-18 05:58:42 +08:00
committed by GitHub
parent 13669b46a6
commit 15bfd0cb79
17 changed files with 1025 additions and 191 deletions

View File

@@ -11,7 +11,9 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -26,8 +28,11 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
namespace mlir {
namespace LLVM {
@@ -43,11 +48,6 @@ namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO(Superjomn) Move to somewhere general utilities locates.
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
@@ -214,36 +214,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
}
};
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
ArrayRef<int64_t> shape) {
assert(multidim_index.size() == shape.size());
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
int64_t rank = shape.size();
int64_t acc_mul = 1;
for (int64_t i = 1; i < rank; ++i) {
acc_mul *= shape[i];
}
int64_t linear_index = 0;
for (int64_t i = 0; i < rank; ++i) {
linear_index += multidim_index[i] * acc_mul;
if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1];
}
}
return linear_index;
}
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
size_t rank = shape.size();
SmallVector<unsigned> elemsPerThreadPerDim(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
}
return product<unsigned>(elemsPerThreadPerDim);
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
@@ -309,9 +279,9 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
}
struct ConvertTritonGPUOpToLLVMPatternBase {
SmallVector<Value>
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter &rewriter) {
SmallVector<Value> results(elems);
for (unsigned i = 0; i < elems; ++i) {
Type type =
@@ -344,7 +314,12 @@ public:
for (unsigned i = 0; i < rank; ++i) {
reordered[i] = shape[order[i]];
}
return delinearize(rewriter, loc, linear, reordered);
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
@@ -370,13 +345,29 @@ public:
return multiDim;
}
// Emit indices calculation within each ConversionPattern
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
int rank = multiDim.size();
Value linear = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), 0);
if (rank > 0) {
linear = multiDim.front();
for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) {
Value dimSize = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(),
std::get<1>(z));
linear = rewriter.create<LLVM::AddOp>(
loc, rewriter.create<LLVM::MulOp>(loc, linear, dimSize),
std::get<0>(z));
}
}
return linear;
}
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = b.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
@@ -391,7 +382,6 @@ public:
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
auto order = blocked_layout.getOrder();
unsigned rank = shape.size();
SmallVector<Value, 4> threadIds(rank);
// step 1, delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
@@ -400,8 +390,19 @@ public:
delinearize(b, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] *
// threadsPerWarp[k]) *
// Wrap around multiDimWarpId/multiDimThreadId incase
// shape[k] > shapePerCTA[k]
unsigned maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
unsigned maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
multiDimWarpId[k] = b.create<LLVM::URemOp>(
loc, multiDimWarpId[k],
createIndexAttrConstant(b, loc, llvmIndexTy, maxWarps));
multiDimThreadId[k] = b.create<LLVM::URemOp>(
loc, multiDimThreadId[k],
createIndexAttrConstant(b, loc, llvmIndexTy, maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK =
createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]);
@@ -413,17 +414,100 @@ public:
loc, multiDimThreadId[k],
b.create<LLVM::MulOp>(loc, multiDimWarpId[k], threadsPerWarpK)));
}
return multiDimBase;
}
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForBlockedLayout(loc, b, blocked, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &b,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
size_t rank = shape.size();
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
SmallVector<int64_t> paddedShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim) {
paddedShape[d] = shape[d];
} else if (d == dim) {
paddedShape[d] = 1;
} else {
paddedShape[d] = shape[d - 1];
}
}
auto paddedIndices =
emitIndicesForBlockedLayout(loc, b, blockedParent, paddedShape);
unsigned numIndices = paddedIndices.size();
SmallVector<SmallVector<Value>> resultIndices(numIndices);
for (unsigned i = 0; i < numIndices; ++i) {
for (unsigned d = 0; d < rank + 1; ++d) {
if (d != dim) {
resultIndices[i].push_back(paddedIndices[i][d]);
}
}
}
return resultIndices;
} else if (auto sliceParent = parent.dyn_cast<SliceEncodingAttr>()) {
assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout"
"is not implemented yet");
return {};
} else {
assert(0 && "emitIndicesForSliceLayout with parent other than blocked & "
"slice not implemented yet");
return {};
}
}
// Emit indices calculation within each ConversionPattern
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization. We might
// implement a indiceCache if necessary.
SmallVector<SmallVector<Value>>
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA(rank);
for (unsigned k = 0; k < rank; ++k) {
shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k];
}
// step 1, delinearize threadId to get the base index
auto multiDimBase =
emitBaseIndexForBlockedLayout(loc, b, blockedLayout, shape);
// step 2, get offset of each element
unsigned elemsPerThread = 1;
SmallVector<SmallVector<unsigned>> offset(rank);
SmallVector<unsigned> multiDimElemsPerThread(rank);
for (unsigned k = 0; k < rank; ++k) {
multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k];
multiDimElemsPerThread[k] =
ceil<unsigned>(shape[k], shapePerCTA[k]) * sizePerThread[k];
elemsPerThread *= multiDimElemsPerThread[k];
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
for (unsigned blockOffset = 0;
blockOffset <
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
blockOffset < ceil<unsigned>(shape[k], shapePerCTA[k]);
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
@@ -445,7 +529,7 @@ public:
std::multiplies<unsigned>());
SmallVector<unsigned> threadsPerDim(rank);
for (unsigned k = 0; k < rank; ++k) {
threadsPerDim[k] = shape[k] / sizePerThread[k];
threadsPerDim[k] = ceil<unsigned>(shape[k], sizePerThread[k]);
}
for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / accumSizePerThread;
@@ -469,6 +553,20 @@ public:
return multiDimIdx;
}
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
Value smem, const Allocation *allocation,
Operation *op) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3);
auto bufferId = allocation->getBufferId(op);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset);
Value base = rewriter.create<LLVM::GEPOp>(loc, ptrTy, smem, offVal);
return base;
}
};
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
@@ -482,19 +580,10 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
auto numElems = layout.getSizePerThread();
size_t totalElems =
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
std::multiplies<>{});
size_t numThreads =
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
size_t numElemsPerThread = totalElems / numThreads;
size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
auto structTy =
@@ -580,7 +669,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
unsigned valueElems = layout.getElemsPerThread(shape);
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
@@ -595,16 +684,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
auto ty = val.getType().cast<RankedTensorType>();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(layout && "unexpected layout in getLayout");
auto shape = ty.getShape();
unsigned valueElems = getElemsPerThread(layout, shape);
unsigned valueElems = layout.getElemsPerThread(shape);
return std::make_tuple(layout, valueElems);
}
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(val);
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
@@ -614,22 +702,18 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
unsigned getVectorizeSize(Value ptr,
const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(ptr);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlignment(ptr, layout);
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
return shape;
};
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
auto ty = ptr.getType().dyn_cast<RankedTensorType>();
assert(ty);
auto shape = ty.getShape();
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned vec = std::min(align, contigPerThread);
vec = std::min<unsigned>(shape[order[0]], vec);
return vec;
}
@@ -819,25 +903,22 @@ struct BroadcastOpConversion
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
// TODO: [goostavz] double confirm the op semantics with Phil
assert(rank == resultTy.getRank());
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
SmallVector<unsigned, 2> broadcastDims;
SmallVector<int64_t, 2> broadcastSizes;
int64_t duplicates = 1;
for (unsigned d = 0; d < rank; ++d) {
int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d]);
unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d];
int64_t numCtas = ceil<unsigned>(resultShape[d], resultShapePerCTA);
if (srcShape[d] != resultShape[d]) {
assert(srcShape[d] == 1);
broadcastDims.push_back(d);
broadcastSizes.push_back(resultShape[d]);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] = 1;
duplicates *= resultShape[d];
srcLogicalShape[d + rank] =
std::max(unsigned(1), srcLayout.getSizePerThread()[d]);
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
@@ -845,18 +926,37 @@ struct BroadcastOpConversion
resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
}
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
int64_t duplicates = 1;
SmallVector<int64_t, 2> broadcastSizes(broadcastDims.size() * 2);
for (auto it : llvm::enumerate(broadcastDims)) {
// Incase there are multiple indices in the src that is actually
// calculating the same element, srcLogicalShape may not need to be 1.
// Such as the case when src of shape [256, 1], and with a blocked layout:
// sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2]
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
broadcastSizes[it.index()] = d;
duplicates *= d;
d = resultLogicalShape[it.value() + rank] /
srcLogicalShape[it.value() + rank];
broadcastSizes[it.index() + broadcastDims.size()] = d;
duplicates *= d;
}
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < srcElems; ++i) {
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
auto resultMultiDim = srcMultiDim;
for (int64_t j = 0; j < duplicates; ++j) {
auto resultMultiDim = srcMultiDim;
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()];
resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()];
resultMultiDim[bcastDim.value() + rank] +=
bcastMultiDim[bcastDim.index() + broadcastDims.size()] *
srcLogicalShape[bcastDim.index() + broadcastDims.size()];
}
auto resultLinearIndex =
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
@@ -871,27 +971,29 @@ struct BroadcastOpConversion
}
};
struct ViewOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ViewOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern;
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor,
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly
// rewriter.replaceOp(op, adaptor.src());
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().cast<RankedTensorType>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals =
this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
@@ -911,12 +1013,12 @@ struct MakeRangeOpConversion
Location loc = op->getLoc();
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
auto layout = rankedTy.getEncoding();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
auto idxs = emitIndices(loc, rewriter, layout, shape);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) {
@@ -1184,8 +1286,9 @@ struct AddPtrOpConversion
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
@@ -1225,7 +1328,8 @@ public:
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
@@ -1244,6 +1348,228 @@ public:
}
};
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
ConvertLayoutOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(converter,
benefit),
allocation_(allocation), smem_(smem) {}
LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
(!dstLayout.isa<BlockedEncodingAttr>())) {
// TODO: not implemented
assert(0 &&
"convert_layout except for blocked -> blocked is not implemented");
return failure();
}
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, smem_, allocation_,
op.getOperation());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
smemBase = rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
auto getContigPerThread = [&](const Attribute &layout,
unsigned d) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d];
} else {
assert(0 && "Unimplemented usage of getContigPerThread");
return 0;
}
};
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return product<unsigned>(blockedLayout.getSizePerThread());
} else {
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
return 0;
}
};
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getOrder();
} else {
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
return {};
}
};
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA =
std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d));
unsigned outPerCTA =
std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d));
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
// TODO: confirm this
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstLayout, shape);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
rewriter.create<mlir::gpu::BarrierOp>(loc);
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) {
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy,
inNumCTAsEachRep, multiDimRepId, inVec,
paddedRepShape, outOrd, vals, smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
rewriter.create<mlir::gpu::BarrierOp>(loc);
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec,
paddedRepShape, outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
}
private:
template <typename T>
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) const {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[rank - 1 - it.value()] = input[it.index()];
}
return result;
};
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding().cast<BlockedEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = layout.getSizePerThread();
auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank);
SmallVector<unsigned> shapePerCTA(rank);
for (unsigned d = 0; d < rank; ++d) {
shapePerCTA[d] = layout.getSizePerThread()[d] *
layout.getThreadsPerWarp()[d] *
layout.getWarpsPerCTA()[d];
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape());
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
// for (unsigned elemId = linearCTAId * accumSizePerThread;
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
auto multiDimElemId =
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
loc, multiDimOffsetFirstElem[d],
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd));
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value ptr =
rewriter.create<LLVM::GEPOp>(loc, elemPtrTy, smemBase, offset);
auto vecTy = VectorType::get(vec, llvmElemTy);
ptr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(vecTy, 3), ptr);
if (stNotRd) {
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v);
valVec = rewriter.create<LLVM::InsertElementOp>(
loc, vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
}
rewriter.create<LLVM::StoreOp>(loc, valVec, ptr);
} else {
Value valVec = rewriter.create<LLVM::LoadOp>(loc, ptr);
for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v);
vals[elemId + linearCTAId * accumSizePerThread + v] =
rewriter.create<LLVM::ExtractElementOp>(loc, llvmElemTy, valVec,
vVal);
}
}
}
}
}
const Allocation *allocation_;
Value smem_;
};
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
@@ -1266,9 +1592,10 @@ public:
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
Attribute layout = type.getEncoding();
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
if (layout && (layout.isa<BlockedEncodingAttr>() ||
layout.isa<SliceEncodingAttr>())) {
unsigned numElementsPerThread =
getElemsPerThread(blocked_layout, type.getShape());
getElemsPerThread(layout, type.getShape());
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
@@ -1285,7 +1612,8 @@ public:
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &analysis,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit = 1) {
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
@@ -1296,17 +1624,19 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
}
class ConvertTritonGPUToLLVM
@@ -1322,19 +1652,34 @@ public:
// TODO: need confirm
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, typeConverter);
TritonLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion
// step 2: Allocate for shared memories
// step 3: Convert the rest of ops via partial conversion
// The reason for a seperation between 1/3 is that, step 2 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem_
// is not revised during the conversion of step 3.
RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure();
Allocation allocation(mod);
auto axisAnalysis = runAxisAnalysis(mod);
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
RewritePatternSet patterns(context);
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
*axisAnalysis, 10 /*benefit*/);
*axisAnalysis, &allocation, smem_,
10 /*benefit*/);
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
@@ -1352,10 +1697,35 @@ protected:
auto axisAnalysisPass =
std::make_unique<AxisInfoAnalysis>(module->getContext());
axisAnalysisPass->run(module);
return axisAnalysisPass;
}
void initSharedMemory(size_t size,
TritonGPUToLLVMTypeConverter &typeConverter);
Value smem_;
};
void ConvertTritonGPUToLLVM::initSharedMemory(
size_t size, TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal,
"global_smem", /*value=*/Attribute(),
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
SmallVector<LLVM::LLVMFuncOp> funcs;
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
assert(funcs.size() == 1 &&
"Inliner pass is expected before TritonGPUToLLVM");
b.setInsertionPointToStart(&funcs[0].getBody().front());
smem_ = b.create<LLVM::AddressOfOp>(loc, global);
}
} // namespace
namespace mlir {
@@ -1366,10 +1736,20 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
// addIllegalDialect<triton::TritonDialect>();
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
addLegalDialect<LLVM::LLVMDialect>();
// addLegalDialect<NVVM::NVVMDialect>();
addIllegalOp<mlir::FuncOp>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {