#include "ReduceOpToLLVM.h" using namespace mlir; using namespace mlir::triton; using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::shflSync; using ::mlir::LLVM::storeShared; using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getOrder; struct ReduceOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern< triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (ReduceOpHelper(op).isFastReduction()) return matchAndRewriteFast(op, adaptor, rewriter); return matchAndRewriteBasic(op, adaptor, rewriter); } private: void accumulate(ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc, Value cur, bool isFirst) const { if (isFirst) { acc = cur; return; } switch (redOp) { case RedOp::ADD: acc = add(acc, cur); break; case RedOp::FADD: acc = fadd(acc.getType(), acc, cur); break; case RedOp::MIN: acc = smin(acc, cur); break; case RedOp::MAX: acc = smax(acc, cur); break; case RedOp::UMIN: acc = umin(acc, cur); break; case RedOp::UMAX: acc = umax(acc, cur); break; case RedOp::FMIN: acc = fmin(acc, cur); break; case RedOp::FMAX: acc = fmax(acc, cur); break; case RedOp::XOR: acc = xor_(acc, cur); break; case RedOp::ARGMIN: case RedOp::ARGMAX: case RedOp::ARGUMIN: case RedOp::ARGUMAX: case RedOp::ARGFMIN: case RedOp::ARGFMAX: llvm::report_fatal_error( "This accumulate implementation is not for argmin / argmax"); default: llvm::report_fatal_error("Unsupported reduce op"); } } void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc, Value &accIndex, Value cur, Value curIndex, bool isFirst) const { if (isFirst) { acc = cur; accIndex = curIndex; return; } switch (redOp) { case RedOp::ARGMIN: accIndex = select( icmp_slt(acc, cur), accIndex, select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex))); acc = smin(acc, cur); break; case RedOp::ARGMAX: accIndex = select( icmp_sgt(acc, cur), accIndex, select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex))); acc = smax(acc, cur); break; case RedOp::ARGUMIN: accIndex = select( icmp_ult(acc, cur), accIndex, select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex))); acc = umin(acc, cur); break; case RedOp::ARGUMAX: accIndex = select( icmp_ugt(acc, cur), accIndex, select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex))); acc = umax(acc, cur); break; case RedOp::ARGFMIN: accIndex = select( fcmp_olt(acc, cur), accIndex, select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex))); acc = fmin(acc, cur); break; case RedOp::ARGFMAX: accIndex = select( fcmp_ogt(acc, cur), accIndex, select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex))); acc = fmax(acc, cur); break; case RedOp::ADD: case RedOp::FADD: case RedOp::MIN: case RedOp::MAX: case RedOp::UMIN: case RedOp::UMAX: case RedOp::FMIN: case RedOp::FMAX: case RedOp::XOR: llvm::report_fatal_error( "This accumulate implementation is only for argmin / argmax"); default: llvm::report_fatal_error("Unsupported reduce op"); } } // Use shared memory for reduction within warps and across warps LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); unsigned axis = op.axis(); bool withIndex = triton::ReduceOp::withIndex(op.redOp()); auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding().cast(); auto srcOrd = srcLayout.getOrder(); auto srcShape = srcTy.getShape(); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto llvmIndexTy = getTypeConverter()->getIndexType(); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); ReduceOpHelper helper(op); auto smemShape = helper.getScratchConfigBasic(); unsigned elems = product(smemShape); Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); indexSmemBase = bitcast(indexSmemBase, indexPtrTy); unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); SmallVector> offset = emitOffsetForLayout(srcLayout, srcShape); std::map, Value> accs; std::map, Value> accIndices; std::map, SmallVector> indices; // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); if (!withIndex) { accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); } else { Value curIndex = srcIndices[i][axis]; accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key], srcValues[i], curIndex, isFirst); } if (isFirst) indices[key] = srcIndices[i]; } // cached int32 constants std::map ints; ints[0] = i32_val(0); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) ints[N] = i32_val(N); Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]); // reduce across threads for (auto it : accs) { const SmallVector &key = it.first; Value acc = it.second; Value accIndex; if (withIndex) accIndex = accIndices[key]; SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); store(acc, writePtr); if (withIndex) store(accIndex, indexWritePtr); SmallVector readIdx(writeIdx.size(), ints[0]); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { readIdx[axis] = ints[N]; Value readMask = icmp_slt(writeIdx[axis], ints[N]); Value readOffset = select( readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), ints[0]); Value readPtr = gep(elemPtrTy, writePtr, readOffset); barrier(); if (!withIndex) { Value cur = load(readPtr); accumulate(rewriter, loc, op.redOp(), acc, cur, false); barrier(); store(acc, writePtr); } else { Value cur = load(readPtr); Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset); Value curIndex = load(indexReadPtr); accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur, curIndex, false); barrier(); store(acc, writePtr); store(accIndex, indexWritePtr); } } } barrier(); // set output values if (auto resultTy = op.getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding(); auto resultShape = resultTy.getShape(); unsigned resultElems = getElemsPerThread(resultTy); auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); for (unsigned i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, ints[0]); Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); Value readPtr = gep(elemPtrTy, smemBase, readOffset); Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); } SmallVector resultTypes(resultElems, withIndex ? llvmIndexTy : llvmElemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes); Value ret = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, ret); } else { // 0d-tensor -> scalar Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); rewriter.replaceOp(op, resultVal); } return success(); } // Use warp shuffle for reduction within warps and shared memory for data // exchange across warps LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); unsigned axis = adaptor.axis(); bool withIndex = triton::ReduceOp::withIndex(op.redOp()); auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto srcRank = srcTy.getRank(); auto order = getOrder(srcLayout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto llvmIndexTy = getTypeConverter()->getIndexType(); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); ReduceOpHelper helper(op); auto smemShapes = helper.getScratchConfigsFast(); unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); indexSmemBase = bitcast(indexSmemBase, indexPtrTy); unsigned sizeIntraWarps = helper.getIntraWarpSize(); unsigned sizeInterWarps = helper.getInterWarpSize(); unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); SmallVector> offset = emitOffsetForLayout(srcLayout, srcShape); std::map, Value> accs; std::map, Value> accIndices; std::map, SmallVector> indices; // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); if (!withIndex) { accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); } else { Value curIndex = srcIndices[i][axis]; accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key], srcValues[i], curIndex, isFirst); } if (isFirst) indices[key] = srcIndices[i]; } Value threadId = getThreadId(rewriter, loc); Value warpSize = i32_val(32); Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); Value laneIdAxis = multiDimLaneId[axis]; Value warpIdAxis = multiDimWarpId[axis]; Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); Value warpZero = icmp_eq(warpIdAxis, zero); for (auto it : accs) { const SmallVector &key = it.first; Value acc = it.second; Value accIndex; if (withIndex) accIndex = accIndices[key]; // Reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { Value shfl = shflSync(loc, rewriter, acc, N); if (!withIndex) { accumulate(rewriter, loc, op.redOp(), acc, shfl, false); } else { Value shflIndex = shflSync(loc, rewriter, accIndex, N); accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, shflIndex, false); } } SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); storeShared(rewriter, loc, writePtr, acc, laneZero); if (withIndex) { Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero); } } barrier(); // The second round of shuffle reduction // now the problem size: sizeInterWarps, s1, s2, .. , sn // where sizeInterWarps is 2^m // // Each thread needs to process: // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads unsigned numThreads = product(triton::gpu::getWarpsPerCTA(srcLayout)) * 32; unsigned elemsPerThread = std::max(elems / numThreads, 1); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { Value readPtr = gep(elemPtrTy, smemBase, readOffset); // FIXME(Qingyi): need predicate icmp_slt(threadId, // i32_val(sizeInerWarps)) Value acc = load(readPtr); Value accIndex; if (withIndex) { Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset); accIndex = load(readIndexPtr); } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { Value shfl = shflSync(loc, rewriter, acc, N); if (!withIndex) { accumulate(rewriter, loc, op.redOp(), acc, shfl, false); } else { Value shflIndex = shflSync(loc, rewriter, accIndex, N); accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, shflIndex, false); } } // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; Value writePtr = gep(elemPtrTy, smemBase, writeOffset); Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); storeShared(rewriter, loc, writePtr, acc, pred); if (withIndex) { Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset); storeShared(rewriter, loc, writeIndexPtr, accIndex, pred); } if (round != elemsPerThread - 1) { readOffset = add(readOffset, i32_val(numThreads)); } } // We could avoid this barrier in some of the layouts, however this is not // the general case. // TODO: optimize the barrier incase the layouts are accepted. barrier(); // set output values if (auto resultTy = op.getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding().cast(); auto resultShape = resultTy.getShape(); unsigned resultElems = getElemsPerThread(resultTy); auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); for (size_t i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, i32_val(0)); Value readOffset = linearize(rewriter, loc, readIdx, smemShapes[0], order); Value readPtr = gep(elemPtrTy, smemBase, readOffset); Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); } SmallVector resultTypes(resultElems, withIndex ? llvmIndexTy : llvmElemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes); Value ret = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, ret); } else { // 0d-tensor -> scalar Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); rewriter.replaceOp(op, resultVal); } return success(); } }; void populateReduceOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, const Allocation *allocation, Value smem, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { patterns.add(typeConverter, allocation, smem, indexCacheInfo, benefit); }