From 0cfe909df84f63ac431429b0c1732a62749c49cf Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 12 Dec 2022 17:46:16 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] some code clean on the backend (#978) --- lib/Conversion/TritonGPUToLLVM/DotHelpers.h | 20 ++-- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 95 +++++---------- lib/Conversion/TritonGPUToLLVM/Utility.h | 7 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 112 +++++++++--------- 4 files changed, 97 insertions(+), 137 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index 8703bebcb..311f0ec91 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -105,11 +105,9 @@ struct DotOpMmaV1ConversionHelper { } // Get the number of fp16x2 elements for $a. - // \param shapeTransed: the shape or reordered shape if transpose needed. + // \param shapeTransed: A's shape or reordered shape if transpose needed. // \param orderTransed: the order or reordered order if transpose needed. - unsigned getNumM(ArrayRef shapeTransed, - ArrayRef orderTransed) const { - bool isARow = orderTransed[0] != 0; + unsigned getNumM(ArrayRef shapeTransed, bool isARow) const { AParam param(isARow); unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]); @@ -117,11 +115,9 @@ struct DotOpMmaV1ConversionHelper { } // Get the number of fp16x2 elements for $b. - // \param shapeTransed: the shape or reordered shape if transpose needed. + // \param shapeTransed: B' shape or reordered shape if transpose needed. // \param orderTransed: the order or reordered order if transpose needed. - unsigned getNumN(ArrayRef shapeTransed, - ArrayRef orderTransed) const { - bool isBRow = orderTransed[0] != 0; + unsigned getNumN(ArrayRef shapeTransed, bool isBRow) const { BParam param(isBRow); unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]); @@ -130,7 +126,7 @@ struct DotOpMmaV1ConversionHelper { int numElemsPerThreadA(ArrayRef shapeTransed, ArrayRef orderTransed) const { - int numM = getNumM(shapeTransed, orderTransed); + int numM = getNumM(shapeTransed, orderTransed[0] == 1); int NK = shapeTransed[1]; // NOTE: We couldn't get the vec from the shared layout. @@ -143,7 +139,7 @@ struct DotOpMmaV1ConversionHelper { int numElemsPerThreadB(ArrayRef shapeTransed, ArrayRef orderTransed) const { - unsigned numN = getNumN(shapeTransed, orderTransed); + unsigned numN = getNumN(shapeTransed, orderTransed[0] == 1); int NK = shapeTransed[0]; // NOTE: We couldn't get the vec from the shared layout. // int vecB = sharedLayout.getVec(); @@ -1451,7 +1447,7 @@ Value DotOpMmaV1ConversionHelper::loadA( } }; - unsigned numM = getNumM(shape, order); + unsigned numM = getNumM(shape, order[0] == 1); for (unsigned k = 0; k < NK; k += 4) for (unsigned m = 0; m < numM / 2; ++m) loadA(m, k); @@ -1563,7 +1559,7 @@ Value DotOpMmaV1ConversionHelper::loadB( } }; - unsigned numN = getNumN(shape, order); + unsigned numN = getNumN(shape, order[0] == 1); for (unsigned k = 0; k < NK; k += 4) for (unsigned n = 0; n < numN / 2; ++n) { if (!hbs.count({n, k})) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 9526e1a5a..f6d158625 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -739,7 +739,6 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, auto tensorTy = resType.cast(); if (tensorTy.getEncoding().isa() || tensorTy.getEncoding().isa()) { - auto tensorTy = resType.cast(); auto srcType = typeConverter->convertType(elemType); auto llSrc = bitcast(constVal, srcType); size_t elemsPerThread = getElemsPerThread(tensorTy); @@ -981,7 +980,7 @@ struct LoadOpConversion size_t size = width / valueElemNbits; auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); - Value v = rewriter.create(loc, vecTy); + Value v = undef(vecTy); for (size_t s = 0; s < size; ++s) { Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( @@ -1118,7 +1117,7 @@ struct StoreOpConversion SmallVector> asmArgs; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition - Value llWord = rewriter.create(loc, wordTy); + Value llWord = undef(wordTy); // Insert each value element to the composition for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; @@ -1129,10 +1128,7 @@ struct StoreOpConversion elem = bitcast(elem, valueElemTy); Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); - llWord = - insert_element(wordTy, llWord, elem, - rewriter.create( - loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx))); + llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); } llWord = bitcast(llWord, valArgTy); std::string constraint = @@ -3570,43 +3566,27 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); auto DTensorTy = D.getType().cast(); - SmallVector AShape(ATensorTy.getShape().begin(), - ATensorTy.getShape().end()); - SmallVector BShape(BTensorTy.getShape().begin(), - BTensorTy.getShape().end()); + auto AShape = ATensorTy.getShape(); + auto BShape = BTensorTy.getShape(); auto DShape = DTensorTy.getShape(); auto wpt = mmaLayout.getWarpsPerCTA(); bool isARow = ALayout.getIsMMAv1Row().cast().getValue(); bool isBRow = BLayout.getIsMMAv1Row().cast().getValue(); - bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes - bool isBVec4 = isBRow && BShape[isBRow] <= 16; - // TODO[Superjomn]: ld.v4 is not supported. - isAVec4 = true; - isBVec4 = true; - int packSize0 = (isARow || isAVec4) ? 1 : 2; - int packSize1 = (isBRow && !isBVec4) ? 2 : 1; - SmallVector fpw({2, 2, 1}); - SmallVector rep({2 * packSize0, 2 * packSize1, 1}); - SmallVector spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1}); - - Value loadedA = adaptor.a(); - Value loadedB = adaptor.b(); - Value loadedC = adaptor.c(); DotOpMmaV1ConversionHelper helper(mmaLayout); - unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]); - unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]); + unsigned numM = helper.getNumM(AShape, isARow); + unsigned numN = helper.getNumN(BShape, isBRow); unsigned NK = AShape[1]; - auto has = helper.extractLoadedOperand(loadedA, NK, rewriter); - auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter); + auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter); + auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter); // Initialize accumulators with external values, the acc holds the accumulator // value that is shared between the MMA instructions inside a DotOp, we can // call the order of the values the accumulator-internal order. - SmallVector acc = getElementsFromStruct(loc, loadedC, rewriter); + SmallVector acc = getElementsFromStruct(loc, adaptor.c(), rewriter); size_t resSize = acc.size(); // The resVals holds the final result of the DotOp. @@ -3719,38 +3699,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, auto bShape = bTensorTy.getShape(); auto cShape = cTensorTy.getShape(); - ValueTable has, hbs; - int mShapePerCTA{-1}, nShapePerCTA{-1}; - int mSizePerThread{-1}, nSizePerThread{-1}; - ArrayRef aOrder, bOrder; - Value llA, llB; BlockedEncodingAttr dLayout = dTensorTy.getEncoding().cast(); auto order = dLayout.getOrder(); auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); DotOpFMAConversionHelper helper(dLayout); - if (auto aDotOpLayout = - aTensorTy.getEncoding() - .dyn_cast()) { // get input from - // convert_layout - auto bDotOpLayout = - bTensorTy.getEncoding().dyn_cast(); - auto aLayout = aDotOpLayout.getParent().cast(); - auto bLayout = bDotOpLayout.getParent().cast(); + auto aDotOpLayout = aTensorTy.getEncoding().cast(); + auto bDotOpLayout = bTensorTy.getEncoding().cast(); + auto aLayout = aDotOpLayout.getParent().cast(); + auto bLayout = bDotOpLayout.getParent().cast(); - assert(bLayout); - llA = adaptor.a(); - llB = adaptor.b(); - } else if (auto aLayout = - aTensorTy.getEncoding() - .dyn_cast()) { // load input from smem - auto bLayout = bTensorTy.getEncoding().dyn_cast(); - assert(bLayout); - Value thread = getThreadId(rewriter, loc); - llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter); - llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter); - } + Value llA = adaptor.a(); + Value llB = adaptor.b(); auto sizePerThread = getSizePerThread(dLayout); auto shapePerCTA = getShapePerCTA(dLayout); @@ -3759,17 +3720,19 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, int M = aShape[0]; int N = bShape[1]; - mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - mSizePerThread = + int mShapePerCTA = + order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + int mSizePerThread = order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - nSizePerThread = + int nShapePerCTA = + order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + int nSizePerThread = order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread, - rewriter, loc); - hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread, - rewriter, loc); + auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, + mSizePerThread, rewriter, loc); + auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, + nSizePerThread, rewriter, loc); SmallVector ret = cc; for (unsigned k = 0; k < K; k++) { @@ -3780,7 +3743,6 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, for (unsigned nn = 0; nn < nSizePerThread; ++nn) { ret[z] = rewriter.create(loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]); - ++z; } } @@ -4310,9 +4272,10 @@ struct ExpOpConversionApprox // For FP64 input, call __nv_expf for higher-precision calculation if (elemTy.getIntOrFloatBitWidth() == 64) return {}; + const double log2e = 1.4426950408889634; - Value prod = - rewriter.create(loc, f32_ty, operands[0], f32_val(log2e)); + Value prod = fmul(f32_ty, operands[0], f32_val(log2e)); + PTXBuilder ptxBuilder; auto &exp2 = ptxBuilder.create("ex2")->o("approx").o("f32"); auto output = ptxBuilder.newOperand("=f"); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 82ccc3fe6..16df14ef9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -31,6 +31,7 @@ #include // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +// Operators #define inttoptr(...) rewriter.create(loc, __VA_ARGS__) #define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) #define zext(...) rewriter.create(loc, __VA_ARGS__) @@ -40,6 +41,7 @@ #define sub(...) rewriter.create(loc, __VA_ARGS__) #define fadd(...) rewriter.create(loc, __VA_ARGS__) #define mul(...) rewriter.create(loc, __VA_ARGS__) +#define fmul(...) rewriter.create(loc, __VA_ARGS__) #define smax(...) rewriter.create(loc, __VA_ARGS__) #define umax(...) rewriter.create(loc, __VA_ARGS__) #define fmax(...) rewriter.create(loc, __VA_ARGS__) @@ -90,6 +92,8 @@ #define address_of(...) rewriter.create(loc, __VA_ARGS__) #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) + +// Types #define i32_ty rewriter.getIntegerType(32) #define ui32_ty rewriter.getIntegerType(32, false) #define f16_ty rewriter.getF16Type() @@ -102,8 +106,9 @@ #define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) +#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) -// Creator for constant +// Constants #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) #define int_val(width, val) \ LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 53bab8a7e..3bc68feba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -36,7 +36,7 @@ namespace { class DecomposeDotOperand : public mlir::RewritePattern { public: - DecomposeDotOperand(mlir::MLIRContext *context) + explicit DecomposeDotOperand(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} @@ -84,7 +84,7 @@ public: // IIUC they are therefore not handled by DRR right now class SimplifyConversion : public mlir::RewritePattern { public: - SimplifyConversion(mlir::MLIRContext *context) + explicit SimplifyConversion(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 4, context) {} @@ -219,8 +219,8 @@ public: // // ----------------------------------------------------------------------------- -static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, - Attribute &ret) { +LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, + Attribute &ret) { ret = targetEncoding; if (auto expand_dims = dyn_cast(op)) { ret = triton::gpu::SliceEncodingAttr::get( @@ -246,7 +246,7 @@ inline bool expensive_to_remat(Operation *op) { if (isa(op)) return true; return false; -}; +} Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, BlockAndValueMapping &mapping) { @@ -276,7 +276,7 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, // are reachable from it without passing through any memory operation. class RematerializeBackward : public mlir::RewritePattern { public: - RematerializeBackward(mlir::MLIRContext *context) + explicit RematerializeBackward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} @@ -303,7 +303,7 @@ public: SetVector layout; llvm::MapVector toConvert; std::vector> queue; - queue.push_back({cvt, targetType.getEncoding()}); + queue.emplace_back(cvt, targetType.getEncoding()); int numCvts = 1; while (!queue.empty()) { Operation *currOp; @@ -341,7 +341,7 @@ public: continue; // we add one expensive conversion for the current operand numCvts += 1; - queue.push_back({opArgI, newEncoding}); + queue.emplace_back(opArgI, newEncoding); } } // if rematerialization would add more conversions than it removes @@ -351,8 +351,8 @@ public: SmallVector sortedValues; SetVector tmp; - for (auto it = toConvert.begin(); it != toConvert.end(); ++it) { - Value v = it->first; + for (auto &item : toConvert) { + Value v = item.first; if (v.getDefiningOp()) tmp.insert(v.getDefiningOp()); else @@ -393,7 +393,7 @@ public: class MoveConvertOutOfLoop : public mlir::RewritePattern { public: - MoveConvertOutOfLoop(mlir::MLIRContext *context) + explicit MoveConvertOutOfLoop(mlir::MLIRContext *context) : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} SmallVector @@ -406,7 +406,7 @@ public: newInitArgs[i] = rewriter.create( newInitArgs[i].getLoc(), newType, newInitArgs[i]); // Clone for loop - scf::ForOp newForOp = rewriter.create( + auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newInitArgs); newForOp->moveBefore(forOp); @@ -455,7 +455,7 @@ public: mlir::PatternRewriter &rewriter) const override { auto forOp = cast(op); auto iterArgs = forOp.getRegionIterArgs(); - for (auto iterArg : llvm::enumerate(iterArgs)) { + for (const auto &iterArg : llvm::enumerate(iterArgs)) { // if (iterArg.index() != 1) // continue; // skip non-tensor types @@ -517,7 +517,7 @@ public: class RematerializeForward : public mlir::RewritePattern { public: - RematerializeForward(mlir::MLIRContext *context) + explicit RematerializeForward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} @@ -584,7 +584,7 @@ public: // // ----------------------------------------------------------------------------- namespace { -static int computeCapabilityToMMAVersion(int computeCapability) { +int computeCapabilityToMMAVersion(int computeCapability) { if (computeCapability < 80) { return 1; } else if (computeCapability < 90) { @@ -595,9 +595,7 @@ static int computeCapabilityToMMAVersion(int computeCapability) { } } -static SmallVector -mmaVersionToShapePerWarp(int version, const ArrayRef &shape, - int numWarps) { +SmallVector mmaVersionToShapePerWarp(int version) { if (version == 1) return {16, 16}; else if (version == 2) @@ -608,12 +606,11 @@ mmaVersionToShapePerWarp(int version, const ArrayRef &shape, } } -SmallVector warpsPerTileV1(triton::DotOp dotOp, - const ArrayRef shape, +SmallVector warpsPerTileV1(const ArrayRef shape, int numWarps) { SmallVector ret = {1, 1}; SmallVector shapePerWarp = - mmaVersionToShapePerWarp(1, shape, numWarps); + mmaVersionToShapePerWarp(1 /*version*/); bool changed = false; do { changed = false; @@ -669,7 +666,7 @@ SmallVector warpsPerTileV2(triton::DotOp dotOp, class OptimizeBlockedToShared : public mlir::RewritePattern { public: - OptimizeBlockedToShared(mlir::MLIRContext *context) + explicit OptimizeBlockedToShared(mlir::MLIRContext *context) : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} @@ -717,7 +714,7 @@ public: class OptimizeConvertToDotOperand : public mlir::RewritePattern { public: - OptimizeConvertToDotOperand(mlir::MLIRContext *context) + explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context) : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} @@ -729,11 +726,12 @@ public: auto dstType = cvt.getResult().getType().cast(); // order ArrayRef order; - if(auto srcBlockedLayout = - srcType.getEncoding().dyn_cast()) + if (auto srcBlockedLayout = + srcType.getEncoding().dyn_cast()) order = srcBlockedLayout.getOrder(); - else if(auto srcSharedLayout = - srcType.getEncoding().dyn_cast()) + else if (auto srcSharedLayout = + srcType.getEncoding() + .dyn_cast()) order = srcSharedLayout.getOrder(); else return failure(); @@ -742,20 +740,18 @@ public: dstType.getEncoding().dyn_cast(); if (!dstDotOperandLayout) return failure(); - unsigned opIdx = dstDotOperandLayout.getOpIdx(); - if(!dstDotOperandLayout.getIsMMAv1Row()) + if (!dstDotOperandLayout.getIsMMAv1Row()) return failure(); - bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); - if((order[0] == 1 && isMMAv1Row) || - (order[0] == 0 && !isMMAv1Row)) + bool isMMAv1Row = + dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); + if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) return failure(); auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( - op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(), - newIsRow); + op->getContext(), dstDotOperandLayout.getOpIdx(), + dstDotOperandLayout.getParent(), newIsRow); auto newDstType = RankedTensorType::get( - dstType.getShape(), - dstType.getElementType(), newDstEncoding); + dstType.getShape(), dstType.getElementType(), newDstEncoding); auto newCvt = rewriter.create( op->getLoc(), newDstType, cvt.getOperand()); rewriter.replaceOp(op, newCvt.getResult()); @@ -763,7 +759,6 @@ public: } }; - class BlockedToMMA : public mlir::RewritePattern { int computeCapability; @@ -777,7 +772,7 @@ public: int version, int numWarps) { switch (version) { case 1: - return warpsPerTileV1(dotOp, shape, numWarps); + return warpsPerTileV1(shape, numWarps); case 2: return warpsPerTileV2(dotOp, shape, numWarps); default: @@ -821,27 +816,31 @@ public: Value b = dotOp.b(); auto oldAType = a.getType().cast(); auto oldBType = b.getType().cast(); - auto oldAOrder = oldAType.getEncoding().cast() - .getParent().cast().getOrder(); - auto oldBOrder = oldBType.getEncoding().cast() - .getParent().cast().getOrder(); + auto oldAOrder = oldAType.getEncoding() + .cast() + .getParent() + .cast() + .getOrder(); + auto oldBOrder = oldBType.getEncoding() + .cast() + .getParent() + .cast() + .getOrder(); Attribute isMMAv1RowA; Attribute isMMAv1RowB; - if(version == 1){ + if (version == 1) { isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); } auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, - newRetType.getEncoding(), - isMMAv1RowA)); + triton::gpu::DotOperandEncodingAttr::get( + oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, - newRetType.getEncoding(), - isMMAv1RowB)); + triton::gpu::DotOperandEncodingAttr::get( + oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB)); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); @@ -857,9 +856,8 @@ public: class FixupLoop : public mlir::RewritePattern { public: - FixupLoop(mlir::MLIRContext *context) - : mlir::RewritePattern(scf::ForOp::getOperationName(), 2, - context) {} + explicit FixupLoop(mlir::MLIRContext *context) + : mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, @@ -869,17 +867,17 @@ public: // Rewrite init argument SmallVector newInitArgs = forOp.getInitArgs(); bool shouldRematerialize = false; - for(size_t i = 0; i < newInitArgs.size(); i++){ + for (size_t i = 0; i < newInitArgs.size(); i++) { auto initArg = newInitArgs[i]; auto regionArg = forOp.getRegionIterArgs()[i]; - if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){ + if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) { shouldRematerialize = true; break; } } - if(!shouldRematerialize) + if (!shouldRematerialize) return failure(); - + scf::ForOp newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newInitArgs); @@ -894,8 +892,6 @@ public: } rewriter.replaceOp(forOp, newForOp.getResults()); return success(); - - } };