From 136668bac3cd6589bfb092bc29bfe95f3e3aa8e9 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 21 Nov 2022 16:00:46 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] tiny code cleanup (#899) - Remove the unnecessary `static` in the anonymous namespace - Remove several unnecessary functions - Several simple rewrites to make code more clear --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 171 +++++++----------- 1 file changed, 65 insertions(+), 106 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 8338388be..dbc057c85 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -51,8 +51,7 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } namespace { // Create a 32-bit integer constant. -static Value createConstantI32(Location loc, PatternRewriter &rewriter, - int32_t v) { +Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { auto i32ty = rewriter.getIntegerType(32); return rewriter.create(loc, i32ty, IntegerAttr::get(i32ty, v)); @@ -71,16 +70,16 @@ Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) { } // Create an index type constant. -static Value createIndexConstant(OpBuilder &builder, Location loc, - TypeConverter *converter, int64_t value) { +Value createIndexConstant(OpBuilder &builder, Location loc, + TypeConverter *converter, int64_t value) { Type ty = converter->convertType(builder.getIndexType()); return builder.create(loc, ty, builder.getIntegerAttr(ty, value)); } // Create an integer constant of \param width bits. -static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, - short width, int64_t value) { +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { Type ty = builder.getIntegerType(width); return builder.create(loc, ty, builder.getIntegerAttr(ty, value)); @@ -187,9 +186,8 @@ template void printScalar(const T &e, const std::string &info) { /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. -static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAttrs, - SmallVectorImpl &result) { +void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, + SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() || @@ -202,7 +200,7 @@ static void filterFuncAttributes(ArrayRef attrs, } /// Helper function for wrapping all attributes into a single DictionaryAttr -static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { +auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { return DictionaryAttr::get( b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs)); } @@ -359,7 +357,7 @@ Value getStructFromElements(Location loc, ValueRange resultVals, // delinearize supposing order is [0, 1, .. , n] template -static SmallVector getMultiDimIndexImpl(T linearIndex, ArrayRef shape) { +SmallVector getMultiDimIndexImpl(T linearIndex, ArrayRef shape) { // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} size_t rank = shape.size(); T accMul = product(shape.drop_back()); @@ -376,8 +374,8 @@ static SmallVector getMultiDimIndexImpl(T linearIndex, ArrayRef shape) { } template -static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape, - ArrayRef order) { +SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape, + ArrayRef order) { size_t rank = shape.size(); assert(rank == order.size()); auto reordered = reorder(shape, order); @@ -391,7 +389,7 @@ static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape, // linearize supposing order is [0, 1, .. , n] template -static T getLinearIndexImpl(ArrayRef multiDimIndex, ArrayRef shape) { +T getLinearIndexImpl(ArrayRef multiDimIndex, ArrayRef shape) { assert(multiDimIndex.size() == shape.size()); // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} size_t rank = shape.size(); @@ -407,15 +405,15 @@ static T getLinearIndexImpl(ArrayRef multiDimIndex, ArrayRef shape) { } template -static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape, - ArrayRef order) { +T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape, + ArrayRef order) { assert(shape.size() == order.size()); return getLinearIndexImpl(reorder(multiDimIndex, order), reorder(shape, order)); } -static Value storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) { +Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred) { MLIRContext *ctx = rewriter.getContext(); unsigned bits = val.getType().getIntOrFloatBitWidth(); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); @@ -517,10 +515,9 @@ struct ConvertTritonGPUOpToLLVMPatternBase { ConversionPatternRewriter &rewriter) { auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); auto rank = (elems.size() - 1) / 2; - return SharedMemoryObject( - /*base=*/elems[0], - /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, - /*offsets=*/{elems.begin() + 1 + rank, elems.end()}); + return {/*base=*/elems[0], + /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; } static Value @@ -1018,13 +1015,13 @@ struct ArithConstantSplatOpConversion // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { - LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) + explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) : AxisAnalysisPass(axisAnalysisPass) {} // Get corresponding LLVM element values of \param value. - SmallVector getLLVMElems(Value value, Value llValue, - ConversionPatternRewriter &rewriter, - Location loc) const { + static SmallVector getLLVMElems(Value value, Value llValue, + ConversionPatternRewriter &rewriter, + Location loc) { if (!value) return {}; if (!llValue.getType().isa()) @@ -1601,7 +1598,7 @@ void ReduceOpConversion::accumulate(ConversionPatternRewriter &rewriter, default: llvm::report_fatal_error("Unsupported reduce op"); } -}; +} Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val, int i) const { @@ -1946,7 +1943,7 @@ struct PrintfOpConversion std::string formatStr; llvm::raw_string_ostream os(formatStr); os << op.prefix(); - if (operands.size() > 0) { + if (!operands.empty()) { os << getFormatSubstr(operands[0]); } @@ -2130,7 +2127,7 @@ struct MakeRangeOpConversion auto idxs = emitIndices(loc, rewriter, layout, shape); unsigned elems = idxs.size(); SmallVector retVals(elems); - for (auto multiDim : llvm::enumerate(idxs)) { + for (const auto &multiDim : llvm::enumerate(idxs)) { assert(multiDim.value().size() == 1); retVals[multiDim.index()] = add(multiDim.value()[0], start); } @@ -2633,7 +2630,7 @@ struct FpToFpOpConversion }; // A CRTP style of base class. -template +template class ElementwiseOpConversionBase : public ConvertTritonGPUOpToLLVMPattern { public: @@ -2688,16 +2685,16 @@ protected: template struct ElementwiseOpConversion : public ElementwiseOpConversionBase< - SourceOp, DestOp, ElementwiseOpConversion> { + SourceOp, ElementwiseOpConversion> { using Base = - ElementwiseOpConversionBase>; using Base::Base; using OpAdaptor = typename Base::OpAdaptor; explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : ElementwiseOpConversionBase( + : ElementwiseOpConversionBase( typeConverter, benefit) {} // An interface to support variant DestOp builder. @@ -2714,10 +2711,10 @@ struct ElementwiseOpConversion // struct CmpIOpConversion - : public ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; + using Base = + ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; @@ -2755,17 +2752,18 @@ struct CmpIOpConversion }; struct CmpFOpConversion - : public ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; + using Base = + ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. - LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - ValueRange operands, Location loc) const { + static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, ValueRange operands, + Location loc) { return rewriter.create( loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0], operands[1]); @@ -2945,13 +2943,6 @@ private: triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const; - - // shared -> dot_operand if the result layout is blocked - Value lowerSharedToDotOperandBlocked( - triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - const BlockedEncodingAttr &blockedLayout, - const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const; }; void ConvertLayoutOpConversion::processReplica( @@ -2960,7 +2951,7 @@ void ConvertLayoutOpConversion::processReplica( ArrayRef multiDimRepId, unsigned vec, ArrayRef paddedRepShape, ArrayRef outOrd, SmallVector &vals, Value smemBase) const { - unsigned accumNumCTAsEachRep = product(numCTAsEachRep); + auto accumNumCTAsEachRep = product(numCTAsEachRep); auto layout = type.getEncoding(); auto blockedLayout = layout.dyn_cast(); auto sliceLayout = layout.dyn_cast(); @@ -2989,13 +2980,12 @@ void ConvertLayoutOpConversion::processReplica( auto multiDimCTAInRepId = getMultiDimIndex(ctaId, numCTAsEachRep, order); SmallVector multiDimCTAId(rank); - for (auto it : llvm::enumerate(multiDimCTAInRepId)) { + for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { auto d = it.index(); multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); } - unsigned linearCTAId = - getLinearIndex(multiDimCTAId, numCTAs, order); + auto linearCTAId = getLinearIndex(multiDimCTAId, numCTAs, order); // TODO: This is actually redundant index calculation, we should // consider of caching the index calculation result in case // of performance issue observed. @@ -3073,7 +3063,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( outNumCTAs[d] = ceil(shape[d], outPerCTA); } // Potentially we need to store for multiple CTAs in this replication - unsigned accumNumReplicates = product(numReplicates); + auto accumNumReplicates = product(numReplicates); // unsigned elems = getElemsPerThread(srcTy); auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); unsigned inVec = 0; @@ -3118,7 +3108,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( rewriter.replaceOp(op, result); return success(); -}; +} LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -3144,7 +3134,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned numElems = getElemsPerThread(srcTy); auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter); - unsigned srcAccumSizeInThreads = + auto srcAccumSizeInThreads = product(srcBlockedLayout.getSizePerThread()); auto elemTy = srcTy.getElementType(); auto wordTy = vec_ty(elemTy, minVec); @@ -3177,7 +3167,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( smemBase = bitcast(smemBase, elemPtrTy); auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - unsigned numWordsEachRep = product(wordsInEachRep); + auto numWordsEachRep = product(wordsInEachRep); SmallVector wordVecs(numWordsEachRep); // TODO: We should get less barriers if it is handled by membar pass // instead of the backend, since the later can only handle it in @@ -3196,7 +3186,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; multiDimIdxInNanoTile[inOrd[0]] /= minVec; - unsigned wordVecIdx = + auto wordVecIdx = getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep, inOrd); wordVecs[wordVecIdx] = insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos)); @@ -3265,7 +3255,6 @@ public: cMatShape = matShape[order[0]]; sMatShape = matShape[order[1]]; - cStride = smemStrides[order[0]]; sStride = smemStrides[order[1]]; // rule: k must be the fast-changing axis. @@ -3636,7 +3625,6 @@ private: int cMatShape; int sMatShape; - Value cStride; Value sStride; bool needTrans; @@ -3651,13 +3639,6 @@ private: int warpOffStride; }; -bool isSplatLike(Value value) { - if (auto constv = dyn_cast(value.getDefiningOp())) - if (auto attr = constv.getValue().dyn_cast()) - return attr.isSplat(); - return false; -} - struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { enum class TensorCoreType : uint8_t { // floating-point tensor core instr @@ -3790,7 +3771,6 @@ struct DotOpMmaV1ConversionHelper { int getRepN(int N) const { return std::max(N / (wpt[1] * instrShape[1]), 1); } - int getRepK(int K) const { return std::max(K / instrShape[2], 1); } static ArrayRef getMmaInstrShape() { return instrShape; } @@ -3857,9 +3837,6 @@ struct DotOpMmaV1ConversionHelper { Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const; - // Loading $c to registers, returns a LLVM::Struct. - Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const; - static ArrayRef getOrder() { return mmaOrder; } // Compute the offset of the matrix to load. @@ -3900,13 +3877,6 @@ struct DotOpMmaV2ConversionHelper { mmaType = getTensorCoreTypeFromOperand(operandTy); } - // Get the M and N of mat instruction shape. - static std::tuple getMatShapeMN() { - // According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix - // shape's M,N are {8,8} - return {8, 8}; - } - // Get the M and N of mma instruction shape. static std::tuple getInstrShapeMN() { // According to DotOpConversionHelper::mmaInstrShape, all the M,N are @@ -4561,7 +4531,7 @@ struct DotOpFMAConversionHelper { ConversionPatternRewriter &rewriter, Location loc) const; - Value getStructFromValueTable(ValueTable vals, + Value getStructFromValueTable(const ValueTable &vals, ConversionPatternRewriter &rewriter, Location loc) const { SmallVector elemTypes(vals.size(), f32_ty); @@ -4838,7 +4808,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto mma = builder.create("mma.sync.aligned.m8n8k4") ->o(isARow ? "row" : "col") .o(isBRow ? "row" : "col") - .o(".f32.f16.f16.f32"); + .o("f32.f16.f16.f32"); mma(resOprs, AOprs, BOprs, COprs); @@ -5095,11 +5065,6 @@ Value DotOpMmaV1ConversionHelper::loadB( return res; } -Value DotOpMmaV1ConversionHelper::loadC( - Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const { - return llTensor; -} - std::tuple DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, @@ -5847,11 +5812,10 @@ struct InsertSliceAsyncOpConversion }; struct ExtElemwiseOpConversion - : public ElementwiseOpConversionBase< - triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> { - using Base = - ElementwiseOpConversionBase; + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; @@ -5895,10 +5859,9 @@ private: }; struct FDivOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; @@ -5911,30 +5874,26 @@ struct FDivOpConversion unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); if (32 == bitwidth) { fdiv.o("full").o("f32"); - auto res = ptxBuilder.newOperand("=r"); - auto lhs = ptxBuilder.newOperand(operands[0], "r"); - auto rhs = ptxBuilder.newOperand(operands[1], "r"); - fdiv(res, lhs, rhs); } else if (64 == bitwidth) { fdiv.o("rn").o("f64"); - auto res = ptxBuilder.newOperand("=l"); - auto lhs = ptxBuilder.newOperand(operands[0], "l"); - auto rhs = ptxBuilder.newOperand(operands[1], "l"); - fdiv(res, lhs, rhs); } else { assert(0 && bitwidth && "not supported"); } + auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l"); + auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l"); + auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l"); + fdiv(res, lhs, rhs); + Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false); return ret; } }; struct ExpOpConversionApprox - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor;