From 080b4addf8c0c9569206e01ecc1549a22a069932 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Wed, 9 Nov 2022 02:10:09 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix the order in linear/delinear and a few bugs in reduce conversion (#851) 1, fix the order in linearize/delinearize, which fix the error of order in emitIndices; 2, fix the selecting of fast implementation in reduce codegen; 3, fix the redundant barrier in reduce codegen; 4, fix the index mapping of the second round of warp_shuffle in shuffle version of reduce codegen. Co-authored-by: Keren Zhou --- include/triton/Analysis/Utility.h | 6 +- lib/Analysis/Allocation.cpp | 2 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 515 ++++++++++-------- python/tests/test_reduce.py | 5 +- 4 files changed, 282 insertions(+), 246 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 6152c11f5..0053776c4 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -21,11 +21,11 @@ template Int product(llvm::ArrayRef arr) { template Int ceil(Int m, Int n) { return (m + n - 1) / n; } // output[i] = input[order[i]] -template -SmallVector reorder(ArrayRef input, ArrayRef order) { +template +SmallVector reorder(ArrayRef input, ArrayRef order) { size_t rank = order.size(); assert(input.size() == rank); - SmallVector result(rank); + SmallVector result(rank); for (auto it : llvm::enumerate(order)) { result[it.index()] = input[it.value()]; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 51d4e0e3b..35795376b 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -77,7 +77,7 @@ SmallVector getScratchConfigForReduce(triton::ReduceOp op) { auto srcShape = srcTy.getShape(); auto axis = op.axis(); - bool fastReduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension + bool fastReduce = axis == srcLayout.getOrder()[0]; SmallVector smemShape; for (auto d : srcShape) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 362f78677..69ca5fc22 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -83,6 +83,11 @@ static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, } // namespace +// A helper function for using printf in LLVM conversion. +void llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter); + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive// // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive #define zext(...) rewriter.create(loc, __VA_ARGS__) #define udiv(...) rewriter.create(loc, __VA_ARGS__) @@ -338,6 +343,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals, return llvmStruct; } +// Delinearize on compile-time consts, assuming the order is [n, .. 2, 1, 0] template static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape) { // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} @@ -355,6 +361,7 @@ static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape) { return multiDimIndex; } +// Linearize on compile-time consts, assuming the order is [n, .. 2, 1, 0] template static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { assert(multiDimIndex.size() == shape.size()); @@ -510,12 +517,12 @@ public: multiDim[0] = linear; } else { Value remained = linear; - for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) { + for (auto &&en : llvm::enumerate(shape.drop_back())) { Value dimSize = idx_val(en.value()); - multiDim[rank - 1 - en.index()] = urem(remained, dimSize); + multiDim[en.index()] = urem(remained, dimSize); remained = udiv(remained, dimSize); } - multiDim[0] = remained; + multiDim[rank - 1] = remained; } return multiDim; } @@ -525,9 +532,9 @@ public: int rank = multiDim.size(); Value linear = idx_val(0); if (rank > 0) { - linear = multiDim.front(); + linear = multiDim.back(); for (auto [dim, shape] : - llvm::zip(multiDim.drop_front(), shape.drop_front())) { + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { Value dimSize = idx_val(shape); linear = add(mul(linear, dimSize), dim); } @@ -566,6 +573,7 @@ public: delinearize(rewriter, loc, warpId, warpsPerCTA, order); SmallVector multiDimThreadId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); + SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { // Wrap around multiDimWarpId/multiDimThreadId incase @@ -1362,7 +1370,9 @@ private: LogicalResult ReduceOpConversion::matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (op.axis() == 1) // FIXME(Qingyi): The fastest-changing dimension + auto srcTy = op.operand().getType().cast(); + auto srcLayout = srcTy.getEncoding().cast(); + if (op.axis() == srcLayout.getOrder()[0]) return matchAndRewriteFast(op, adaptor, rewriter); return matchAndRewriteBasic(op, adaptor, rewriter); } @@ -1444,6 +1454,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding().cast(); + auto srcOrd = srcLayout.getOrder(); auto srcShape = srcTy.getShape(); auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); @@ -1487,7 +1498,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape); + Value writeOffset = + linearize(rewriter, loc, reorder(writeIdx, srcOrd), + reorder(smemShape, srcOrd)); Value writePtr = gep(elemPtrTy, smemBase, writeOffset); store(acc, writePtr); @@ -1495,8 +1508,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { readIdx[axis] = ints[N]; Value readMask = icmp_slt(writeIdx[axis], ints[N]); - Value readOffset = select( - readMask, linearize(rewriter, loc, readIdx, smemShape), ints[0]); + Value readOffset = + select(readMask, + linearize(rewriter, loc, reorder(readIdx, srcOrd), + reorder(smemShape, srcOrd)), + ints[0]); Value readPtr = gep(elemPtrTy, writePtr, readOffset); barrier(); accumulate(rewriter, loc, op.redOp(), acc, load(readPtr), false); @@ -1519,7 +1535,9 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( for (unsigned i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape); + Value readOffset = + linearize(rewriter, loc, reorder(readIdx, srcOrd), + reorder(smemShape, srcOrd)); Value readPtr = gep(elemPtrTy, smemBase, readOffset); resultVals[i] = load(readPtr); } @@ -1548,6 +1566,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); + auto srcRank = srcTy.getRank(); auto threadsPerWarp = srcLayout.getThreadsPerWarp(); auto warpsPerCTA = srcLayout.getWarpsPerCTA(); @@ -1592,6 +1611,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + Value laneIdAxis = multiDimLaneId[axis]; Value warpIdAxis = multiDimWarpId[axis]; @@ -1609,56 +1629,77 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( accumulate(rewriter, loc, op.redOp(), acc, shfl, false); } - if (sizeInterWarps == 1) { - SmallVector writeIdx = indices[key]; - writeIdx[axis] = zero; - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, laneZero); - } else { - SmallVector writeIdx = indices[key]; - writeIdx[axis] = - warpIdAxis; // axis must be the fastest-changing dimension - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, laneZero); - barrier(); + SmallVector writeIdx = indices[key]; + writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, reorder(writeIdx, order), + reorder(smemShape, order)); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + storeShared(rewriter, loc, writePtr, acc, laneZero); + } - SmallVector readIdx = writeIdx; - readIdx[axis] = urem(laneId, i32_val(sizeInterWarps)); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - acc = load(readPtr); + barrier(); - // reduce across warps - for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(rewriter, loc, acc, N); - accumulate(rewriter, loc, op.redOp(), acc, shfl, false); - } + // the second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn => + // 1, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + unsigned elems = product(smemShape); + unsigned numThreads = product(srcLayout.getWarpsPerCTA()) * 32; + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + Value acc = load(readPtr); - writeIdx[axis] = zero; - writeOffset = linearize(rewriter, loc, writeIdx, smemShape); - writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, and_(laneZero, warpZero)); + for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { + Value shfl = shflSync(rewriter, loc, acc, N); + accumulate(rewriter, loc, op.redOp(), acc, shfl, false); + } + + Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps)); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + icmp_eq(laneIdModSizeInterWarps, zero); + storeShared(rewriter, loc, writePtr, acc, + and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero)); + + if (round != elemsPerThread - 1) { + readOffset = add(readOffset, i32_val(numThreads)); } } + // We could avoid this barrier in some of the layouts, however this is not + // the general case. TODO: optimize the barrier incase the layouts are + // accepted. + barrier(); + // set output values if (auto resultTy = op.getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding().cast(); auto resultShape = resultTy.getShape(); + SmallVector resultOrd; + for (auto ord : order) { + if (ord != 0) + resultOrd.push_back(ord - 1); + } unsigned resultElems = getElemsPerThread(resultTy); auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape); assert(resultIndices.size() == resultElems); - barrier(); SmallVector resultVals(resultElems); for (size_t i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape); + Value readOffset = + linearize(rewriter, loc, reorder(readIdx, resultOrd), + reorder(resultShape, resultOrd)); Value readPtr = gep(elemPtrTy, smemBase, readOffset); resultVals[i] = load(readPtr); } @@ -1670,7 +1711,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( rewriter.replaceOp(op, ret); } else { // 0d-tensor -> scalar - barrier(); Value resultVal = load(smemBase); rewriter.replaceOp(op, resultVal); } @@ -1707,6 +1747,191 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; +struct PrintfOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + SmallVector operands; + for (auto operand : adaptor.getOperands()) { + auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); + for (auto elem : sub_operands) { + operands.push_back(elem); + } + } + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << op.prefix(); + if (operands.size() > 0) { + os << getFormatSubstr(operands[0]); + } + + for (size_t i = 1; i < operands.size(); ++i) { + os << ", " << getFormatSubstr(operands[i]); + } + llPrintf(formatStr, operands, rewriter); + rewriter.eraseOp(op); + return success(); + } + // get format specific for each input value + // currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64 + std::string getFormatSubstr(Value value) const { + Type type = value.getType(); + unsigned width = type.getIntOrFloatBitWidth(); + + if (type.isa()) { + return "%p"; + } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return "%f"; + } else if (type.isSignedInteger()) { + return "%i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + return "%u"; + } + assert(false && "not supported type"); + } + + // declare vprintf(i8*, i8*) as external function + static LLVM::LLVMFuncOp + getVprintfDeclaration(ConversionPatternRewriter &rewriter) { + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("vprintf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + SmallVector argsType{ptr_ty(IntegerType::get(context, 8)), + ptr_ty(IntegerType::get(context, 8))}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(context), funcName, + funcType); + } + + // extend integer to int32, extend float to float64 + // this comes from vprintf alignment requirements. + static std::pair + promoteValue(ConversionPatternRewriter &rewriter, Value value) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + type.dump(); + unsigned width = type.getIntOrFloatBitWidth(); + Value newOp = value; + Type newType = type; + + bool bUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && width < 32) { + if (bUnsigned) { + newType = ui32_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } else { + newType = i32_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + newType = f64_ty; + newOp = rewriter.create(UnknownLoc::get(context), newType, + value); + } + + return {newType, newOp}; + } + + static void llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) { + static const char formatStringPrefix[] = "printfFormat_"; + assert(!msg.empty() && "printf with empty string not support"); + Type int8Ptr = ptr_ty(i8_ty); + + auto *context = rewriter.getContext(); + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + auto funcOp = getVprintfDeclaration(rewriter); + + Value one = rewriter.create( + UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1)); + Value zero = rewriter.create( + UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0)); + + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> formatString(msg); + formatString.push_back('\n'); + formatString.push_back('\0'); + size_t formatStringSize = formatString.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize); + + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(context), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString)); + } + + Value globalPtr = + rewriter.create(UnknownLoc::get(context), global); + Value stringStart = + rewriter.create(UnknownLoc::get(context), int8Ptr, + globalPtr, mlir::ValueRange({zero, zero})); + + Value bufferPtr = + rewriter.create(UnknownLoc::get(context), int8Ptr); + + SmallVector newArgs; + if (args.size() >= 1) { + SmallVector argTypes; + for (auto arg : args) { + Type newType; + Value newArg; + std::tie(newType, newArg) = promoteValue(rewriter, arg); + argTypes.push_back(newType); + newArgs.push_back(newArg); + } + + Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes); + auto allocated = rewriter.create(UnknownLoc::get(context), + ptr_ty(structTy), one, + /*alignment=*/0); + + for (const auto &entry : llvm::enumerate(newArgs)) { + auto index = rewriter.create( + UnknownLoc::get(context), i32_ty, + rewriter.getI32IntegerAttr(entry.index())); + auto fieldPtr = rewriter.create( + UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]), + allocated, ArrayRef{zero, index}); + rewriter.create(UnknownLoc::get(context), entry.value(), + fieldPtr); + } + bufferPtr = rewriter.create(UnknownLoc::get(context), + int8Ptr, allocated); + } + + ValueRange operands{stringStart, bufferPtr}; + rewriter.create(UnknownLoc::get(context), funcOp, operands); + } +}; + struct MakeRangeOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -2070,17 +2295,6 @@ public: } private: - template - SmallVector reorder(ArrayRef input, ArrayRef order) const { - size_t rank = order.size(); - assert(input.size() == rank); - SmallVector result(rank); - for (auto it : llvm::enumerate(order)) { - result[rank - 1 - it.value()] = input[it.index()]; - } - return result; - }; - // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, @@ -4483,7 +4697,7 @@ struct InsertSliceAsyncOpConversion auto numVecCols = std::max(inVec / outVec, 1); auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape); - // <, TileOffset> + // <, TileOffset> DenseMap, Value> tileOffsetMap; for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { // minVec = 2, inVec = 4, outVec = 2 @@ -4674,190 +4888,6 @@ struct FDivOpConversion } }; -struct PrintfOpConversion - : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - SmallVector operands; - for (auto operand : adaptor.getOperands()) { - auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); - for (auto elem : sub_operands) { - operands.push_back(elem); - } - } - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << op.prefix(); - if (operands.size() > 0) { - os << getFormatSubstr(operands[0]); - } - - for (size_t i = 1; i < operands.size(); ++i) { - os << ", " << getFormatSubstr(operands[i]); - } - llPrintf(formatStr, operands, rewriter); - rewriter.eraseOp(op); - return success(); - } - // get format specific for each input value - // currently support pointer, i8, i16, i32, i64, f16, bf16, f32, f64 - std::string getFormatSubstr(Value value) const { - Type type = value.getType(); - unsigned width = type.getIntOrFloatBitWidth(); - - if (type.isa()) { - return "%p"; - } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { - return "%f"; - } else if (type.isSignedInteger()) { - return "%i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { - return "%u"; - } - assert(false && "not supported type"); - } - - // declare vprintf(i8*, i8*) as external function - LLVM::LLVMFuncOp - getVprintfDeclaration(ConversionPatternRewriter &rewriter) const { - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("vprintf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - auto *context = rewriter.getContext(); - - SmallVector argsType{ptr_ty(IntegerType::get(context, 8)), - ptr_ty(IntegerType::get(context, 8))}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - return rewriter.create(UnknownLoc::get(context), funcName, - funcType); - } - - // extend integer to int32, extend float to float64 - // this comes from vprintf alignment requirements. - std::pair promoteValue(ConversionPatternRewriter &rewriter, - Value value) const { - auto *context = rewriter.getContext(); - auto type = value.getType(); - unsigned width = type.getIntOrFloatBitWidth(); - Value newOp = value; - Type newType = type; - - bool bUnsigned = type.isUnsignedInteger(); - if (type.isIntOrIndex() && width < 32) { - if (bUnsigned) { - newType = ui32_ty; - newOp = rewriter.create(UnknownLoc::get(context), newType, - value); - } else { - newType = i32_ty; - newOp = rewriter.create(UnknownLoc::get(context), newType, - value); - } - } else if (type.isBF16() || type.isF16() || type.isF32()) { - newType = f64_ty; - newOp = rewriter.create(UnknownLoc::get(context), newType, - value); - } - - return {newType, newOp}; - } - - void llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter) const { - static const char formatStringPrefix[] = "printfFormat_"; - assert(!msg.empty() && "printf with empty string not support"); - Type int8Ptr = ptr_ty(i8_ty); - - auto *context = rewriter.getContext(); - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); - auto funcOp = getVprintfDeclaration(rewriter); - - Value one = rewriter.create( - UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1)); - Value zero = rewriter.create( - UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0)); - - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); - - llvm::SmallString<64> formatString(msg); - formatString.push_back('\n'); - formatString.push_back('\0'); - size_t formatStringSize = formatString.size_in_bytes(); - auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize); - - LLVM::GlobalOp global; - { - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - global = rewriter.create( - UnknownLoc::get(context), globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, - rewriter.getStringAttr(formatString)); - } - - Value globalPtr = - rewriter.create(UnknownLoc::get(context), global); - Value stringStart = - rewriter.create(UnknownLoc::get(context), int8Ptr, - globalPtr, mlir::ValueRange({zero, zero})); - - Value bufferPtr = - rewriter.create(UnknownLoc::get(context), int8Ptr); - - SmallVector newArgs; - if (args.size() >= 1) { - SmallVector argTypes; - for (auto arg : args) { - Type newType; - Value newArg; - std::tie(newType, newArg) = promoteValue(rewriter, arg); - argTypes.push_back(newType); - newArgs.push_back(newArg); - } - - Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes); - auto allocated = rewriter.create(UnknownLoc::get(context), - ptr_ty(structTy), one, - /*alignment=*/0); - - for (const auto &entry : llvm::enumerate(newArgs)) { - auto index = rewriter.create( - UnknownLoc::get(context), i32_ty, - rewriter.getI32IntegerAttr(entry.index())); - auto fieldPtr = rewriter.create( - UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]), - allocated, ArrayRef{zero, index}); - rewriter.create(UnknownLoc::get(context), entry.value(), - fieldPtr); - } - bufferPtr = rewriter.create(UnknownLoc::get(context), - int8Ptr, allocated); - } - - ValueRange operands{stringStart, bufferPtr}; - rewriter.create(UnknownLoc::get(context), funcOp, operands); - } -}; - void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, @@ -5062,6 +5092,15 @@ void ConvertTritonGPUToLLVM::initSharedMemory( namespace mlir { +namespace LLVM { + +void llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) { + PrintfOpConversion::llPrintf(msg, args, rewriter); +} + +} // namespace LLVM + TritonLLVMConversionTarget::TritonLLVMConversionTarget( MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) : ConversionTarget(ctx) { diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index f00d2b764..5e9a7445a 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -97,9 +97,7 @@ reduce2d_configs = [ (op, dtype, shape, axis) for op in ['sum', 'min', 'max'] for dtype in dtypes - for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32)] - # TODO: fix and uncomment - #, (4, 128), (32, 64)] + for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)] for axis in [0, 1] ] @@ -128,7 +126,6 @@ def test_reduce2d(op, dtype, shape, axis): golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype) else: golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype) - if dtype.is_floating_point and op == 'sum': if shape[axis] >= 256: assert_close(z, golden_z, rtol=0.05, atol=0.1)