From 42db3538e4257cac70c6e9c209214bef0a43ca98 Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Fri, 28 Oct 2022 11:07:45 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774) What is done in this PR: - [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA` implementation for `SliceEncodingAttr` - [x] Split `emitIndices` into two phases: `emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout` - [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation - [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with ptx instruction `shfl.sync` - [x] Add support for scalar value in `StoreOpConversion` - [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests Co-authored-by: Qingyi Liu --- include/triton/Analysis/Allocation.h | 3 + .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 6 + .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 4 +- lib/Analysis/Allocation.cpp | 43 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 511 ++++++++++++++++-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 55 +- python/tests/test_reduce.py | 115 ++++ 7 files changed, 680 insertions(+), 57 deletions(-) create mode 100644 python/tests/test_reduce.py diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index cb4e77228..f4d6d102f 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -6,6 +6,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Support/raw_ostream.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include #include @@ -19,6 +20,8 @@ SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec); +SmallVector getScratchConfigForReduce(triton::ReduceOp op); + } // namespace triton /// Modified from llvm-15.0: llvm/ADT/AddressRanges.h diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index ed051f522..82c20a639 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -250,6 +250,12 @@ struct PTXIOInstr : public PTXInstrBase { return *this; } + // Add ".shared" suffix to instruction + PTXIOInstr &shared(bool predicate = true) { + o("shared", predicate); + return *this; + } + // Add ".v" suffix to instruction PTXIOInstr &v(int vecWidth, bool predicate = true) { if (vecWidth > 1) { diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index dd301b904..0b2ec56c5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -324,7 +324,9 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { "Attribute":$parent ); - let extraClassDeclaration = extraBaseClassDeclaration; + let extraClassDeclaration = extraBaseClassDeclaration # [{ + SmallVector paddedShape(ArrayRef shape) const; + }]; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index e4efee3ce..f204a6ade 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -14,6 +14,7 @@ using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; namespace mlir { @@ -33,6 +34,10 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, "Unexpect layout in getScratchConfigForCvtLayout()"); unsigned rank = dstTy.getRank(); SmallVector paddedRepShape(rank); + if (auto srcSliceLayout = srcLayout.dyn_cast()) + srcLayout = srcSliceLayout.getParent(); + if (auto dstSliceLayout = dstLayout.dyn_cast()) + dstLayout = dstSliceLayout.getParent(); auto srcBlockedLayout = srcLayout.dyn_cast(); auto srcMmaLayout = srcLayout.dyn_cast(); auto dstBlockedLayout = dstLayout.dyn_cast(); @@ -73,6 +78,31 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, return paddedRepShape; } +SmallVector getScratchConfigForReduce(triton::ReduceOp op) { + auto srcTy = op.operand().getType().cast(); + auto srcLayout = srcTy.getEncoding().cast(); + auto srcShape = srcTy.getShape(); + auto rank = srcShape.size(); + auto axis = op.axis(); + + bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension + + SmallVector smemShape; + for (auto d : srcShape) + smemShape.push_back(d); + + if (fast_reduce) { + unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis]; + smemShape[axis] = sizeInterWarps; + } else { + unsigned threadsPerCTAAxis = + srcLayout.getThreadsPerWarp()[axis] * srcLayout.getWarpsPerCTA()[axis]; + smemShape[axis] = threadsPerCTAAxis; + } + + return smemShape; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation *allocation) @@ -127,9 +157,16 @@ private: // TODO(Keren): Reduce with index is not supported yet. auto value = op->getOperand(0); if (auto tensorType = value.getType().dyn_cast()) { - auto bytes = tensorType.getNumElements() * - tensorType.getElementTypeBitWidth() / 8; - allocation->addBuffer(op, bytes); + if (tensorType.getEncoding().isa()) { + auto smemShape = getScratchConfigForReduce(reduceOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), + 1, std::multiplies{}); + auto bytes = elems * tensorType.getElementTypeBitWidth() / 8; + allocation->addBuffer(op, bytes); + } else { + assert(0 && "ReduceOp with input layout other than blocked layout is " + "not implemented yet"); + } } } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.src().getType().cast(); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ba98f0bba..fda2f4a61 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -76,7 +76,15 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) #define add(...) rewriter.create(loc, __VA_ARGS__) +#define fadd(...) rewriter.create(loc, __VA_ARGS__) #define mul(...) rewriter.create(loc, __VA_ARGS__) +#define smax(...) rewriter.create(loc, __VA_ARGS__) +#define umax(...) rewriter.create(loc, __VA_ARGS__) +#define fmax(...) rewriter.create(loc, __VA_ARGS__) +#define smin(...) rewriter.create(loc, __VA_ARGS__) +#define umin(...) rewriter.create(loc, __VA_ARGS__) +#define fmin(...) rewriter.create(loc, __VA_ARGS__) +#define and_(...) rewriter.create(loc, __VA_ARGS__) #define xor_(...) rewriter.create(loc, __VA_ARGS__) #define bitcast(...) rewriter.create(loc, __VA_ARGS__) #define gep(...) rewriter.create(loc, __VA_ARGS__) @@ -89,11 +97,16 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, rewriter.create(loc, __VA_ARGS__) #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(val, ptr) rewriter.create(loc, val, ptr) +#define icmp_eq(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) +#define icmp_slt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) #define select(...) rewriter.create(loc, __VA_ARGS__) #define address_of(...) rewriter.create(loc, __VA_ARGS__) #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) +#define f32_ty rewriter.getF32Type() #define vec_ty(type, num) VectorType::get(num, type) #define void_ty LLVM::LLVMVoidType::get(ctx) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__) @@ -336,6 +349,20 @@ static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { return linearIndex; } +Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred) { + MLIRContext *ctx = rewriter.getContext(); + unsigned bits = val.getType().getIntOrFloatBitWidth(); + const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); + + PTXBuilder builder; + auto &st = builder.create("st")->shared().b(bits); + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + auto *valOpr = builder.newOperand(val, c); + st(ptrOpr, valOpr).predicate(pred, "b"); + return builder.launch(rewriter, loc, void_ty); +} + struct ConvertTritonGPUOpToLLVMPatternBase { static SmallVector getElementsFromStruct(Location loc, Value llvmStruct, @@ -504,17 +531,8 @@ public: unsigned dim = sliceLayout.getDim(); size_t rank = shape.size(); if (auto blockedParent = parent.dyn_cast()) { - SmallVector paddedShape(rank + 1); - for (unsigned d = 0; d < rank + 1; ++d) { - if (d < dim) - paddedShape[d] = shape[d]; - else if (d == dim) - paddedShape[d] = 1; - else - paddedShape[d] = shape[d - 1]; - } auto paddedIndices = emitIndicesForBlockedLayout( - loc, rewriter, blockedParent, paddedShape); + loc, rewriter, blockedParent, sliceLayout.paddedShape(shape)); unsigned numIndices = paddedIndices.size(); SmallVector> resultIndices(numIndices); for (unsigned i = 0; i < numIndices; ++i) @@ -536,31 +554,19 @@ public: } } - // Emit indices calculation within each ConversionPattern, and returns a - // [elemsPerThread X rank] index matrix. - // TODO: [goostavz] Double confirm the redundant indices calculations will - // be eliminated in the consequent MLIR/LLVM optimization. We might - // implement a indiceCache if necessary. - SmallVector> - emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, - const BlockedEncodingAttr &blockedLayout, - ArrayRef shape) const { - auto llvmIndexTy = this->getTypeConverter()->getIndexType(); + SmallVector> + emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, + ArrayRef shape) const { auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + unsigned rank = shape.size(); SmallVector shapePerCTA = getShapePerCTA(blockedLayout); SmallVector tilesPerDim(rank); for (unsigned k = 0; k < rank; ++k) tilesPerDim[k] = ceil(shape[k], shapePerCTA[k]); - // step 1, delinearize threadId to get the base index - auto multiDimBase = - emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); - - // step 2, get offset of each element - unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape); SmallVector> offset(rank); for (unsigned k = 0; k < rank; ++k) { // 1 block in minimum if shape[k] is less than shapePerCTA[k] @@ -577,12 +583,10 @@ public: threadsPerWarp[k] + threadOffset * sizePerThread[k] + elemOffset); } - // step 3, add offset to base, and reorder the sequence of indices to - // guarantee that elems in the same sizePerThread are adjacent in order - SmallVector> multiDimIdx(elemsPerThread, - SmallVector(rank)); - unsigned totalSizePerThread = product(sizePerThread); + unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape); + unsigned totalSizePerThread = product(sizePerThread); + SmallVector> reorderedOffset(elemsPerThread); for (unsigned n = 0; n < elemsPerThread; ++n) { unsigned linearNanoTileId = n / totalSizePerThread; unsigned linearNanoTileElemId = n % totalSizePerThread; @@ -595,10 +599,38 @@ public: multiDimNanoTileId[k] * (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + multiDimNanoTileElemId[k]; - multiDimIdx[n][k] = - add(multiDimBase[k], idx_val(offset[k][reorderedMultiDimId])); + reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]); } } + return reorderedOffset; + } + + // Emit indices calculation within each ConversionPattern, and returns a + // [elemsPerThread X rank] index matrix. + // TODO: [goostavz] Double confirm the redundant indices calculations will + // be eliminated in the consequent MLIR/LLVM optimization. We might + // implement a indiceCache if necessary. + SmallVector> + emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, + const BlockedEncodingAttr &blockedLayout, + ArrayRef shape) const { + // step 1, delinearize threadId to get the base index + auto multiDimBase = + emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + + // step 2, get offset of each element + SmallVector> offset = + emitOffsetForBlockedLayout(blockedLayout, shape); + + // step 3, add offset to base, and reorder the sequence of indices to + // guarantee that elems in the same sizePerThread are adjacent in order + unsigned rank = shape.size(); + unsigned elemsPerThread = offset.size(); + SmallVector> multiDimIdx(elemsPerThread, + SmallVector(rank)); + for (unsigned n = 0; n < elemsPerThread; ++n) + for (unsigned k = 0; k < rank; ++k) + multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k])); return multiDimIdx; } @@ -1027,8 +1059,12 @@ struct StoreOpConversion MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType().dyn_cast(); - if (!valueTy) - return failure(); + if (!valueTy) { + store(llValue, llPtr); + rewriter.eraseOp(op); + return success(); + } + Type valueElemTy = getTypeConverter()->convertType(valueTy.getElementType()); @@ -1225,6 +1261,360 @@ struct BroadcastOpConversion } }; +/// ====================== reduce codegen begin ========================== + +struct ReduceOpConversion + : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +private: + void accumulate(ConversionPatternRewriter &rewriter, Location loc, + RedOp redOp, Value &acc, Value cur, bool isFirst) const; + + Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const; + + // Use shared memory for reduction within warps and across warps + LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + // Use warp shuffle for reduction within warps and shared memory for data + // exchange across warps + LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; +}; + +LogicalResult +ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto srcTy = op.operand().getType().cast(); + auto rank = srcTy.getShape().size(); + if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension + return matchAndRewriteFast(op, adaptor, rewriter); + return matchAndRewriteBasic(op, adaptor, rewriter); +} + +void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter, + Location loc, RedOp redOp, Value &acc, + Value cur, bool isFirst) const { + if (isFirst) { + acc = cur; + return; + } + auto type = cur.getType(); + switch (redOp) { + case RedOp::ADD: + acc = add(acc, cur); + break; + case RedOp::MAX: + if (type.isUnsignedInteger()) + acc = umax(acc, cur); + else + acc = smax(acc, cur); + break; + case RedOp::MIN: + if (type.isUnsignedInteger()) + acc = umin(acc, cur); + else + acc = smin(acc, cur); + break; + case RedOp::FADD: + acc = fadd(acc.getType(), acc, cur); + break; + case RedOp::FMAX: + acc = fmax(acc, cur); + break; + case RedOp::FMIN: + acc = fmin(acc, cur); + break; + case RedOp::XOR: + acc = xor_(acc, cur); + break; + default: + llvm::report_fatal_error("Unsupported reduce op"); + } +}; + +Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, + Location loc, Value val, int i) const { + MLIRContext *ctx = rewriter.getContext(); + unsigned bits = val.getType().getIntOrFloatBitWidth(); + + if (bits == 64) { + Type vecTy = vec_ty(f32_ty, 2); + Value vec = bitcast(vecTy, val); + Value val0 = extract_element(f32_ty, vec, i32_val(0)); + Value val1 = extract_element(f32_ty, vec, i32_val(1)); + val0 = shflSync(rewriter, loc, val0, i); + val1 = shflSync(rewriter, loc, val1, i); + vec = undef(vecTy); + vec = insert_element(vecTy, vec, val0, i32_val(0)); + vec = insert_element(vecTy, vec, val1, i32_val(1)); + return bitcast(val.getType(), vec); + } + + PTXBuilder builder; + auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32"); + auto *dOpr = builder.newOperand("=r"); + auto *aOpr = builder.newOperand(val, "r"); + auto *bOpr = builder.newConstantOperand(i); + auto *cOpr = builder.newConstantOperand("0x1f"); + auto *maskOpr = builder.newConstantOperand("0xffffffff"); + shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); + return builder.launch(rewriter, loc, val.getType(), false); +} + +LogicalResult ReduceOpConversion::matchAndRewriteBasic( + triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + unsigned axis = op.axis(); + + auto srcTy = op.operand().getType().cast(); + auto srcLayout = srcTy.getEncoding().cast(); + auto srcShape = srcTy.getShape(); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + smemBase = bitcast(elemPtrTy, smemBase); + + auto smemShape = getScratchConfigForReduce(op); + + unsigned srcElems = getElemsPerThread(srcLayout, srcShape); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); + auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); + + SmallVector> offset = + emitOffsetForBlockedLayout(srcLayout, srcShape); + + std::map, Value> accs; + std::map, SmallVector> indices; + + // reduce within threads + for (unsigned i = 0; i < srcElems; ++i) { + SmallVector key = offset[i]; + key[axis] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + + // cached int32 constants + std::map ints; + ints[0] = i32_val(0); + for (int N = smemShape[axis] / 2; N > 0; N >>= 1) + ints[N] = i32_val(N); + Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]); + + // reduce across threads + for (auto it : accs) { + const SmallVector &key = it.first; + Value acc = it.second; + SmallVector writeIdx = indices[key]; + + writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + store(acc, writePtr); + + SmallVector readIdx(writeIdx.size(), ints[0]); + for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { + readIdx[axis] = ints[N]; + Value readMask = icmp_slt(writeIdx[axis], ints[N]); + Value readOffset = select( + readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]); + Value readPtr = gep(elemPtrTy, writePtr, readOffset); + barrier(); + accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false); + store(acc, writePtr); + } + } + + // set output values + if (auto resultTy = op.getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding(); + auto resultShape = resultTy.getShape(); + + unsigned resultElems = getElemsPerThread(resultLayout, resultShape); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); + assert(resultIndices.size() == resultElems); + + barrier(); + SmallVector resultVals(resultElems); + for (int i = 0; i < resultElems; i++) { + SmallVector readIdx = resultIndices[i]; + readIdx.insert(readIdx.begin() + axis, ints[0]); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape); + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + resultVals[i] = load(readPtr); + } + + SmallVector resultTypes(resultElems, llvmElemTy); + Type structTy = + LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes); + Value ret = getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, ret); + } else { + // 0d-tensor -> scalar + barrier(); + Value resultVal = load(smemBase); + rewriter.replaceOp(op, resultVal); + } + + return success(); +} + +LogicalResult ReduceOpConversion::matchAndRewriteFast( + triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + unsigned axis = adaptor.axis(); + + auto srcTy = op.operand().getType().cast(); + auto srcLayout = srcTy.getEncoding().cast(); + auto srcShape = srcTy.getShape(); + auto srcOrder = srcLayout.getOrder(); + + auto threadsPerWarp = srcLayout.getThreadsPerWarp(); + auto warpsPerCTA = srcLayout.getWarpsPerCTA(); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + smemBase = bitcast(elemPtrTy, smemBase); + + auto order = srcLayout.getOrder(); + unsigned sizeIntraWarps = threadsPerWarp[axis]; + unsigned sizeInterWarps = warpsPerCTA[axis]; + + unsigned srcElems = getElemsPerThread(srcLayout, srcShape); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); + auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); + + SmallVector> offset = + emitOffsetForBlockedLayout(srcLayout, srcShape); + + std::map, Value> accs; + std::map, SmallVector> indices; + + auto smemShape = getScratchConfigForReduce(op); + + // reduce within threads + for (unsigned i = 0; i < srcElems; ++i) { + SmallVector key = offset[i]; + key[axis] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneIdAxis, zero); + Value warpZero = icmp_eq(warpIdAxis, zero); + + for (auto it : accs) { + const SmallVector &key = it.first; + Value acc = it.second; + + // reduce within warps + for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { + Value shfl = shflSync(rewriter, loc, acc, N); + accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + } + + if (sizeInterWarps == 1) { + SmallVector writeIdx = indices[key]; + writeIdx[axis] = zero; + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + storeShared(rewriter, loc, writePtr, acc, laneZero); + } else { + SmallVector writeIdx = indices[key]; + writeIdx[axis] = + warpIdAxis; // axis must be the fastest-changing dimension + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + storeShared(rewriter, loc, writePtr, acc, laneZero); + barrier(); + + SmallVector readIdx = writeIdx; + readIdx[axis] = urem(laneId, i32_val(sizeInterWarps)); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape); + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + acc = load(readPtr); + + // reduce across warps + for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { + Value shfl = shflSync(rewriter, loc, acc, N); + accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + } + + writeIdx[axis] = zero; + writeOffset = linearize(rewriter, loc, writeIdx, smemShape); + writePtr = gep(elemPtrTy, smemBase, writeOffset); + storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero)); + } + } + + // set output values + if (auto resultTy = op.getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding().cast(); + auto resultShape = resultTy.getShape(); + + unsigned resultElems = getElemsPerThread(resultLayout, resultShape); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); + assert(resultIndices.size() == resultElems); + + barrier(); + SmallVector resultVals(resultElems); + for (int i = 0; i < resultElems; i++) { + SmallVector readIdx = resultIndices[i]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape); + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + resultVals[i] = load(readPtr); + } + + SmallVector resultTypes(resultElems, llvmElemTy); + Type structTy = + LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes); + Value ret = getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, ret); + } else { + // 0d-tensor -> scalar + barrier(); + Value resultVal = load(smemBase); + rewriter.replaceOp(op, resultVal); + } + + return success(); +} + +/// ====================== reduce codegen end ========================== + template struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { using OpAdaptor = typename SourceOp::Adaptor; @@ -1738,15 +2128,16 @@ public: dstLayout.isa()) { return lowerSharedToDotOperand(op, adaptor, rewriter); } - if ((!srcLayout.isa() && - !srcLayout.isa()) || - (!dstLayout.isa() && - !dstLayout.isa())) { - // TODO: to be implemented - return failure(); + if ((srcLayout.isa() || + srcLayout.isa() || + srcLayout.isa()) && + (dstLayout.isa() || + dstLayout.isa() || + dstLayout.isa())) { + return lowerDistributedToDistributed(op, adaptor, rewriter); } - - return lowerDistributedToDistributed(op, adaptor, rewriter); + // TODO: to be implemented + return failure(); } private: @@ -1799,6 +2190,7 @@ void ConvertLayoutOpConversion::processReplica( unsigned accumNumCTAsEachRep = product(numCTAsEachRep); auto layout = type.getEncoding(); auto blockedLayout = layout.dyn_cast(); + auto sliceLayout = layout.dyn_cast(); auto mmaLayout = layout.dyn_cast(); auto rank = type.getRank(); auto sizePerThread = getSizePerThread(layout); @@ -1816,6 +2208,18 @@ void ConvertLayoutOpConversion::processReplica( if (blockedLayout) { multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( loc, rewriter, blockedLayout, type.getShape()); + } else if (sliceLayout) { + unsigned dim = sliceLayout.getDim(); + auto parent = sliceLayout.getParent(); + if (auto blockedParent = parent.dyn_cast()) { + SmallVector paddedShape = + sliceLayout.paddedShape(type.getShape()); + multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout( + loc, rewriter, blockedParent, paddedShape); + } else { + assert(0 && "SliceEncodingAttr with parent other than " + "BlockedEncodingAttr not implemented"); + } } else if (mmaLayout) { Value threadId = getThreadId(rewriter, loc); Value warpSize = idx_val(32); @@ -1863,6 +2267,25 @@ void ConvertLayoutOpConversion::processReplica( idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] + multiDimElemId[d])); } + } else if (sliceLayout) { + unsigned dim = sliceLayout.getDim(); + auto parent = sliceLayout.getParent(); + if (auto blockedParent = parent.dyn_cast()) { + SmallVector multiDimElemId = getMultiDimIndex( + elemId, blockedParent.getSizePerThread()); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d == dim) + continue; + unsigned slicedD = d < dim ? d : (d - 1); + multiDimOffset[slicedD] = + add(multiDimOffsetFirstElem[d], + idx_val(multiDimCTAInRepId[slicedD] * shapePerCTA[slicedD] + + multiDimElemId[d])); + } + } else { + assert(0 && "SliceEncodingAttr with parent other than " + "BlockedEncodingAttr not implemented"); + } } else if (mmaLayout) { assert(rank == 2); assert(mmaLayout.getVersion() == 2 && @@ -1952,6 +2375,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( auto multiDimRepId = getMultiDimIndex(repId, numReplicates); barrier(); if (srcLayout.isa() || + srcLayout.isa() || srcLayout.isa()) { processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, outOrd, vals, @@ -3710,6 +4134,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, #undef POPULATE_UNARY_OP patterns.add(typeConverter, benefit); + patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, allocation, smem, benefit); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 0d9bc568d..2049a18b1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -63,6 +63,19 @@ SmallVector getSizePerThread(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); + } else if (auto sliceLayout = layout.dyn_cast()) { + unsigned dim = sliceLayout.getDim(); + auto parent = sliceLayout.getParent(); + if (auto blockedParent = parent.dyn_cast()) { + SmallVector sizePerThread( + blockedParent.getSizePerThread().begin(), + blockedParent.getSizePerThread().end()); + sizePerThread.erase(sizePerThread.begin() + dim); + return sizePerThread; + } else { + assert(0 && "SliceEncodingAttr with parent other than " + "BlockedEncodingAttr not implemented"); + } } else if (auto mmaLayout = layout.dyn_cast()) { assert(mmaLayout.getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); @@ -95,6 +108,21 @@ SmallVector getShapePerCTA(const Attribute &layout) { shape.push_back(blockedLayout.getSizePerThread()[d] * blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getWarpsPerCTA()[d]); + } else if (auto sliceLayout = layout.dyn_cast()) { + unsigned dim = sliceLayout.getDim(); + auto parent = sliceLayout.getParent(); + if (auto blockedParent = parent.dyn_cast()) { + for (int d = 0, n = blockedParent.getOrder().size(); d < n; ++d) { + if (d == dim) + continue; + shape.push_back(blockedParent.getSizePerThread()[d] * + blockedParent.getThreadsPerWarp()[d] * + blockedParent.getWarpsPerCTA()[d]); + } + } else { + assert(0 && "SliceEncodingAttr with parent other than " + "BlockedEncodingAttr not implemented"); + } } else if (auto mmaLayout = layout.dyn_cast()) { assert(mmaLayout.getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); @@ -206,6 +234,22 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef shape) const { return product(elemsPerThread); } +SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} + unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { size_t rank = shape.size(); auto parent = getParent(); @@ -213,16 +257,7 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { if (auto blockedParent = parent.dyn_cast()) { assert(rank == blockedParent.getSizePerThread().size() - 1 && "unexpected rank in SliceEncodingAttr::getElemsPerThread"); - SmallVector paddedShape(rank + 1); - for (unsigned d = 0; d < rank + 1; ++d) { - if (d < dim) - paddedShape[d] = shape[d]; - else if (d == dim) - paddedShape[d] = 1; - else - paddedShape[d] = shape[d - 1]; - } - return blockedParent.getElemsPerThread(paddedShape); + return blockedParent.getElemsPerThread(paddedShape(shape)); } else { assert(0 && "getElemsPerThread not implemented"); return 0; diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py new file mode 100644 index 000000000..0c92b8a91 --- /dev/null +++ b/python/tests/test_reduce.py @@ -0,0 +1,115 @@ +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +dtype_mapping = { + 'float16': torch.float16, + 'float32': torch.float32, + 'float64': torch.float64, +} + + +def patch_kernel(template, to_replace): + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +@triton.jit +def reduce1d_kernel(x_ptr, z_ptr, block: tl.constexpr): + x = tl.load(x_ptr + tl.arange(0, block)) + tl.store(z_ptr, tl.OP(x, axis=0)) + + +@triton.jit +def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr): + range_m = tl.arange(0, block_m) + range_n = tl.arange(0, block_n) + x = tl.load(x_ptr + range_m[:, None] * block_n + range_n[None, :]) + z = tl.OP(x, axis=axis) + if axis == 0: + tl.store(z_ptr + range_n, z) + else: + tl.store(z_ptr + range_m, z) + + +reduce1d_configs = [ + (op, dtype, shape) + for op in ['sum', 'min', 'max'] + for dtype in ['float16', 'float32', 'float64'] + for shape in [4, 8, 16, 32, 64, 128, 512, 1024] +] + + +@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs) +def test_reduce1d(op, dtype, shape): + dtype = dtype_mapping[dtype] + x = torch.randn((shape,), device='cuda', dtype=dtype) + z = torch.empty( + tuple(), + device=x.device, + dtype=dtype, + ) + + kernel = patch_kernel(reduce1d_kernel, {'OP': op}) + grid = (1,) + kernel[grid](x_ptr=x, z_ptr=z, block=shape) + + if op == 'sum': + golden_z = torch.sum(x, dtype=dtype) + elif op == 'min': + golden_z = torch.min(x) + else: + golden_z = torch.max(x) + + if op == 'sum': + if shape >= 256: + assert_close(z, golden_z, rtol=0.05, atol=0.1) + elif shape >= 32: + assert_close(z, golden_z, rtol=0.05, atol=0.02) + else: + assert_close(z, golden_z, rtol=0.01, atol=0.01) + else: + assert_close(z, golden_z, rtol=0.001, atol=0.001) + + +reduce2d_configs = [ + (op, dtype, shape, axis) + for op in ['sum', 'min', 'max'] + for dtype in ['float16', 'float32', 'float64'] + for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)] + for axis in [0, 1] +] + + +@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs) +def test_reduce2d(op, dtype, shape, axis): + dtype = dtype_mapping[dtype] + x = torch.randn(shape, device='cuda', dtype=dtype) + reduced_shape = (shape[1 - axis],) + z = torch.empty(reduced_shape, device=x.device, dtype=dtype) + + kernel = patch_kernel(reduce2d_kernel, {'OP': op}) + grid = (1,) + kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1]) + + if op == 'sum': + golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=dtype) + elif op == 'min': + golden_z = torch.min(x, dim=axis, keepdim=False)[0] + else: + golden_z = torch.max(x, dim=axis, keepdim=False)[0] + + if op == 'sum': + if shape[axis] >= 256: + assert_close(z, golden_z, rtol=0.05, atol=0.1) + elif shape[axis] >= 32: + assert_close(z, golden_z, rtol=0.05, atol=0.02) + else: + assert_close(z, golden_z, rtol=0.01, atol=0.01) + else: + assert_close(z, golden_z, rtol=0.001, atol=0.001)