diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index b33ac5bf9..60151d5b6 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -20,8 +20,6 @@ SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec); -SmallVector getScratchConfigForReduce(triton::ReduceOp op); - } // namespace triton /// Modified from llvm-15.0: llvm/ADT/AddressRanges.h diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 77ebb0eaf..852e54532 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -26,6 +26,12 @@ public: unsigned getThreadsReductionAxis(); + SmallVector getScratchConfigBasic(); + + SmallVector> getScratchConfigsFast(); + + unsigned getScratchSizeInBytes(); + private: triton::ReduceOp op; RankedTensorType srcTy{}; @@ -39,6 +45,14 @@ bool maybeAliasOp(Operation *op); std::string getValueOperandName(Value value, AsmState &state); +template +inline SmallVector convertType(ArrayRef in) { + SmallVector out; + for (const T_IN &i : in) + out.push_back(T_OUT(i)); + return out; +} + template Int product(llvm::ArrayRef arr) { return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); } diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 97a015882..32512282b 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -351,6 +351,11 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect, let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; + let extraClassDeclaration = [{ + // This member function is marked static because we need to call it before the ReduceOp + // is constructed, see the implementation of create_reduce in triton.cc. + static bool withIndex(mlir::triton::RedOp redOp); + }]; } // diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index a8576d060..068956697 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -88,25 +88,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, return paddedRepShape; } -SmallVector getScratchConfigForReduce(triton::ReduceOp op) { - ReduceOpHelper helper(op); - - SmallVector smemShape; - auto srcShape = helper.getSrcShape(); - for (auto d : srcShape) - smemShape.push_back(d); - - auto axis = op.axis(); - if (helper.isFastReduction()) { - smemShape[axis] = helper.getInterWarpSize(); - } else { - smemShape[axis] = - std::min(smemShape[axis], helper.getThreadsReductionAxis()); - } - - return smemShape; -} - // TODO: extend beyond scalars SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { SmallVector smemShape; @@ -173,21 +154,9 @@ private: /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { if (auto reduceOp = dyn_cast(op)) { - // TODO(Keren): Reduce with index is not supported yet. - auto value = op->getOperand(0); - if (auto tensorType = value.getType().dyn_cast()) { - bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction(); - auto smemShape = getScratchConfigForReduce(reduceOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); - if (fastReduce) { - auto mod = op->getParentOfType(); - unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - elems = std::max(elems, numWarps * 32); - } - auto bytes = elems * tensorType.getElementTypeBitWidth() / 8; - allocation->addBuffer(op, bytes); - } + ReduceOpHelper helper(reduceOp); + unsigned bytes = helper.getScratchSizeInBytes(); + allocation->addBuffer(op, bytes); } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.src().getType().cast(); auto dstTy = cvtLayout.result().getType().cast(); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ab25a41bd..8458b5ee5 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -37,6 +37,55 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() { triton::gpu::getWarpsPerCTA(srcLayout)[axis]; } +SmallVector ReduceOpHelper::getScratchConfigBasic() { + auto axis = op.axis(); + auto smemShape = convertType(getSrcShape()); + smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis()); + return smemShape; +} + +SmallVector> ReduceOpHelper::getScratchConfigsFast() { + auto axis = op.axis(); + SmallVector> smemShapes(3); + + /// shared memory block0 + smemShapes[0] = convertType(getSrcShape()); + smemShapes[0][axis] = getInterWarpSize(); + + /// FIXME(Qingyi): This size is actually larger than required. + /// shared memory block1: + auto mod = op.getOperation()->getParentOfType(); + unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + smemShapes[1].push_back(numWarps * 32); + + /// FIXME(Qingyi): This requirement is actually not necessary, because it is + /// always smaller than smemShapes[0] shared memory block2 + smemShapes[2] = convertType(getSrcShape()); + smemShapes[2].erase(smemShapes[2].begin() + axis); + + return smemShapes; +} + +unsigned ReduceOpHelper::getScratchSizeInBytes() { + unsigned elems = 0; + if (isFastReduction()) { + auto smemShapes = getScratchConfigsFast(); + for (const auto &smemShape : smemShapes) + elems = std::max(elems, product(smemShape)); + } else { + auto smemShape = getScratchConfigBasic(); + elems = product(smemShape); + } + + auto tensorType = op.operand().getType().cast(); + unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8; + + if (triton::ReduceOp::withIndex(op.redOp())) + bytes += elems * sizeof(int32_t); + + return bytes; +} + bool isSharedEncoding(Value value) { auto type = value.getType(); if (auto tensorType = type.dyn_cast()) { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 883e718b9..322c43446 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1338,6 +1338,10 @@ private: void accumulate(ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc, Value cur, bool isFirst) const; + void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc, + RedOp redOp, Value &acc, Value &accIndex, Value cur, + Value curIndex, bool isFirst) const; + Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val, int i) const; @@ -1366,7 +1370,6 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter, acc = cur; return; } - auto type = cur.getType(); switch (redOp) { case RedOp::ADD: acc = add(acc, cur); @@ -1395,6 +1398,75 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter, case RedOp::XOR: acc = xor_(acc, cur); break; + case RedOp::ARGMIN: + case RedOp::ARGMAX: + case RedOp::ARGUMIN: + case RedOp::ARGUMAX: + case RedOp::ARGFMIN: + case RedOp::ARGFMAX: + llvm::report_fatal_error( + "This accumulate implementation is not for argmin / argmax"); + default: + llvm::report_fatal_error("Unsupported reduce op"); + } +} + +void ReduceOpConversion::accumulateWithIndex( + ConversionPatternRewriter &rewriter, Location loc, RedOp redOp, Value &acc, + Value &accIndex, Value cur, Value curIndex, bool isFirst) const { + if (isFirst) { + acc = cur; + accIndex = curIndex; + return; + } + switch (redOp) { + case RedOp::ARGMIN: + accIndex = + select(icmp_slt(acc, cur), accIndex, + select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex))); + acc = smin(acc, cur); + break; + case RedOp::ARGMAX: + accIndex = + select(icmp_sgt(acc, cur), accIndex, + select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex))); + acc = smax(acc, cur); + break; + case RedOp::ARGUMIN: + accIndex = + select(icmp_ult(acc, cur), accIndex, + select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex))); + acc = umin(acc, cur); + break; + case RedOp::ARGUMAX: + accIndex = + select(icmp_ugt(acc, cur), accIndex, + select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex))); + acc = umax(acc, cur); + break; + case RedOp::ARGFMIN: + accIndex = + select(fcmp_olt(acc, cur), accIndex, + select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex))); + acc = fmin(acc, cur); + break; + case RedOp::ARGFMAX: + accIndex = + select(fcmp_ogt(acc, cur), accIndex, + select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex))); + acc = fmax(acc, cur); + break; + case RedOp::ADD: + case RedOp::FADD: + case RedOp::MIN: + case RedOp::MAX: + case RedOp::UMIN: + case RedOp::UMAX: + case RedOp::FMIN: + case RedOp::FMAX: + case RedOp::XOR: + llvm::report_fatal_error( + "This accumulate implementation is only for argmin / argmax"); default: llvm::report_fatal_error("Unsupported reduce op"); } @@ -1433,6 +1505,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); unsigned axis = op.axis(); + bool withIndex = triton::ReduceOp::withIndex(op.redOp()); auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding().cast(); @@ -1440,11 +1513,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( auto srcShape = srcTy.getShape(); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto llvmIndexTy = getTypeConverter()->getIndexType(); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); - auto smemShape = getScratchConfigForReduce(op); + ReduceOpHelper helper(op); + auto smemShape = helper.getScratchConfigBasic(); + unsigned elems = product(smemShape); + Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); + indexSmemBase = bitcast(indexSmemBase, indexPtrTy); unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); @@ -1454,6 +1533,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( emitOffsetForBlockedLayout(srcLayout, srcShape); std::map, Value> accs; + std::map, Value> accIndices; std::map, SmallVector> indices; // reduce within threads @@ -1461,7 +1541,13 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); + if (!withIndex) { + accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); + } else { + Value curIndex = srcIndices[i][axis]; + accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key], + srcValues[i], curIndex, isFirst); + } if (isFirst) indices[key] = srcIndices[i]; } @@ -1477,12 +1563,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( for (auto it : accs) { const SmallVector &key = it.first; Value acc = it.second; + Value accIndex; + if (withIndex) + accIndex = accIndices[key]; SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); store(acc, writePtr); + if (withIndex) + store(accIndex, indexWritePtr); SmallVector readIdx(writeIdx.size(), ints[0]); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { @@ -1493,11 +1585,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( ints[0]); Value readPtr = gep(elemPtrTy, writePtr, readOffset); barrier(); - accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false); - store(acc, writePtr); + if (!withIndex) { + Value cur = load(readPtr); + accumulate(rewriter, loc, op.redOp(), acc, cur, false); + store(acc, writePtr); + } else { + Value cur = load(readPtr); + Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset); + Value curIndex = load(indexReadPtr); + accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur, + curIndex, false); + store(acc, writePtr); + store(accIndex, indexWritePtr); + } } } + barrier(); + // set output values if (auto resultTy = op.getType().dyn_cast()) { // nd-tensor where n >= 1 @@ -1508,25 +1613,25 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); - barrier(); SmallVector resultVals(resultElems); for (unsigned i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, ints[0]); Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); Value readPtr = gep(elemPtrTy, smemBase, readOffset); - resultVals[i] = load(readPtr); + Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); + resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); } - SmallVector resultTypes(resultElems, llvmElemTy); + SmallVector resultTypes(resultElems, + withIndex ? llvmIndexTy : llvmElemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes); Value ret = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, ret); } else { // 0d-tensor -> scalar - barrier(); - Value resultVal = load(smemBase); + Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); rewriter.replaceOp(op, resultVal); } @@ -1538,25 +1643,35 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); unsigned axis = adaptor.axis(); + bool withIndex = triton::ReduceOp::withIndex(op.redOp()); auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto srcRank = srcTy.getRank(); + auto order = getOrder(srcLayout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto llvmIndexTy = getTypeConverter()->getIndexType(); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); ReduceOpHelper helper(op); + auto smemShapes = helper.getScratchConfigsFast(); + unsigned elems = product(smemShapes[0]); + unsigned maxElems = std::max(elems, product(smemShapes[1])); + maxElems = std::max(maxElems, product(smemShapes[2])); + Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); + indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + unsigned sizeIntraWarps = helper.getIntraWarpSize(); unsigned sizeInterWarps = helper.getInterWarpSize(); - auto order = getOrder(srcLayout); unsigned srcElems = getElemsPerThread(srcTy); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); @@ -1565,16 +1680,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( emitOffsetForLayout(srcLayout, srcShape); std::map, Value> accs; + std::map, Value> accIndices; std::map, SmallVector> indices; - auto smemShape = getScratchConfigForReduce(op); - // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); + if (!withIndex) { + accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst); + } else { + Value curIndex = srcIndices[i][axis]; + accumulateWithIndex(rewriter, loc, op.redOp(), accs[key], accIndices[key], + srcValues[i], curIndex, isFirst); + } if (isFirst) indices[key] = srcIndices[i]; } @@ -1599,18 +1719,32 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( for (auto it : accs) { const SmallVector &key = it.first; Value acc = it.second; + Value accIndex; + if (withIndex) + accIndex = accIndices[key]; // reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { Value shfl = shflSync(rewriter, loc, acc, N); - accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + if (!withIndex) { + accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + } else { + Value shflIndex = shflSync(rewriter, loc, accIndex, N); + accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, + shflIndex, false); + } } SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order); + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShapes[0], order); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); storeShared(rewriter, loc, writePtr, acc, laneZero); + if (withIndex) { + Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); + storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero); + } } barrier(); @@ -1622,7 +1756,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( // // each thread needs to process: // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads - unsigned elems = product(smemShape); unsigned numThreads = product(triton::gpu::getWarpsPerCTA(srcLayout)) * 32; unsigned elemsPerThread = std::max(elems / numThreads, 1); @@ -1630,10 +1763,21 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( for (unsigned round = 0; round < elemsPerThread; ++round) { Value readPtr = gep(elemPtrTy, smemBase, readOffset); Value acc = load(readPtr); + Value accIndex; + if (withIndex) { + Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset); + accIndex = load(readIndexPtr); + } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { Value shfl = shflSync(rewriter, loc, acc, N); - accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + if (!withIndex) { + accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + } else { + Value shflIndex = shflSync(rewriter, loc, accIndex, N); + accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, + shflIndex, false); + } } Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps)); @@ -1642,8 +1786,12 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); - storeShared(rewriter, loc, writePtr, acc, - and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero)); + Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + storeShared(rewriter, loc, writePtr, acc, pred); + if (withIndex) { + Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset); + storeShared(rewriter, loc, writeIndexPtr, accIndex, pred); + } if (round != elemsPerThread - 1) { readOffset = add(readOffset, i32_val(numThreads)); @@ -1671,25 +1819,24 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); - SmallVector resultShape; - std::copy(resultTy.getShape().begin(), resultTy.getShape().end(), - std::back_inserter(resultShape)); for (size_t i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; Value readOffset = - linearize(rewriter, loc, readIdx, resultShape, resultOrd); + linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd); Value readPtr = gep(elemPtrTy, smemBase, readOffset); - resultVals[i] = load(readPtr); + Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); + resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); } - SmallVector resultTypes(resultElems, llvmElemTy); + SmallVector resultTypes(resultElems, + withIndex ? llvmIndexTy : llvmElemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes); Value ret = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, ret); } else { // 0d-tensor -> scalar - Value resultVal = load(smemBase); + Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); rewriter.replaceOp(op, resultVal); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index e9f791875..9deb7198f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -60,12 +60,32 @@ rewriter.create(loc, __VA_ARGS__) #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(val, ptr) rewriter.create(loc, val, ptr) +#define fcmp_ogt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::ogt, lhs, rhs) +#define fcmp_olt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::olt, lhs, rhs) #define icmp_eq(...) \ rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) #define icmp_ne(...) \ rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) #define icmp_slt(...) \ rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) +#define icmp_sle(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__) +#define icmp_sgt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__) +#define icmp_sge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__) +#define icmp_ult(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__) +#define icmp_ule(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__) +#define icmp_ugt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__) +#define icmp_uge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__) #define select(...) rewriter.create(loc, __VA_ARGS__) #define address_of(...) rewriter.create(loc, __VA_ARGS__) #define barrier() rewriter.create(loc) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 503bc84f3..88bdbdc38 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -240,12 +240,16 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( Value arg = operands[0]; auto argTy = arg.getType().cast(); auto argEltTy = argTy.getElementType(); + auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); + auto redOp = attributes.get("redOp").cast().getValue(); + bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); + auto retEltTy = withIndex ? i32Ty : argEltTy; auto retShape = argTy.getShape().vec(); int axis = attributes.get("axis").cast().getInt(); retShape.erase(retShape.begin() + axis); if (retShape.empty()) { // 0d-tensor -> scalar - inferredReturnTypes.push_back(argEltTy); + inferredReturnTypes.push_back(retEltTy); } else { // nd-tensor where n >= 1 // infer encoding @@ -264,11 +268,20 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( } // create type inferredReturnTypes.push_back( - RankedTensorType::get(retShape, argEltTy, retEncoding)); + RankedTensorType::get(retShape, retEltTy, retEncoding)); } return mlir::success(); } +bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { + return redOp == mlir::triton::RedOp::ARGMIN || + redOp == mlir::triton::RedOp::ARGMAX || + redOp == mlir::triton::RedOp::ARGUMIN || + redOp == mlir::triton::RedOp::ARGUMAX || + redOp == mlir::triton::RedOp::ARGFMIN || + redOp == mlir::triton::RedOp::ARGFMAX; +} + //-- SplatOp -- OpFoldResult SplatOp::fold(ArrayRef operands) { auto constOperand = src().getDefiningOp(); diff --git a/python/src/triton.cc b/python/src/triton.cc index f06f7a476..003910af6 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1195,10 +1195,11 @@ void init_triton_ir(py::module &&m) { operand.getType().dyn_cast(); std::vector shape = inputTensorType.getShape(); shape.erase(shape.begin() + axis); - mlir::Type resType = inputTensorType.getElementType(); + bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); + mlir::Type resType = withIndex ? self.getI32Type() + : inputTensorType.getElementType(); if (!shape.empty()) { - resType = mlir::RankedTensorType::get( - shape, inputTensorType.getElementType()); + resType = mlir::RankedTensorType::get(shape, resType); } return self.create(loc, resType, redOp, operand, axis); diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index dd8b4ecba..55be1cafa 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -1,4 +1,5 @@ import pytest +import numpy as np import torch from torch.testing import assert_close @@ -13,7 +14,9 @@ dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes} -def get_reduced_dtype(dtype): +def get_reduced_dtype(op, dtype): + if op in ['argmin', 'argmax']: + return torch.int32 if dtype in [torch.int8, torch.int16, torch.uint8]: return torch.int32 if dtype in [torch.bfloat16]: @@ -48,7 +51,7 @@ def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, blo reduce1d_configs = [ (op, dtype, shape) - for op in ['sum', 'min', 'max'] + for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum'] for dtype in dtypes for shape in [4, 8, 16, 32, 64, 128, 512, 1024] ] @@ -56,8 +59,11 @@ reduce1d_configs = [ @pytest.mark.parametrize('op, dtype, shape', reduce1d_configs) def test_reduce1d(op, dtype, shape): + if op == 'xor_sum' and dtype in float_dtypes: + return + dtype = dtype_mapping[dtype] - reduced_dtype = get_reduced_dtype(dtype) + reduced_dtype = get_reduced_dtype(op, dtype) if dtype.is_floating_point: x = torch.randn((shape,), device='cuda', dtype=dtype) @@ -79,8 +85,17 @@ def test_reduce1d(op, dtype, shape): golden_z = torch.sum(x, dtype=reduced_dtype) elif op == 'min': golden_z = torch.min(x).to(reduced_dtype) - else: + elif op == 'max': golden_z = torch.max(x).to(reduced_dtype) + elif op == 'argmin': + golden_z = torch.argmin(x).to(reduced_dtype) + elif op == 'argmax': + golden_z = torch.argmax(x).to(reduced_dtype) + elif op == 'xor_sum': + sum_npy = np.bitwise_xor.reduce(x.cpu().numpy()) + golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda() + else: + raise RuntimeError(f'Unknwon reduce op {op}') if dtype.is_floating_point and op == 'sum': if shape >= 256: @@ -95,7 +110,7 @@ def test_reduce1d(op, dtype, shape): reduce2d_configs = [ (op, dtype, shape, axis) - for op in ['sum', 'min', 'max'] + for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum'] for dtype in dtypes for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)] for axis in [0, 1] @@ -104,8 +119,11 @@ reduce2d_configs = [ @pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs) def test_reduce2d(op, dtype, shape, axis): + if op == 'xor_sum' and dtype in float_dtypes: + return + dtype = dtype_mapping[dtype] - reduced_dtype = get_reduced_dtype(dtype) + reduced_dtype = get_reduced_dtype(op, dtype) reduced_shape = (shape[1 - axis],) if dtype.is_floating_point: @@ -123,8 +141,18 @@ def test_reduce2d(op, dtype, shape, axis): golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype) elif op == 'min': golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype) - else: + elif op == 'max': golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype) + elif op == 'argmin': + golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype) + elif op == 'argmax': + golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype) + elif op == 'xor_sum': + sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False) + golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda() + else: + raise RuntimeError(f'Unknwon reduce op {op}') + if dtype.is_floating_point and op == 'sum': if shape[axis] >= 256: assert_close(z, golden_z, rtol=0.05, atol=0.1) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index a5e8166e6..7bd05e291 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1041,6 +1041,13 @@ def max(input, axis, _builder=None): return semantic.max(input, axis, _builder) +@builtin +@_add_reduction_docstr("maximum index") +def argmax(input, axis, _builder=None): + axis = _constexpr_to_value(axis) + return semantic.argmax(input, axis, _builder) + + @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): @@ -1048,6 +1055,13 @@ def min(input, axis, _builder=None): return semantic.min(input, axis, _builder) +@builtin +@_add_reduction_docstr("minimum index") +def argmin(input, axis, _builder=None): + axis = _constexpr_to_value(axis) + return semantic.argmin(input, axis, _builder) + + @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index df2d6a3af..741356e1a 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1061,10 +1061,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) +def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN) + + def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) +def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX) + + def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)