[Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)
What is done in this PR: - [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA` implementation for `SliceEncodingAttr` - [x] Split `emitIndices` into two phases: `emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout` - [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation - [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with ptx instruction `shfl.sync` - [x] Add support for scalar value in `StoreOpConversion` - [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
@@ -19,6 +20,8 @@ SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
|
@@ -250,6 +250,12 @@ struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".shared" suffix to instruction
|
||||
PTXIOInstr &shared(bool predicate = true) {
|
||||
o("shared", predicate);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".v" suffix to instruction
|
||||
PTXIOInstr &v(int vecWidth, bool predicate = true) {
|
||||
if (vecWidth > 1) {
|
||||
|
@@ -324,7 +324,9 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SmallVector<int64_t> paddedShape(ArrayRef<int64_t> shape) const;
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
@@ -14,6 +14,7 @@ using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -33,6 +34,10 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
if (auto srcSliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>())
|
||||
srcLayout = srcSliceLayout.getParent();
|
||||
if (auto dstSliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>())
|
||||
dstLayout = dstSliceLayout.getParent();
|
||||
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||
@@ -73,6 +78,31 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto rank = srcShape.size();
|
||||
auto axis = op.axis();
|
||||
|
||||
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||
|
||||
SmallVector<unsigned> smemShape;
|
||||
for (auto d : srcShape)
|
||||
smemShape.push_back(d);
|
||||
|
||||
if (fast_reduce) {
|
||||
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
|
||||
smemShape[axis] = sizeInterWarps;
|
||||
} else {
|
||||
unsigned threadsPerCTAAxis =
|
||||
srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis];
|
||||
smemShape[axis] = threadsPerCTAAxis;
|
||||
}
|
||||
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
class AllocationAnalysis {
|
||||
public:
|
||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||
@@ -127,9 +157,16 @@ private:
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
if (tensorType.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(),
|
||||
1, std::multiplies{});
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else {
|
||||
assert(0 && "ReduceOp with input layout other than blocked layout is "
|
||||
"not implemented yet");
|
||||
}
|
||||
}
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||
|
@@ -76,7 +76,15 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
||||
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
||||
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
||||
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
||||
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
||||
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
||||
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
||||
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(...) rewriter.create<LLVM::BitcastOp>(loc, __VA_ARGS__)
|
||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||
@@ -89,11 +97,16 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
||||
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_slt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define void_ty LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
|
||||
@@ -336,6 +349,20 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &st = builder.create<PTXIOInstr>("st")->shared().b(bits);
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty);
|
||||
}
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
@@ -504,17 +531,8 @@ public:
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
size_t rank = shape.size();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
paddedShape[d] = shape[d];
|
||||
else if (d == dim)
|
||||
paddedShape[d] = 1;
|
||||
else
|
||||
paddedShape[d] = shape[d - 1];
|
||||
}
|
||||
auto paddedIndices = emitIndicesForBlockedLayout(
|
||||
loc, rewriter, blockedParent, paddedShape);
|
||||
loc, rewriter, blockedParent, sliceLayout.paddedShape(shape));
|
||||
unsigned numIndices = paddedIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices(numIndices);
|
||||
for (unsigned i = 0; i < numIndices; ++i)
|
||||
@@ -536,31 +554,19 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||
// implement a indiceCache if necessary.
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
SmallVector<unsigned> tilesPerDim(rank);
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
|
||||
// step 2, get offset of each element
|
||||
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
||||
SmallVector<SmallVector<unsigned>> offset(rank);
|
||||
for (unsigned k = 0; k < rank; ++k) {
|
||||
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
|
||||
@@ -577,12 +583,10 @@ public:
|
||||
threadsPerWarp[k] +
|
||||
threadOffset * sizePerThread[k] + elemOffset);
|
||||
}
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||
SmallVector<Value>(rank));
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
|
||||
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
unsigned linearNanoTileId = n / totalSizePerThread;
|
||||
unsigned linearNanoTileElemId = n % totalSizePerThread;
|
||||
@@ -595,10 +599,38 @@ public:
|
||||
multiDimNanoTileId[k] *
|
||||
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
|
||||
multiDimNanoTileElemId[k];
|
||||
multiDimIdx[n][k] =
|
||||
add(multiDimBase[k], idx_val(offset[k][reorderedMultiDimId]));
|
||||
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
|
||||
}
|
||||
}
|
||||
return reorderedOffset;
|
||||
}
|
||||
|
||||
// Emit indices calculation within each ConversionPattern, and returns a
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||
// implement a indiceCache if necessary.
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
|
||||
// step 2, get offset of each element
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
unsigned rank = shape.size();
|
||||
unsigned elemsPerThread = offset.size();
|
||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||
SmallVector<Value>(rank));
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n)
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
|
||||
|
||||
return multiDimIdx;
|
||||
}
|
||||
@@ -1027,8 +1059,12 @@ struct StoreOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
if (!valueTy) {
|
||||
store(llValue, llPtr);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
@@ -1225,6 +1261,360 @@ struct BroadcastOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== reduce codegen begin ==========================
|
||||
|
||||
struct ReduceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
|
||||
public:
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RedOp redOp, Value &acc, Value cur, bool isFirst) const;
|
||||
|
||||
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
||||
int i) const;
|
||||
|
||||
// Use shared memory for reduction within warps and across warps
|
||||
LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Use warp shuffle for reduction within warps and shared memory for data
|
||||
// exchange across warps
|
||||
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto rank = srcTy.getShape().size();
|
||||
if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension
|
||||
return matchAndRewriteFast(op, adaptor, rewriter);
|
||||
return matchAndRewriteBasic(op, adaptor, rewriter);
|
||||
}
|
||||
|
||||
void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter,
|
||||
Location loc, RedOp redOp, Value &acc,
|
||||
Value cur, bool isFirst) const {
|
||||
if (isFirst) {
|
||||
acc = cur;
|
||||
return;
|
||||
}
|
||||
auto type = cur.getType();
|
||||
switch (redOp) {
|
||||
case RedOp::ADD:
|
||||
acc = add(acc, cur);
|
||||
break;
|
||||
case RedOp::MAX:
|
||||
if (type.isUnsignedInteger())
|
||||
acc = umax(acc, cur);
|
||||
else
|
||||
acc = smax(acc, cur);
|
||||
break;
|
||||
case RedOp::MIN:
|
||||
if (type.isUnsignedInteger())
|
||||
acc = umin(acc, cur);
|
||||
else
|
||||
acc = smin(acc, cur);
|
||||
break;
|
||||
case RedOp::FADD:
|
||||
acc = fadd(acc.getType(), acc, cur);
|
||||
break;
|
||||
case RedOp::FMAX:
|
||||
acc = fmax(acc, cur);
|
||||
break;
|
||||
case RedOp::FMIN:
|
||||
acc = fmin(acc, cur);
|
||||
break;
|
||||
case RedOp::XOR:
|
||||
acc = xor_(acc, cur);
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("Unsupported reduce op");
|
||||
}
|
||||
};
|
||||
|
||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value val, int i) const {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(vecTy, val);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = shflSync(rewriter, loc, val0, i);
|
||||
val1 = shflSync(rewriter, loc, val1, i);
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(val.getType(), vec);
|
||||
}
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||
auto *dOpr = builder.newOperand("=r");
|
||||
auto *aOpr = builder.newOperand(val, "r");
|
||||
auto *bOpr = builder.newConstantOperand(i);
|
||||
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
}
|
||||
|
||||
LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
unsigned axis = op.axis();
|
||||
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
|
||||
auto smemShape = getScratchConfigForReduce(op);
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
|
||||
// cached int32 constants
|
||||
std::map<int, Value> ints;
|
||||
ints[0] = i32_val(0);
|
||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1)
|
||||
ints[N] = i32_val(N);
|
||||
Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]);
|
||||
|
||||
// reduce across threads
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
Value acc = it.second;
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
|
||||
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
store(acc, writePtr);
|
||||
|
||||
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
|
||||
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
|
||||
readIdx[axis] = ints[N];
|
||||
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
|
||||
Value readOffset = select(
|
||||
readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]);
|
||||
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
|
||||
barrier();
|
||||
accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false);
|
||||
store(acc, writePtr);
|
||||
}
|
||||
}
|
||||
|
||||
// set output values
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
auto resultShape = resultTy.getShape();
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (int i = 0; i < resultElems; i++) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, ints[0]);
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
}
|
||||
|
||||
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
barrier();
|
||||
Value resultVal = load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
unsigned axis = adaptor.axis();
|
||||
|
||||
auto srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto srcOrder = srcLayout.getOrder();
|
||||
|
||||
auto threadsPerWarp = srcLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = srcLayout.getWarpsPerCTA();
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(elemPtrTy, smemBase);
|
||||
|
||||
auto order = srcLayout.getOrder();
|
||||
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||
unsigned sizeInterWarps = warpsPerCTA[axis];
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
|
||||
|
||||
auto smemShape = getScratchConfigForReduce(op);
|
||||
|
||||
// reduce within threads
|
||||
for (unsigned i = 0; i < srcElems; ++i) {
|
||||
SmallVector<unsigned> key = offset[i];
|
||||
key[axis] = 0;
|
||||
bool isFirst = accs.find(key) == accs.end();
|
||||
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
|
||||
if (isFirst)
|
||||
indices[key] = srcIndices[i];
|
||||
}
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
|
||||
SmallVector<Value> multiDimLaneId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
Value laneIdAxis = multiDimLaneId[axis];
|
||||
Value warpIdAxis = multiDimWarpId[axis];
|
||||
|
||||
Value zero = i32_val(0);
|
||||
Value laneZero = icmp_eq(laneIdAxis, zero);
|
||||
Value warpZero = icmp_eq(warpIdAxis, zero);
|
||||
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
Value acc = it.second;
|
||||
|
||||
// reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
}
|
||||
|
||||
if (sizeInterWarps == 1) {
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = zero;
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
} else {
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] =
|
||||
warpIdAxis; // axis must be the fastest-changing dimension
|
||||
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, laneZero);
|
||||
barrier();
|
||||
|
||||
SmallVector<Value> readIdx = writeIdx;
|
||||
readIdx[axis] = urem(laneId, i32_val(sizeInterWarps));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
acc = load(readPtr);
|
||||
|
||||
// reduce across warps
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
}
|
||||
|
||||
writeIdx[axis] = zero;
|
||||
writeOffset = linearize(rewriter, loc, writeIdx, smemShape);
|
||||
writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero));
|
||||
}
|
||||
}
|
||||
|
||||
// set output values
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
barrier();
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (int i = 0; i < resultElems; i++) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset = linearize(rewriter, loc, readIdx, smemShape);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
resultVals[i] = load(readPtr);
|
||||
}
|
||||
|
||||
SmallVector<Type> resultTypes(resultElems, llvmElemTy);
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
barrier();
|
||||
Value resultVal = load(smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// ====================== reduce codegen end ==========================
|
||||
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
@@ -1738,15 +2128,16 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
!srcLayout.isa<MmaEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
||||
!dstLayout.isa<MmaEncodingAttr>())) {
|
||||
// TODO: to be implemented
|
||||
return failure();
|
||||
if ((srcLayout.isa<BlockedEncodingAttr>() ||
|
||||
srcLayout.isa<MmaEncodingAttr>() ||
|
||||
srcLayout.isa<SliceEncodingAttr>()) &&
|
||||
(dstLayout.isa<BlockedEncodingAttr>() ||
|
||||
dstLayout.isa<MmaEncodingAttr>() ||
|
||||
dstLayout.isa<SliceEncodingAttr>())) {
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
|
||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||
// TODO: to be implemented
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -1799,6 +2190,7 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto layout = type.getEncoding();
|
||||
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
||||
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
@@ -1816,6 +2208,18 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
if (blockedLayout) {
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedLayout, type.getShape());
|
||||
} else if (sliceLayout) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<int64_t> paddedShape =
|
||||
sliceLayout.paddedShape(type.getShape());
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedParent, paddedShape);
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (mmaLayout) {
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = idx_val(32);
|
||||
@@ -1863,6 +2267,25 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
} else if (sliceLayout) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, blockedParent.getSizePerThread());
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
unsigned slicedD = d < dim ? d : (d - 1);
|
||||
multiDimOffset[slicedD] =
|
||||
add(multiDimOffsetFirstElem[d],
|
||||
idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (mmaLayout) {
|
||||
assert(rank == 2);
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
@@ -1952,6 +2375,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||
barrier();
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||
srcLayout.isa<SliceEncodingAttr>() ||
|
||||
srcLayout.isa<MmaEncodingAttr>()) {
|
||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
||||
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
||||
@@ -3710,6 +4134,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
#undef POPULATE_UNARY_OP
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
|
||||
|
@@ -63,6 +63,19 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
blockedLayout.getSizePerThread().end());
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
SmallVector<unsigned> sizePerThread(
|
||||
blockedParent.getSizePerThread().begin(),
|
||||
blockedParent.getSizePerThread().end());
|
||||
sizePerThread.erase(sizePerThread.begin() + dim);
|
||||
return sizePerThread;
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
@@ -95,6 +108,21 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
shape.push_back(blockedLayout.getSizePerThread()[d] *
|
||||
blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d]);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parent = sliceLayout.getParent();
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) {
|
||||
if (d == dim)
|
||||
continue;
|
||||
shape.push_back(blockedParent.getSizePerThread()[d] *
|
||||
blockedParent.getThreadsPerWarp()[d] *
|
||||
blockedParent.getWarpsPerCTA()[d]);
|
||||
}
|
||||
} else {
|
||||
assert(0 && "SliceEncodingAttr with parent other than "
|
||||
"BlockedEncodingAttr not implemented");
|
||||
}
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
@@ -206,6 +234,22 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
return product<unsigned>(elemsPerThread);
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
unsigned dim = getDim();
|
||||
SmallVector<int64_t> retShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
retShape[d] = shape[d];
|
||||
else if (d == dim)
|
||||
retShape[d] = 1;
|
||||
else
|
||||
retShape[d] = shape[d - 1];
|
||||
}
|
||||
return retShape;
|
||||
}
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
size_t rank = shape.size();
|
||||
auto parent = getParent();
|
||||
@@ -213,16 +257,7 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedParent = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
assert(rank == blockedParent.getSizePerThread().size() - 1 &&
|
||||
"unexpected rank in SliceEncodingAttr::getElemsPerThread");
|
||||
SmallVector<int64_t> paddedShape(rank + 1);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d < dim)
|
||||
paddedShape[d] = shape[d];
|
||||
else if (d == dim)
|
||||
paddedShape[d] = 1;
|
||||
else
|
||||
paddedShape[d] = shape[d - 1];
|
||||
}
|
||||
return blockedParent.getElemsPerThread(paddedShape);
|
||||
return blockedParent.getElemsPerThread(paddedShape(shape));
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
|
115
python/tests/test_reduce.py
Normal file
115
python/tests/test_reduce.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
dtype_mapping = {
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
}
|
||||
|
||||
|
||||
def patch_kernel(template, to_replace):
|
||||
kernel = triton.JITFunction(template.fn)
|
||||
for key, value in to_replace.items():
|
||||
kernel.src = kernel.src.replace(key, value)
|
||||
return kernel
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce1d_kernel(x_ptr, z_ptr, block: tl.constexpr):
|
||||
x = tl.load(x_ptr + tl.arange(0, block))
|
||||
tl.store(z_ptr, tl.OP(x, axis=0))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr):
|
||||
range_m = tl.arange(0, block_m)
|
||||
range_n = tl.arange(0, block_n)
|
||||
x = tl.load(x_ptr + range_m[:, None] * block_n + range_n[None, :])
|
||||
z = tl.OP(x, axis=axis)
|
||||
if axis == 0:
|
||||
tl.store(z_ptr + range_n, z)
|
||||
else:
|
||||
tl.store(z_ptr + range_m, z)
|
||||
|
||||
|
||||
reduce1d_configs = [
|
||||
(op, dtype, shape)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for dtype in ['float16', 'float32', 'float64']
|
||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||
def test_reduce1d(op, dtype, shape):
|
||||
dtype = dtype_mapping[dtype]
|
||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||
z = torch.empty(
|
||||
tuple(),
|
||||
device=x.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
||||
grid = (1,)
|
||||
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
||||
|
||||
if op == 'sum':
|
||||
golden_z = torch.sum(x, dtype=dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x)
|
||||
else:
|
||||
golden_z = torch.max(x)
|
||||
|
||||
if op == 'sum':
|
||||
if shape >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
elif shape >= 32:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
||||
|
||||
|
||||
reduce2d_configs = [
|
||||
(op, dtype, shape, axis)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for dtype in ['float16', 'float32', 'float64']
|
||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||
def test_reduce2d(op, dtype, shape, axis):
|
||||
dtype = dtype_mapping[dtype]
|
||||
x = torch.randn(shape, device='cuda', dtype=dtype)
|
||||
reduced_shape = (shape[1 - axis],)
|
||||
z = torch.empty(reduced_shape, device=x.device, dtype=dtype)
|
||||
|
||||
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
||||
grid = (1,)
|
||||
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
||||
|
||||
if op == 'sum':
|
||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0]
|
||||
else:
|
||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0]
|
||||
|
||||
if op == 'sum':
|
||||
if shape[axis] >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
elif shape[axis] >= 32:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
Reference in New Issue
Block a user