[Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)
What is done in this PR: - [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA` implementation for `SliceEncodingAttr` - [x] Split `emitIndices` into two phases: `emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout` - [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation - [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with ptx instruction `shfl.sync` - [x] Add support for scalar value in `StoreOpConversion` - [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
|||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
@@ -19,6 +20,8 @@ SmallVector<unsigned>
|
|||||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||||
unsigned &outVec);
|
unsigned &outVec);
|
||||||
|
|
||||||
|
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
|
||||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||||
|
@@ -250,6 +250,12 @@ struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add ".shared" suffix to instruction
|
||||||
|
PTXIOInstr &shared(bool predicate = true) {
|
||||||
|
o("shared", predicate);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
// Add ".v" suffix to instruction
|
// Add ".v" suffix to instruction
|
||||||
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
||||||
if (vecWidth > 1) {
|
if (vecWidth > 1) {
|
||||||
|
@@ -324,7 +324,9 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
|||||||
"Attribute":$parent
|
"Attribute":$parent
|
||||||
);
|
);
|
||||||
|
|
||||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||||
|
SmallVector<int64_t> paddedShape(ArrayRef<int64_t> shape) const;
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -14,6 +14,7 @@ using ::mlir::triton::gpu::BlockedEncodingAttr;
|
|||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
@@ -33,6 +34,10 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||||
unsigned rank = dstTy.getRank();
|
unsigned rank = dstTy.getRank();
|
||||||
SmallVector<unsigned> paddedRepShape(rank);
|
SmallVector<unsigned> paddedRepShape(rank);
|
||||||
|
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
|
||||||
|
srcLayout = srcSliceLayout.getParent();
|
||||||
|
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
|
||||||
|
dstLayout = dstSliceLayout.getParent();
|
||||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
@@ -73,6 +78,31 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
return paddedRepShape;
|
return paddedRepShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||||
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
|
auto srcShape = srcTy.getShape();
|
||||||
|
auto rank = srcShape.size();
|
||||||
|
auto axis = op.axis();
|
||||||
|
|
||||||
|
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||||
|
|
||||||
|
SmallVector<unsigned> smemShape;
|
||||||
|
for (auto d : srcShape)
|
||||||
|
smemShape.push_back(d);
|
||||||
|
|
||||||
|
if (fast_reduce) {
|
||||||
|
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
|
||||||
|
smemShape[axis] = sizeInterWarps;
|
||||||
|
} else {
|
||||||
|
unsigned threadsPerCTAAxis =
|
||||||
|
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
|
||||||
|
smemShape[axis] = threadsPerCTAAxis;
|
||||||
|
}
|
||||||
|
|
||||||
|
return smemShape;
|
||||||
|
}
|
||||||
|
|
||||||
class AllocationAnalysis {
|
class AllocationAnalysis {
|
||||||
public:
|
public:
|
||||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||||
@@ -127,9 +157,16 @@ private:
|
|||||||
// TODO(Keren): Reduce with index is not supported yet.
|
// TODO(Keren): Reduce with index is not supported yet.
|
||||||
auto value = op->getOperand(0);
|
auto value = op->getOperand(0);
|
||||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||||
auto bytes = tensorType.getNumElements() *
|
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||||
tensorType.getElementTypeBitWidth() / 8;
|
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
|
||||||
|
1, std::multiplies{});
|
||||||
|
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||||
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
|
} else {
|
||||||
|
assert(0 && "ReduceOp with input layout other than blocked layout is "
|
||||||
|
"not implemented yet");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||||
|
@@ -76,7 +76,15 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||||
|
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||||
|
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||||
|
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||||
|
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||||
|
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
||||||
|
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
||||||
|
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||||
|
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||||
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
||||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||||
@@ -89,11 +97,16 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||||
|
#define icmp_eq(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||||
|
#define icmp_slt(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||||
#define i32_ty rewriter.getIntegerType(32)
|
#define i32_ty rewriter.getIntegerType(32)
|
||||||
|
#define f32_ty rewriter.getF32Type()
|
||||||
#define vec_ty(type, num) VectorType::get(num, type)
|
#define vec_ty(type, num) VectorType::get(num, type)
|
||||||
#define void_ty LLVM::LLVMVoidType::get(ctx)
|
#define void_ty LLVM::LLVMVoidType::get(ctx)
|
||||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
|
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
|
||||||
@@ -336,6 +349,20 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
|||||||
return linearIndex;
|
return linearIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||||
|
Value val, Value pred) {
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto &st = builder.create<PTXIOInstr>("st")->shared().b(bits);
|
||||||
|
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||||
|
auto *valOpr = builder.newOperand(val, c);
|
||||||
|
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||||
|
return builder.launch(rewriter, loc, void_ty);
|
||||||
|
}
|
||||||
|
|
||||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
static SmallVector<Value>
|
static SmallVector<Value>
|
||||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||||
@@ -504,17 +531,8 @@ public:
|
|||||||
unsigned dim = sliceLayout.getDim();
|
unsigned dim = sliceLayout.getDim();
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
SmallVector<int64_t> paddedShape(rank + 1);
|
|
||||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
||||||
if (d < dim)
|
|
||||||
paddedShape[d] = shape[d];
|
|
||||||
else if (d == dim)
|
|
||||||
paddedShape[d] = 1;
|
|
||||||
else
|
|
||||||
paddedShape[d] = shape[d - 1];
|
|
||||||
}
|
|
||||||
auto paddedIndices = emitIndicesForBlockedLayout(
|
auto paddedIndices = emitIndicesForBlockedLayout(
|
||||||
loc, rewriter, blockedParent, paddedShape);
|
loc, rewriter, blockedParent, sliceLayout.paddedShape(shape));
|
||||||
unsigned numIndices = paddedIndices.size();
|
unsigned numIndices = paddedIndices.size();
|
||||||
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
||||||
for (unsigned i = 0; i < numIndices; ++i)
|
for (unsigned i = 0; i < numIndices; ++i)
|
||||||
@@ -536,31 +554,19 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emit indices calculation within each ConversionPattern, and returns a
|
SmallVector<SmallVector<unsigned>>
|
||||||
// [elemsPerThread X rank] index matrix.
|
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
||||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
ArrayRef<int64_t> shape) const {
|
||||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
|
||||||
// implement a indiceCache if necessary.
|
|
||||||
SmallVector<SmallVector<Value>>
|
|
||||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
const BlockedEncodingAttr &blockedLayout,
|
|
||||||
ArrayRef<int64_t> shape) const {
|
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
|
||||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
unsigned rank = shape.size();
|
unsigned rank = shape.size();
|
||||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||||
SmallVector<unsigned> tilesPerDim(rank);
|
SmallVector<unsigned> tilesPerDim(rank);
|
||||||
for (unsigned k = 0; k < rank; ++k)
|
for (unsigned k = 0; k < rank; ++k)
|
||||||
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||||
|
|
||||||
// step 1, delinearize threadId to get the base index
|
|
||||||
auto multiDimBase =
|
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
||||||
|
|
||||||
// step 2, get offset of each element
|
|
||||||
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
|
||||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||||
for (unsigned k = 0; k < rank; ++k) {
|
for (unsigned k = 0; k < rank; ++k) {
|
||||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||||
@@ -577,12 +583,10 @@ public:
|
|||||||
threadsPerWarp[k] +
|
threadsPerWarp[k] +
|
||||||
threadOffset * sizePerThread[k] + elemOffset);
|
threadOffset * sizePerThread[k] + elemOffset);
|
||||||
}
|
}
|
||||||
// step 3, add offset to base, and reorder the sequence of indices to
|
|
||||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
|
||||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
|
||||||
SmallVector<Value>(rank));
|
|
||||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
|
||||||
|
|
||||||
|
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
||||||
|
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||||
|
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
|
||||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||||
@@ -595,10 +599,38 @@ public:
|
|||||||
multiDimNanoTileId[k] *
|
multiDimNanoTileId[k] *
|
||||||
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
||||||
multiDimNanoTileElemId[k];
|
multiDimNanoTileElemId[k];
|
||||||
multiDimIdx[n][k] =
|
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
|
||||||
add(multiDimBase[k], idx_val(offset[k][reorderedMultiDimId]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return reorderedOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit indices calculation within each ConversionPattern, and returns a
|
||||||
|
// [elemsPerThread X rank] index matrix.
|
||||||
|
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||||
|
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||||
|
// implement a indiceCache if necessary.
|
||||||
|
SmallVector<SmallVector<Value>>
|
||||||
|
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
const BlockedEncodingAttr &blockedLayout,
|
||||||
|
ArrayRef<int64_t> shape) const {
|
||||||
|
// step 1, delinearize threadId to get the base index
|
||||||
|
auto multiDimBase =
|
||||||
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||||
|
|
||||||
|
// step 2, get offset of each element
|
||||||
|
SmallVector<SmallVector<unsigned>> offset =
|
||||||
|
emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||||
|
|
||||||
|
// step 3, add offset to base, and reorder the sequence of indices to
|
||||||
|
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||||
|
unsigned rank = shape.size();
|
||||||
|
unsigned elemsPerThread = offset.size();
|
||||||
|
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||||
|
SmallVector<Value>(rank));
|
||||||
|
for (unsigned n = 0; n < elemsPerThread; ++n)
|
||||||
|
for (unsigned k = 0; k < rank; ++k)
|
||||||
|
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
|
||||||
|
|
||||||
return multiDimIdx;
|
return multiDimIdx;
|
||||||
}
|
}
|
||||||
@@ -1027,8 +1059,12 @@ struct StoreOpConversion
|
|||||||
MLIRContext *ctx = rewriter.getContext();
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
|
||||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!valueTy)
|
if (!valueTy) {
|
||||||
return failure();
|
store(llValue, llPtr);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
Type valueElemTy =
|
Type valueElemTy =
|
||||||
getTypeConverter()->convertType(valueTy.getElementType());
|
getTypeConverter()->convertType(valueTy.getElementType());
|
||||||
|
|
||||||
@@ -1225,6 +1261,360 @@ struct BroadcastOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// ====================== reduce codegen begin ==========================
|
||||||
|
|
||||||
|
struct ReduceOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
|
||||||
|
public:
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
RedOp redOp, Value &acc, Value cur, bool isFirst) const;
|
||||||
|
|
||||||
|
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
||||||
|
int i) const;
|
||||||
|
|
||||||
|
// Use shared memory for reduction within warps and across warps
|
||||||
|
LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
|
// Use warp shuffle for reduction within warps and shared memory for data
|
||||||
|
// exchange across warps
|
||||||
|
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
|
auto rank = srcTy.getShape().size();
|
||||||
|
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
||||||
|
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||||
|
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, RedOp redOp, Value &acc,
|
||||||
|
Value cur, bool isFirst) const {
|
||||||
|
if (isFirst) {
|
||||||
|
acc = cur;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto type = cur.getType();
|
||||||
|
switch (redOp) {
|
||||||
|
case RedOp::ADD:
|
||||||
|
acc = add(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::MAX:
|
||||||
|
if (type.isUnsignedInteger())
|
||||||
|
acc = umax(acc, cur);
|
||||||
|
else
|
||||||
|
acc = smax(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::MIN:
|
||||||
|
if (type.isUnsignedInteger())
|
||||||
|
acc = umin(acc, cur);
|
||||||
|
else
|
||||||
|
acc = smin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::FADD:
|
||||||
|
acc = fadd(acc.getType(), acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::FMAX:
|
||||||
|
acc = fmax(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::FMIN:
|
||||||
|
acc = fmin(acc, cur);
|
||||||
|
break;
|
||||||
|
case RedOp::XOR:
|
||||||
|
acc = xor_(acc, cur);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
llvm::report_fatal_error("Unsupported reduce op");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Value val, int i) const {
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
if (bits == 64) {
|
||||||
|
Type vecTy = vec_ty(f32_ty, 2);
|
||||||
|
Value vec = bitcast(vecTy, val);
|
||||||
|
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||||
|
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||||
|
val0 = shflSync(rewriter, loc, val0, i);
|
||||||
|
val1 = shflSync(rewriter, loc, val1, i);
|
||||||
|
vec = undef(vecTy);
|
||||||
|
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||||
|
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||||
|
return bitcast(val.getType(), vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||||
|
auto *dOpr = builder.newOperand("=r");
|
||||||
|
auto *aOpr = builder.newOperand(val, "r");
|
||||||
|
auto *bOpr = builder.newConstantOperand(i);
|
||||||
|
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||||
|
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||||
|
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||||
|
return builder.launch(rewriter, loc, val.getType(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||||
|
triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
unsigned axis = op.axis();
|
||||||
|
|
||||||
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
|
auto srcShape = srcTy.getShape();
|
||||||
|
|
||||||
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
|
smemBase = bitcast(elemPtrTy, smemBase);
|
||||||
|
|
||||||
|
auto smemShape = getScratchConfigForReduce(op);
|
||||||
|
|
||||||
|
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
||||||
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
|
|
||||||
|
SmallVector<SmallVector<unsigned>> offset =
|
||||||
|
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||||
|
|
||||||
|
std::map<SmallVector<unsigned>, Value> accs;
|
||||||
|
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||||
|
|
||||||
|
// reduce within threads
|
||||||
|
for (unsigned i = 0; i < srcElems; ++i) {
|
||||||
|
SmallVector<unsigned> key = offset[i];
|
||||||
|
key[axis] = 0;
|
||||||
|
bool isFirst = accs.find(key) == accs.end();
|
||||||
|
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||||
|
if (isFirst)
|
||||||
|
indices[key] = srcIndices[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// cached int32 constants
|
||||||
|
std::map<int, Value> ints;
|
||||||
|
ints[0] = i32_val(0);
|
||||||
|
for (int N = smemShape[axis] / 2; N > 0; N >>= 1)
|
||||||
|
ints[N] = i32_val(N);
|
||||||
|
Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]);
|
||||||
|
|
||||||
|
// reduce across threads
|
||||||
|
for (auto it : accs) {
|
||||||
|
const SmallVector<unsigned> &key = it.first;
|
||||||
|
Value acc = it.second;
|
||||||
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
|
|
||||||
|
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||||
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||||
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
|
store(acc, writePtr);
|
||||||
|
|
||||||
|
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
||||||
|
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
||||||
|
readIdx[axis] = ints[N];
|
||||||
|
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
||||||
|
Value readOffset = select(
|
||||||
|
readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]);
|
||||||
|
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||||
|
barrier();
|
||||||
|
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
|
||||||
|
store(acc, writePtr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set output values
|
||||||
|
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
// nd-tensor where n >= 1
|
||||||
|
auto resultLayout = resultTy.getEncoding();
|
||||||
|
auto resultShape = resultTy.getShape();
|
||||||
|
|
||||||
|
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
||||||
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||||
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
SmallVector<Value> resultVals(resultElems);
|
||||||
|
for (int i = 0; i < resultElems; i++) {
|
||||||
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
|
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||||
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||||
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
|
resultVals[i] = load(readPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
||||||
|
Type structTy =
|
||||||
|
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||||
|
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
|
rewriter.replaceOp(op, ret);
|
||||||
|
} else {
|
||||||
|
// 0d-tensor -> scalar
|
||||||
|
barrier();
|
||||||
|
Value resultVal = load(smemBase);
|
||||||
|
rewriter.replaceOp(op, resultVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||||
|
triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
unsigned axis = adaptor.axis();
|
||||||
|
|
||||||
|
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||||
|
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||||
|
auto srcShape = srcTy.getShape();
|
||||||
|
auto srcOrder = srcLayout.getOrder();
|
||||||
|
|
||||||
|
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||||
|
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
|
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||||
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||||
|
smemBase = bitcast(elemPtrTy, smemBase);
|
||||||
|
|
||||||
|
auto order = srcLayout.getOrder();
|
||||||
|
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||||
|
unsigned sizeInterWarps = warpsPerCTA[axis];
|
||||||
|
|
||||||
|
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
||||||
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
|
|
||||||
|
SmallVector<SmallVector<unsigned>> offset =
|
||||||
|
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||||
|
|
||||||
|
std::map<SmallVector<unsigned>, Value> accs;
|
||||||
|
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||||
|
|
||||||
|
auto smemShape = getScratchConfigForReduce(op);
|
||||||
|
|
||||||
|
// reduce within threads
|
||||||
|
for (unsigned i = 0; i < srcElems; ++i) {
|
||||||
|
SmallVector<unsigned> key = offset[i];
|
||||||
|
key[axis] = 0;
|
||||||
|
bool isFirst = accs.find(key) == accs.end();
|
||||||
|
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||||
|
if (isFirst)
|
||||||
|
indices[key] = srcIndices[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Value threadId = getThreadId(rewriter, loc);
|
||||||
|
Value warpSize = i32_val(32);
|
||||||
|
Value warpId = udiv(threadId, warpSize);
|
||||||
|
Value laneId = urem(threadId, warpSize);
|
||||||
|
|
||||||
|
SmallVector<Value> multiDimLaneId =
|
||||||
|
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||||
|
SmallVector<Value> multiDimWarpId =
|
||||||
|
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||||
|
Value laneIdAxis = multiDimLaneId[axis];
|
||||||
|
Value warpIdAxis = multiDimWarpId[axis];
|
||||||
|
|
||||||
|
Value zero = i32_val(0);
|
||||||
|
Value laneZero = icmp_eq(laneIdAxis, zero);
|
||||||
|
Value warpZero = icmp_eq(warpIdAxis, zero);
|
||||||
|
|
||||||
|
for (auto it : accs) {
|
||||||
|
const SmallVector<unsigned> &key = it.first;
|
||||||
|
Value acc = it.second;
|
||||||
|
|
||||||
|
// reduce within warps
|
||||||
|
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||||
|
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||||
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sizeInterWarps == 1) {
|
||||||
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
|
writeIdx[axis] = zero;
|
||||||
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||||
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||||
|
} else {
|
||||||
|
SmallVector<Value> writeIdx = indices[key];
|
||||||
|
writeIdx[axis] =
|
||||||
|
warpIdAxis; // axis must be the fastest-changing dimension
|
||||||
|
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||||
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
|
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
SmallVector<Value> readIdx = writeIdx;
|
||||||
|
readIdx[axis] = urem(laneId, i32_val(sizeInterWarps));
|
||||||
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||||
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
|
acc = load(readPtr);
|
||||||
|
|
||||||
|
// reduce across warps
|
||||||
|
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||||
|
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||||
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
writeIdx[axis] = zero;
|
||||||
|
writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||||
|
writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
|
storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set output values
|
||||||
|
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
// nd-tensor where n >= 1
|
||||||
|
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||||
|
auto resultShape = resultTy.getShape();
|
||||||
|
|
||||||
|
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
||||||
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||||
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
SmallVector<Value> resultVals(resultElems);
|
||||||
|
for (int i = 0; i < resultElems; i++) {
|
||||||
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
|
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||||
|
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||||
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
|
resultVals[i] = load(readPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
||||||
|
Type structTy =
|
||||||
|
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||||
|
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
|
rewriter.replaceOp(op, ret);
|
||||||
|
} else {
|
||||||
|
// 0d-tensor -> scalar
|
||||||
|
barrier();
|
||||||
|
Value resultVal = load(smemBase);
|
||||||
|
rewriter.replaceOp(op, resultVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ====================== reduce codegen end ==========================
|
||||||
|
|
||||||
template <typename SourceOp>
|
template <typename SourceOp>
|
||||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
@@ -1738,15 +2128,16 @@ public:
|
|||||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||||
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
if ((srcLayout.isa<BlockedEncodingAttr>() ||
|
||||||
!srcLayout.isa<MmaEncodingAttr>()) ||
|
srcLayout.isa<MmaEncodingAttr>() ||
|
||||||
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
srcLayout.isa<SliceEncodingAttr>()) &&
|
||||||
!dstLayout.isa<MmaEncodingAttr>())) {
|
(dstLayout.isa<BlockedEncodingAttr>() ||
|
||||||
// TODO: to be implemented
|
dstLayout.isa<MmaEncodingAttr>() ||
|
||||||
return failure();
|
dstLayout.isa<SliceEncodingAttr>())) {
|
||||||
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
|
// TODO: to be implemented
|
||||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -1799,6 +2190,7 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||||
auto layout = type.getEncoding();
|
auto layout = type.getEncoding();
|
||||||
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||||
|
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
||||||
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
||||||
auto rank = type.getRank();
|
auto rank = type.getRank();
|
||||||
auto sizePerThread = getSizePerThread(layout);
|
auto sizePerThread = getSizePerThread(layout);
|
||||||
@@ -1816,6 +2208,18 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
if (blockedLayout) {
|
if (blockedLayout) {
|
||||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||||
loc, rewriter, blockedLayout, type.getShape());
|
loc, rewriter, blockedLayout, type.getShape());
|
||||||
|
} else if (sliceLayout) {
|
||||||
|
unsigned dim = sliceLayout.getDim();
|
||||||
|
auto parent = sliceLayout.getParent();
|
||||||
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
SmallVector<int64_t> paddedShape =
|
||||||
|
sliceLayout.paddedShape(type.getShape());
|
||||||
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||||
|
loc, rewriter, blockedParent, paddedShape);
|
||||||
|
} else {
|
||||||
|
assert(0 && "SliceEncodingAttr with parent other than "
|
||||||
|
"BlockedEncodingAttr not implemented");
|
||||||
|
}
|
||||||
} else if (mmaLayout) {
|
} else if (mmaLayout) {
|
||||||
Value threadId = getThreadId(rewriter, loc);
|
Value threadId = getThreadId(rewriter, loc);
|
||||||
Value warpSize = idx_val(32);
|
Value warpSize = idx_val(32);
|
||||||
@@ -1863,6 +2267,25 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||||
multiDimElemId[d]));
|
multiDimElemId[d]));
|
||||||
}
|
}
|
||||||
|
} else if (sliceLayout) {
|
||||||
|
unsigned dim = sliceLayout.getDim();
|
||||||
|
auto parent = sliceLayout.getParent();
|
||||||
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||||
|
elemId, blockedParent.getSizePerThread());
|
||||||
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||||
|
if (d == dim)
|
||||||
|
continue;
|
||||||
|
unsigned slicedD = d < dim ? d : (d - 1);
|
||||||
|
multiDimOffset[slicedD] =
|
||||||
|
add(multiDimOffsetFirstElem[d],
|
||||||
|
idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] +
|
||||||
|
multiDimElemId[d]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(0 && "SliceEncodingAttr with parent other than "
|
||||||
|
"BlockedEncodingAttr not implemented");
|
||||||
|
}
|
||||||
} else if (mmaLayout) {
|
} else if (mmaLayout) {
|
||||||
assert(rank == 2);
|
assert(rank == 2);
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
@@ -1952,6 +2375,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||||
barrier();
|
barrier();
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||||
|
srcLayout.isa<SliceEncodingAttr>() ||
|
||||||
srcLayout.isa<MmaEncodingAttr>()) {
|
srcLayout.isa<MmaEncodingAttr>()) {
|
||||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
||||||
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
||||||
@@ -3710,6 +4134,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
#undef POPULATE_UNARY_OP
|
#undef POPULATE_UNARY_OP
|
||||||
|
|
||||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
|
|
||||||
|
@@ -63,6 +63,19 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
|||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||||
blockedLayout.getSizePerThread().end());
|
blockedLayout.getSizePerThread().end());
|
||||||
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
|
unsigned dim = sliceLayout.getDim();
|
||||||
|
auto parent = sliceLayout.getParent();
|
||||||
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
SmallVector<unsigned> sizePerThread(
|
||||||
|
blockedParent.getSizePerThread().begin(),
|
||||||
|
blockedParent.getSizePerThread().end());
|
||||||
|
sizePerThread.erase(sizePerThread.begin() + dim);
|
||||||
|
return sizePerThread;
|
||||||
|
} else {
|
||||||
|
assert(0 && "SliceEncodingAttr with parent other than "
|
||||||
|
"BlockedEncodingAttr not implemented");
|
||||||
|
}
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
"mmaLayout version = 1 is not implemented yet");
|
"mmaLayout version = 1 is not implemented yet");
|
||||||
@@ -95,6 +108,21 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
|||||||
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
||||||
blockedLayout.getThreadsPerWarp()[d] *
|
blockedLayout.getThreadsPerWarp()[d] *
|
||||||
blockedLayout.getWarpsPerCTA()[d]);
|
blockedLayout.getWarpsPerCTA()[d]);
|
||||||
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
|
unsigned dim = sliceLayout.getDim();
|
||||||
|
auto parent = sliceLayout.getParent();
|
||||||
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||||
|
if (d == dim)
|
||||||
|
continue;
|
||||||
|
shape.push_back(blockedParent.getSizePerThread()[d] *
|
||||||
|
blockedParent.getThreadsPerWarp()[d] *
|
||||||
|
blockedParent.getWarpsPerCTA()[d]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(0 && "SliceEncodingAttr with parent other than "
|
||||||
|
"BlockedEncodingAttr not implemented");
|
||||||
|
}
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
assert(mmaLayout.getVersion() == 2 &&
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
"mmaLayout version = 1 is not implemented yet");
|
"mmaLayout version = 1 is not implemented yet");
|
||||||
@@ -206,6 +234,22 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
return product<unsigned>(elemsPerThread);
|
return product<unsigned>(elemsPerThread);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t>
|
||||||
|
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||||
|
size_t rank = shape.size();
|
||||||
|
unsigned dim = getDim();
|
||||||
|
SmallVector<int64_t> retShape(rank + 1);
|
||||||
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||||
|
if (d < dim)
|
||||||
|
retShape[d] = shape[d];
|
||||||
|
else if (d == dim)
|
||||||
|
retShape[d] = 1;
|
||||||
|
else
|
||||||
|
retShape[d] = shape[d - 1];
|
||||||
|
}
|
||||||
|
return retShape;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
auto parent = getParent();
|
auto parent = getParent();
|
||||||
@@ -213,16 +257,7 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||||
SmallVector<int64_t> paddedShape(rank + 1);
|
return blockedParent.getElemsPerThread(paddedShape(shape));
|
||||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
||||||
if (d < dim)
|
|
||||||
paddedShape[d] = shape[d];
|
|
||||||
else if (d == dim)
|
|
||||||
paddedShape[d] = 1;
|
|
||||||
else
|
|
||||||
paddedShape[d] = shape[d - 1];
|
|
||||||
}
|
|
||||||
return blockedParent.getElemsPerThread(paddedShape);
|
|
||||||
} else {
|
} else {
|
||||||
assert(0 && "getElemsPerThread not implemented");
|
assert(0 && "getElemsPerThread not implemented");
|
||||||
return 0;
|
return 0;
|
||||||
|
115
python/tests/test_reduce.py
Normal file
115
python/tests/test_reduce.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
dtype_mapping = {
|
||||||
|
'float16': torch.float16,
|
||||||
|
'float32': torch.float32,
|
||||||
|
'float64': torch.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def patch_kernel(template, to_replace):
|
||||||
|
kernel = triton.JITFunction(template.fn)
|
||||||
|
for key, value in to_replace.items():
|
||||||
|
kernel.src = kernel.src.replace(key, value)
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reduce1d_kernel(x_ptr, z_ptr, block: tl.constexpr):
|
||||||
|
x = tl.load(x_ptr + tl.arange(0, block))
|
||||||
|
tl.store(z_ptr, tl.OP(x, axis=0))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr):
|
||||||
|
range_m = tl.arange(0, block_m)
|
||||||
|
range_n = tl.arange(0, block_n)
|
||||||
|
x = tl.load(x_ptr + range_m[:, None] * block_n + range_n[None, :])
|
||||||
|
z = tl.OP(x, axis=axis)
|
||||||
|
if axis == 0:
|
||||||
|
tl.store(z_ptr + range_n, z)
|
||||||
|
else:
|
||||||
|
tl.store(z_ptr + range_m, z)
|
||||||
|
|
||||||
|
|
||||||
|
reduce1d_configs = [
|
||||||
|
(op, dtype, shape)
|
||||||
|
for op in ['sum', 'min', 'max']
|
||||||
|
for dtype in ['float16', 'float32', 'float64']
|
||||||
|
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||||
|
def test_reduce1d(op, dtype, shape):
|
||||||
|
dtype = dtype_mapping[dtype]
|
||||||
|
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||||
|
z = torch.empty(
|
||||||
|
tuple(),
|
||||||
|
device=x.device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
||||||
|
grid = (1,)
|
||||||
|
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
||||||
|
|
||||||
|
if op == 'sum':
|
||||||
|
golden_z = torch.sum(x, dtype=dtype)
|
||||||
|
elif op == 'min':
|
||||||
|
golden_z = torch.min(x)
|
||||||
|
else:
|
||||||
|
golden_z = torch.max(x)
|
||||||
|
|
||||||
|
if op == 'sum':
|
||||||
|
if shape >= 256:
|
||||||
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||||
|
elif shape >= 32:
|
||||||
|
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
||||||
|
else:
|
||||||
|
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
||||||
|
else:
|
||||||
|
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
reduce2d_configs = [
|
||||||
|
(op, dtype, shape, axis)
|
||||||
|
for op in ['sum', 'min', 'max']
|
||||||
|
for dtype in ['float16', 'float32', 'float64']
|
||||||
|
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||||
|
for axis in [0, 1]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||||
|
def test_reduce2d(op, dtype, shape, axis):
|
||||||
|
dtype = dtype_mapping[dtype]
|
||||||
|
x = torch.randn(shape, device='cuda', dtype=dtype)
|
||||||
|
reduced_shape = (shape[1 - axis],)
|
||||||
|
z = torch.empty(reduced_shape, device=x.device, dtype=dtype)
|
||||||
|
|
||||||
|
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
||||||
|
grid = (1,)
|
||||||
|
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
||||||
|
|
||||||
|
if op == 'sum':
|
||||||
|
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype)
|
||||||
|
elif op == 'min':
|
||||||
|
golden_z = torch.min(x, dim=axis, keepdim=False)[0]
|
||||||
|
else:
|
||||||
|
golden_z = torch.max(x, dim=axis, keepdim=False)[0]
|
||||||
|
|
||||||
|
if op == 'sum':
|
||||||
|
if shape[axis] >= 256:
|
||||||
|
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||||
|
elif shape[axis] >= 32:
|
||||||
|
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
||||||
|
else:
|
||||||
|
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
||||||
|
else:
|
||||||
|
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
Reference in New Issue
Block a user