#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include using namespace mlir; namespace { #include "TritonGPUCombine.inc" using triton::DotOp; using triton::gpu::ConvertLayoutOp; using triton::gpu::DotOperandEncodingAttr; using triton::gpu::MmaEncodingAttr; // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- // convert(blocked, dot_operand) -> // convert(blocked, mma) + convert(mma, dot_operand) // if this value is itself the result of a dot operation // this is a heuristic to accommodate some pattern seen in fused attention // kernels. // TODO: replace this by something more generic, i.e. layout-aware CSE class DecomposeDotOperand : public mlir::RewritePattern { public: explicit DecomposeDotOperand(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); auto srcType = convert.getOperand().getType().cast(); auto dstType = convert.getType().cast(); if (srcType.getEncoding().isa() && dstType.getEncoding().isa()) { auto dstDotOperand = dstType.getEncoding().cast(); auto dstParent = dstDotOperand.getParent(); if (dstDotOperand.getOpIdx() == 1 || !dstParent.isa()) return mlir::failure(); auto dstParentMma = dstParent.cast(); if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1) return mlir::failure(); SetVector bwdSlices; mlir::getBackwardSlice(convert.getResult(), &bwdSlices); if (llvm::find_if(bwdSlices, [](Operation *op) { return isa(op); }) == bwdSlices.end()) return mlir::failure(); auto tmpType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), dstParentMma); auto tmp = rewriter.create( convert.getLoc(), tmpType, convert.getOperand()); auto newConvert = rewriter.create( convert.getLoc(), dstType, tmp); rewriter.replaceOp(op, {newConvert}); return mlir::success(); } return mlir::failure(); } }; class SimplifyReduceCvt : public mlir::RewritePattern { public: explicit SimplifyReduceCvt(mlir::MLIRContext *context) : mlir::RewritePattern(triton::ReduceOp::getOperationName(), 2, context) { } mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto reduce = cast(*op); auto reduceArg = dyn_cast( reduce.getOperand().getDefiningOp()); if (!reduceArg) return mlir::failure(); // this may generate unsupported conversions in the LLVM codegen if (reduceArg.getOperand() .getType() .cast() .getEncoding() .isa()) return mlir::failure(); auto newReduce = rewriter.create( op->getLoc(), reduce.redOp(), reduceArg.getOperand(), reduce.axis()); if (isa( *reduceArg.getOperand().getDefiningOp())) return mlir::failure(); Value newRet = newReduce.getResult(); // it's still beneficial to move the conversion // to after the reduce if necessary since it will be // done on a rank-reduced tensor hence cheaper if (newRet.getType() != reduce.getResult().getType()) newRet = rewriter.create( op->getLoc(), reduce.getResult().getType(), newRet); rewriter.replaceOp(op, newRet); return success(); } }; // Layout conversions can't deduce their return type automatically. // IIUC they are therefore not handled by DRR right now class SimplifyConversion : public mlir::RewritePattern { public: explicit SimplifyConversion(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 4, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention auto srcType = convert.getOperand().getType().cast(); auto dstType = convert.getType().cast(); if (dstType.getEncoding().isa() && srcType.getEncoding().isa()) return mlir::failure(); // convert to the same layout -- we can delete if (op->getResultTypes() == op->getOperandTypes()) { rewriter.replaceOp(op, op->getOperands()); return mlir::success(); } Operation *arg = op->getOperand(0).getDefiningOp(); // block argument if (!arg) return mlir::failure(); // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) auto alloc_tensor = dyn_cast(arg); if (alloc_tensor) { if (!isSharedEncoding(op->getResult(0))) { return mlir::failure(); } rewriter.replaceOpWithNewOp( op, op->getResult(0).getType()); return mlir::success(); } // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) auto insert_slice = dyn_cast(arg); if (insert_slice) { if (!isSharedEncoding(op->getResult(0))) { return mlir::failure(); } auto newType = op->getResult(0).getType().cast(); // Ensure that the new insert_slice op is placed in the same place as the // old insert_slice op. Otherwise, the new insert_slice op may be placed // after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(insert_slice); auto newArg = rewriter.create( op->getLoc(), newType, insert_slice.dst()); rewriter.replaceOpWithNewOp( op, newType, insert_slice.src(), newArg.getResult(), insert_slice.index(), insert_slice.mask(), insert_slice.other(), insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(), insert_slice.axis()); return mlir::success(); } // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) auto extract_slice = dyn_cast(arg); if (extract_slice) { if (!isSharedEncoding(op->getResult(0))) { return mlir::failure(); } auto origType = extract_slice.source().getType().cast(); auto newType = RankedTensorType::get( origType.getShape(), origType.getElementType(), op->getResult(0).getType().cast().getEncoding()); auto origResType = op->getResult(0).getType().cast(); auto resType = RankedTensorType::get( origResType.getShape(), origResType.getElementType(), extract_slice.getType().cast().getEncoding()); // Ensure that the new extract_slice op is placed in the same place as the // old extract_slice op. Otherwise, the new extract_slice op may be placed // after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(extract_slice); auto newArg = rewriter.create( op->getLoc(), newType, extract_slice.source()); rewriter.replaceOpWithNewOp( op, resType, newArg.getResult(), extract_slice.offsets(), extract_slice.sizes(), extract_slice.strides(), extract_slice.static_offsets(), extract_slice.static_sizes(), extract_slice.static_strides()); return mlir::success(); } // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (llvm::isa(arg)) { if (arg->getOperand(0).getDefiningOp() && !isSharedEncoding(arg->getOperand(0)) && isSharedEncoding(convert.getOperand()) && !isSharedEncoding(convert.getResult())) { return mlir::failure(); } if (isSharedEncoding(convert.getOperand()) && isSharedEncoding(convert.getResult())) { return mlir::failure(); } auto srcType = convert.getOperand().getType().cast(); auto srcShared = srcType.getEncoding().dyn_cast(); if (srcShared && srcShared.getVec() > 1) return mlir::failure(); rewriter.replaceOpWithNewOp( op, op->getResultTypes().front(), arg->getOperand(0)); return mlir::success(); } // cvt(type1, splat(type2, x)) -> splat(type1, x) if (auto splat = llvm::dyn_cast(arg)) { rewriter.replaceOpWithNewOp(op, op->getResultTypes(), splat.src()); return mlir::success(); } // cvt(type1, make_range(type2, x)) -> make_range(type1, x) if (auto range = llvm::dyn_cast(arg)) { rewriter.replaceOpWithNewOp( op, op->getResultTypes(), range.start(), range.end()); return mlir::success(); } // cvt(type, constant) -> constant if (auto cst = llvm::dyn_cast(arg)) if (auto ret = cst.getValue().dyn_cast()) { auto newRet = SplatElementsAttr::get(op->getResultTypes().front(), ret.getSplatValue()); rewriter.replaceOpWithNewOp(op, newRet); return mlir::success(); } return mlir::failure(); } }; // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- // TODO: Interface LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, Attribute &ret) { ret = targetEncoding; if (auto expand_dims = dyn_cast(op)) { ret = triton::gpu::SliceEncodingAttr::get( op->getContext(), expand_dims.axis(), targetEncoding); } if (auto reduce = dyn_cast(op)) { auto sliceEncoding = targetEncoding.dyn_cast(); if (!sliceEncoding) return failure(); ret = sliceEncoding.getParent(); } return success(); } // TODO: Interface LogicalResult getForwardEncoding(Attribute sourceEncoding, Operation *op, Attribute &ret) { if (op->hasTrait()) { ret = sourceEncoding; return success(); } if (isa(op)) { ret = Attribute(); return success(); } return failure(); } inline bool expensive_to_remat(Operation *op) { if (!op) return true; if (isa(op)) return true; if (isa(op)) return true; return false; } LogicalResult simulateBackwardRematerialization( Operation *initOp, SetVector &processed, SetVector &layout, llvm::MapVector &toConvert, Attribute targetEncoding) { // DFS std::vector> queue; queue.emplace_back(initOp, targetEncoding); // We want to see the effect of converting `initOp` to a new layout // so we initialize `numCvts = 1`. int numCvts = 1; while (!queue.empty()) { Operation *currOp; Attribute currLayout; std::tie(currOp, currLayout) = queue.back(); queue.pop_back(); // If the current operation is expensive to rematerialize, // we stop everything if (expensive_to_remat(currOp)) return mlir::failure(); // we would propagate the conversion here numCvts -= 1; // check if the conversion could be folded at this operation if (isa(*currOp)) continue; // done processing processed.insert(currOp); layout.insert(currLayout); // add all operands to the queue for (Value argI : currOp->getOperands()) { Attribute newEncoding; // cannot invert the current encoding for this operand // we stop everything if (failed(invertEncoding(currLayout, currOp, newEncoding))) { return mlir::failure(); } if (toConvert.count(argI) && toConvert[argI] != newEncoding) return mlir::failure(); // Operation *opArgI = argI.getDefiningOp(); toConvert.insert({argI, newEncoding}); if (!opArgI || processed.contains(opArgI) || (opArgI->getBlock() != initOp->getBlock())) continue; // we add one expensive conversion for the current operand numCvts += 1; queue.emplace_back(opArgI, newEncoding); } } // if rematerialization would add more conversions than it removes // then we don't do it if (numCvts > 0) return mlir::failure(); return mlir::success(); } // Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, BlockAndValueMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); auto origType = op->getResult(0).getType().cast(); auto newType = RankedTensorType::get( origType.getShape(), origType.getElementType(), newOp->getOperand(0).getType().cast().getEncoding()); newOp->getResult(0).setType(newType); auto typeInfer = dyn_cast(newOp); if (typeInfer) { SmallVector newType; auto success = typeInfer.inferReturnTypes( newOp->getContext(), newOp->getLoc(), newOp->getOperands(), newOp->getAttrDictionary(), newOp->getRegions(), newType); if (succeeded(success)) newOp->getResult(0).setType(newType.front()); } return newOp; } // class MoveConvertOutOfIf : public mlir::RewritePattern { public: explicit MoveConvertOutOfIf(mlir::MLIRContext *context) : mlir::RewritePattern(scf::IfOp::getOperationName(), 2, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto ifOp = cast(*op); auto thenYield = ifOp.thenYield(); auto elseYield = ifOp.elseYield(); int numOps = thenYield.getNumOperands(); SmallVector newThenYieldOps = thenYield.getOperands(); SmallVector newElseYieldOps = elseYield.getOperands(); SetVector thenCvts; SetVector elseCvts; SmallVector newRetTypes; BlockAndValueMapping mapping; for (size_t i = 0; i < numOps; i++) { auto thenCvt = dyn_cast( thenYield.getOperand(i).getDefiningOp()); auto elseCvt = dyn_cast( elseYield.getOperand(i).getDefiningOp()); if (thenCvt && elseCvt && std::distance(thenCvt->user_begin(), thenCvt->user_end()) == 1 && std::distance(elseCvt->user_begin(), elseCvt->user_end()) == 1 && thenCvt.getOperand().getType() == elseCvt.getOperand().getType()) { mapping.map(thenCvt.getResult(), thenCvt.getOperand()); mapping.map(elseCvt.getResult(), elseCvt.getOperand()); newRetTypes.push_back(thenCvt.getOperand().getType()); thenCvts.insert((Operation *)thenCvt); elseCvts.insert((Operation *)elseCvt); } else newRetTypes.push_back(thenYield.getOperand(i).getType()); } if (mapping.getValueMap().empty()) return mlir::failure(); rewriter.setInsertionPoint(op); auto newIfOp = rewriter.create(ifOp.getLoc(), newRetTypes, ifOp.getCondition(), true); // rematerialize `then` block rewriter.setInsertionPointToEnd(newIfOp.thenBlock()); for (Operation &op : ifOp.thenBlock()->getOperations()) { if (thenCvts.contains(&op)) { mapping.map(op.getResult(0), mapping.lookup(op.getOperand(0))); continue; } rewriter.clone(op, mapping); } // rematerialize `else` block rewriter.setInsertionPointToEnd(newIfOp.elseBlock()); for (Operation &op : ifOp.elseBlock()->getOperations()) { if (elseCvts.contains(&op)) { mapping.map(op.getResult(0), mapping.lookup(op.getOperand(0))); continue; } rewriter.clone(op, mapping); } rewriter.setInsertionPointAfter(newIfOp); SmallVector newRetValues = newIfOp.getResults(); for (size_t i = 0; i < numOps; i++) { if (newIfOp.getResult(i).getType() != ifOp.getResult(i).getType()) { newRetValues[i] = rewriter.create( newIfOp.getLoc(), ifOp.getResult(i).getType(), newIfOp.getResult(i)); } } rewriter.replaceOp(op, newRetValues); return mlir::success(); } }; // class FoldConvertAndReduce : public mlir::RewritePattern { public: explicit FoldConvertAndReduce(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *cvtOp, mlir::PatternRewriter &rewriter) const override { auto cvt = dyn_cast(*cvtOp); auto srcEncoding = cvt.getOperand().getType().cast().getEncoding(); auto dstEncoding = cvt.getResult().getType().cast().getEncoding(); if (srcEncoding.isa()) return failure(); SetVector cvtSlices; auto filter = [&](Operation *op) { return op->getBlock() == cvt->getBlock() && !(isa(op) && !op->getResult(0).getType().isa()) && !isa(op); }; mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); if (cvtSlices.empty()) return failure(); llvm::MapVector toConvert; for (Operation *op : cvtSlices) { // don't rematerialize anything expensive if (expensive_to_remat(op)) return failure(); // don't rematerialize non-element-wise if (!op->hasTrait()) return failure(); Attribute dstEncoding = cvt.getOperand().getType().cast().getEncoding(); // don't rematerialize if it adds an extra conversion that can't // be removed for (Value arg : op->getOperands()) { Operation *argOp = arg.getDefiningOp(); SetVector processed; SetVector layout; llvm::MapVector toConvert; if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 && failed(simulateBackwardRematerialization(argOp, processed, layout, toConvert, dstEncoding))) { return failure(); } } } BlockAndValueMapping mapping; auto op = cvtSlices.front(); for (Value arg : op->getOperands()) { if (arg.getDefiningOp() == cvt) mapping.map(arg, cvt.getOperand()); else { auto cvtI = rewriter.create( arg.getLoc(), cvt.getOperand().getType(), arg); if (Operation *argOp = arg.getDefiningOp()) cvtI->moveAfter(argOp); mapping.map(arg, cvtI); } } rewriter.setInsertionPoint(op); Operation *newOp = rewriter.clone(*op, mapping); auto oldType = op->getResult(0).getType().cast(); auto newType = RankedTensorType::get( oldType.getShape(), oldType.getElementType(), cvt.getOperand().getType().cast().getEncoding()); newOp->getResult(0).setType(newType); auto newCvtType = RankedTensorType::get( oldType.getShape(), oldType.getElementType(), cvt.getResult().getType().cast().getEncoding()); auto newCvt = rewriter.create( newOp->getLoc(), newCvtType, newOp->getResult(0)); rewriter.replaceOp(op, newCvt->getResults()); return success(); } }; // Layout conversions are expensive. They require going through // shared memory, which is orders of magnitude slower than // other non-i/o operations in the dialect. // It therefore makes sense to remove them whenever possible, // even if it means rematerializing all values whose definitions // are reachable from it without passing through any memory operation. class RematerializeBackward : public mlir::RewritePattern { public: explicit RematerializeBackward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *cvt, mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(cvt)) return mlir::failure(); // we don't touch block arguments Operation *op = cvt->getOperand(0).getDefiningOp(); if (!op) return mlir::failure(); // we don't want to rematerialize any conversion to/from shared if (isSharedEncoding(cvt->getResults()[0]) || isSharedEncoding(cvt->getOperand(0))) return mlir::failure(); // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention auto targetType = cvt->getResultTypes()[0].cast(); if (targetType.getEncoding().isa()) return mlir::failure(); // DFS SetVector processed; SetVector layout; llvm::MapVector toConvert; std::vector> queue; queue.emplace_back(cvt, targetType.getEncoding()); int numCvts = 1; while (!queue.empty()) { Operation *currOp; Attribute currLayout; std::tie(currOp, currLayout) = queue.back(); queue.pop_back(); // If the current operation is expensive to rematerialize, // we stop everything if (expensive_to_remat(currOp)) break; // a conversion will be removed here (i.e. transferred to operands) numCvts -= 1; // done processing processed.insert(currOp); layout.insert(currLayout); // add all operands to the queue for (Value argI : currOp->getOperands()) { Attribute newEncoding; // cannot invert the current encoding for this operand // we stop everything if (failed(invertEncoding(currLayout, currOp, newEncoding))) return mlir::failure(); if (toConvert.count(argI) && toConvert[argI] != newEncoding) return mlir::failure(); // Operation *opArgI = argI.getDefiningOp(); toConvert.insert({argI, newEncoding}); if (!opArgI || processed.contains(opArgI) || (opArgI->getBlock() != cvt->getBlock())) continue; // if the conversion can be folded into opArgI then // we don't count this conversion as expensive if (isa(*opArgI)) continue; // we add one expensive conversion for the current operand numCvts += 1; queue.emplace_back(opArgI, newEncoding); } } // if rematerialization would add more conversions than it removes // then we don't do it if (numCvts > 0) return mlir::failure(); SmallVector sortedValues; SetVector tmp; for (auto &item : toConvert) { Value v = item.first; if (v.getDefiningOp()) tmp.insert(v.getDefiningOp()); else sortedValues.push_back(v); } tmp = mlir::topologicalSort(tmp); for (Operation *op : tmp) sortedValues.push_back(op->getResult(0)); BlockAndValueMapping mapping; for (Value currOperand : sortedValues) { // unpack information Attribute targetLayout = toConvert.lookup(currOperand); // rematerialize the operand if necessary Operation *currOperation = currOperand.getDefiningOp(); if (processed.contains(currOperation)) { currOperation = cloneWithInferType(rewriter, currOperation, mapping); currOperand = currOperation->getResult(0); } // compute target type for the layout cast auto currType = currOperand.getType().cast(); auto newType = RankedTensorType::get( currType.getShape(), currType.getElementType(), targetLayout); auto newOperand = rewriter.create( currOperand.getLoc(), newType, currOperand); if (currOperation) newOperand->moveAfter(currOperation); mapping.map(currOperand, newOperand); } rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); return mlir::success(); } }; // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- class MoveConvertOutOfLoop : public mlir::RewritePattern { public: explicit MoveConvertOutOfLoop(mlir::MLIRContext *context) : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} SmallVector rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, size_t i, RankedTensorType newType, triton::gpu::ConvertLayoutOp origConversion) const { // Rewrite init argument Type origType = forOp.getInitArgs()[i].getType(); SmallVector newInitArgs = forOp.getInitArgs(); newInitArgs[i] = rewriter.create( newInitArgs[i].getLoc(), newType, newInitArgs[i]); // Clone for loop auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newInitArgs); newForOp->moveBefore(forOp); rewriter.setInsertionPointToStart(newForOp.getBody()); BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]); // the iter arg of interest may have other uses than the conversion // we're hoisting out of the loop. If that's the case we will // need to add extra conversions for all uses... which is only useful // if these extra conversions can be removed by another pattern auto oldArg = forOp.getRegionIterArgs()[i]; auto newArg = newForOp.getRegionIterArgs()[i]; auto newArgFallback = rewriter.create( newForOp.getLoc(), origType, newArg); mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); for (Operation &op : forOp.getBody()->without_terminator()) { if (&op == (Operation *)(&origConversion)) continue; Operation *newOp = rewriter.clone(op, mapping); if (find(oldArg.getUsers(), &op) != oldArg.getUsers().end()) newOp->replaceUsesOfWith(newArg, newArgFallback); } // create yield, inserting conversions if necessary auto yieldOp = forOp.getBody()->getTerminator(); SmallVector newYieldArgs; for (Value arg : yieldOp->getOperands()) newYieldArgs.push_back(mapping.lookup(arg)); newYieldArgs[i] = rewriter.create( yieldOp->getLoc(), newType, newYieldArgs[i]); rewriter.create(forOp.getLoc(), newYieldArgs); // replace SmallVector newResults = newForOp->getResults(); newResults[i] = rewriter.create( rewriter.getUnknownLoc(), origType, newForOp->getResult(i)); newResults[i].getDefiningOp()->moveAfter(newForOp); return newResults; } mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto forOp = cast(op); auto iterArgs = forOp.getRegionIterArgs(); for (const auto &iterArg : llvm::enumerate(iterArgs)) { // if (iterArg.index() != 1) // continue; // skip non-tensor types if (!iterArg.value().getType().isa()) continue; // we only move `iterArg` out of the loop if // - there is only a single conversion use // - moving this conversion out of the loop will not generate // any extra non-removable conversion auto users = iterArg.value().getUsers(); // check first condition SetVector cvtTargetTypes; for (auto user : users) { if (isa(user)) { auto newType = user->getResults()[0].getType().cast(); auto oldType = user->getOperand(0).getType().cast(); if (oldType.getEncoding().isa() && newType.getEncoding() .isa()) { continue; } if (newType.getEncoding().isa()) { if (newType.getEncoding() .cast() .getVec() == 1) continue; } cvtTargetTypes.insert(newType); } } if (cvtTargetTypes.size() != 1) continue; // TODO: check second condition for (auto user : users) { if (isa(user)) continue; } // check for (auto op : iterArg.value().getUsers()) { auto cvt = dyn_cast(op); if (!cvt) continue; auto targetType = op->getResultTypes()[0].cast(); auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(), targetType, cvt); rewriter.replaceOp(forOp, newFor); return success(); } } return failure(); } }; // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- class RematerializeForward : public mlir::RewritePattern { public: explicit RematerializeForward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *_cvtOp, mlir::PatternRewriter &rewriter) const override { auto cvt = cast(_cvtOp); auto forOp = dyn_cast(cvt->getParentOp()); if (!forOp) return mlir::failure(); auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; SetVector cvtSlices; auto filter = [&](Operation *op) { return isInLoop(op) && !isa(op) && !isa(op) && !isa(op) && !isa(op); }; mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); if (cvtSlices.empty()) return failure(); for (Operation *op : cvtSlices) { if (!op->hasTrait() && !op->hasTrait()) return failure(); for (Value arg : op->getOperands()) { Operation *argOp = arg.getDefiningOp(); if (argOp && (argOp != cvt) && !isa(argOp)) { return failure(); } } } // otherwise, we push the conversion forward // since we'll be able to move it out of // the loop once it reaches the yield op // op(cvt(arg_0), arg_1, ..., arg_n) // -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n))) BlockAndValueMapping mapping; auto op = cvtSlices.front(); for (Value arg : op->getOperands()) { if (arg.getDefiningOp() == cvt) mapping.map(arg, cvt.getOperand()); else { auto cvtI = rewriter.create( arg.getLoc(), cvt.getOperand().getType(), arg); mapping.map(arg, cvtI); } } Operation *newOp = rewriter.clone(*op, mapping); newOp->getResult(0).setType(cvt.getOperand().getType()); auto newCvt = rewriter.create( newOp->getLoc(), cvt.getResult().getType(), newOp->getResult(0)); rewriter.replaceOp(op, newCvt->getResults()); return success(); } }; // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- namespace { int computeCapabilityToMMAVersion(int computeCapability) { if (computeCapability < 70) { return 0; } else if (computeCapability < 80) { return 1; } else if (computeCapability < 90) { return 2; } else { assert(false && "computeCapability > 90 not supported"); return 3; } } SmallVector mmaVersionToShapePerWarp(int version) { if (version == 1) return {16, 16}; else if (version == 2) return {16, 8}; else { assert(false && "version not supported"); return {0, 0}; } } SmallVector warpsPerTileV1(const ArrayRef shape, int numWarps) { SmallVector ret = {1, 1}; SmallVector shapePerWarp = mmaVersionToShapePerWarp(1 /*version*/); bool changed = false; do { changed = false; int pre = ret[0]; if (ret[0] * ret[1] < numWarps) { ret[0] = std::clamp(ret[0] * 2, 1, shape[0] / shapePerWarp[0]); changed = pre != ret[0]; } if (ret[0] * ret[1] < numWarps) { pre = ret[1]; ret[1] = std::clamp(ret[1] * 2, 1, shape[1] / shapePerWarp[1]); changed = pre != ret[1]; } } while (changed); return ret; } SmallVector warpsPerTileV2(triton::DotOp dotOp, const ArrayRef shape, int numWarps) { SetVector slices; mlir::getForwardSlice(dotOp.getResult(), &slices); if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != slices.end()) return {(unsigned)numWarps, 1}; SmallVector ret = {1, 1}; SmallVector shapePerWarp = {16, 8}; bool changed = false; // TODO (@daadaada): double-check. // original logic in // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 // seems buggy for shape = [32, 16] ? do { changed = false; if (ret[0] * ret[1] >= numWarps) break; if (shape[0] / shapePerWarp[0] / ret[0] >= shape[1] / (shapePerWarp[1] * 2) / ret[1]) { if (ret[0] < shape[0] / shapePerWarp[0]) { ret[0] *= 2; } else ret[1] *= 2; } else { ret[1] *= 2; } } while (true); return ret; } } // namespace class OptimizeBlockedToShared : public mlir::RewritePattern { public: explicit OptimizeBlockedToShared(mlir::MLIRContext *context) : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto cvt = cast(op); auto srcType = cvt.getOperand().getType().cast(); auto dstType = cvt.getResult().getType().cast(); auto srcBlockedLayout = srcType.getEncoding().dyn_cast(); auto dstSharedLayout = dstType.getEncoding().dyn_cast(); if (!srcBlockedLayout || !dstSharedLayout) return failure(); if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder()) return failure(); // For now only works if single use is transpose // TODO: rematerialize #shared uses auto users = op->getUsers(); if (std::distance(users.begin(), users.end()) != 1 || !isa(*users.begin())) return failure(); auto tmpShared = triton::gpu::SharedEncodingAttr::get( op->getContext(), dstSharedLayout.getVec(), dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(), srcBlockedLayout.getOrder()); auto tmpType = RankedTensorType::get(srcType.getShape(), srcType.getElementType(), tmpShared); auto tmpCvt = rewriter.create( op->getLoc(), tmpType, cvt.getOperand()); auto newDstType = RankedTensorType::get( users.begin()->getResultTypes()[0].cast().getShape(), srcType.getElementType(), dstSharedLayout); auto newTrans = rewriter.create(op->getLoc(), newDstType, tmpCvt.getResult()); rewriter.replaceOp(*users.begin(), newTrans.getResult()); return success(); } }; class OptimizeConvertToDotOperand : public mlir::RewritePattern { public: explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context) : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto cvt = cast(op); auto srcType = cvt.getOperand().getType().cast(); auto dstType = cvt.getResult().getType().cast(); // order ArrayRef order; if (auto srcBlockedLayout = srcType.getEncoding().dyn_cast()) order = srcBlockedLayout.getOrder(); else if (auto srcSharedLayout = srcType.getEncoding() .dyn_cast()) order = srcSharedLayout.getOrder(); else return failure(); // dot operand output auto dstDotOperandLayout = dstType.getEncoding().dyn_cast(); if (!dstDotOperandLayout) return failure(); if (!dstDotOperandLayout.getIsMMAv1Row()) return failure(); bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) return failure(); auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(), newIsRow); auto newDstType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), newDstEncoding); auto newCvt = rewriter.create( op->getLoc(), newDstType, cvt.getOperand()); rewriter.replaceOp(op, newCvt.getResult()); return success(); } }; class BlockedToMMA : public mlir::RewritePattern { int computeCapability; public: BlockedToMMA(mlir::MLIRContext *context, int computeCapability) : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), computeCapability(computeCapability) {} static SmallVector getWarpsPerTile(triton::DotOp dotOp, const ArrayRef shape, int version, int numWarps) { switch (version) { case 1: return warpsPerTileV1(shape, numWarps); case 2: return warpsPerTileV2(dotOp, shape, numWarps); default: assert(false && "not supported version"); return {0, 0}; } } mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto dotOp = cast(op); // TODO: Check data-types and SM compatibility auto oldRetType = dotOp.getResult().getType().cast(); if (!oldRetType.getEncoding() || oldRetType.getEncoding().isa()) return failure(); auto AType = dotOp.getOperand(0).getType().cast(); auto BType = dotOp.getOperand(1).getType().cast(); // for FMA, should retain the blocked layout. int versionMajor = computeCapabilityToMMAVersion(computeCapability); if (!supportMMA(dotOp, versionMajor)) return failure(); auto AOrder = AType.getEncoding() .cast() .getParent() .cast() .getOrder(); auto BOrder = BType.getEncoding() .cast() .getParent() .cast() .getOrder(); // get MMA encoding for the given number of warps auto retShape = oldRetType.getShape(); auto mod = op->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); auto warpsPerTile = getWarpsPerTile(dotOp, retShape, versionMajor, numWarps); triton::gpu::MmaEncodingAttr mmaEnc; if (versionMajor == 1) { auto shapeA = AType.getShape(); auto shapeB = BType.getShape(); bool isARow = AOrder[0] != 0; bool isBRow = BOrder[0] != 0; mmaEnc = triton::gpu::MmaEncodingAttr::get( oldRetType.getContext(), versionMajor, warpsPerTile, shapeA, shapeB, isARow, isBRow); } else if (versionMajor == 2) { mmaEnc = triton::gpu::MmaEncodingAttr::get( oldRetType.getContext(), versionMajor, 0 /*versionMinor*/, warpsPerTile); } else { assert(false && "Mma layout only support versionMajor of 1 or 2"); } auto newRetType = RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc); // convert accumulator auto oldAcc = dotOp.getOperand(2); auto newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); Value a = dotOp.a(); Value b = dotOp.b(); auto oldAType = a.getType().cast(); auto oldBType = b.getType().cast(); auto oldAOrder = oldAType.getEncoding() .cast() .getParent() .cast() .getOrder(); auto oldBOrder = oldBType.getEncoding() .cast() .getParent() .cast() .getOrder(); Attribute isMMAv1RowA; Attribute isMMAv1RowB; if (versionMajor == 1) { isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); } auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), triton::gpu::DotOperandEncodingAttr::get( oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), triton::gpu::DotOperandEncodingAttr::get( oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB)); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32()); rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); return success(); } }; class FixupLoop : public mlir::RewritePattern { public: explicit FixupLoop(mlir::MLIRContext *context) : mlir::RewritePattern(scf::ForOp::getOperationName(), 2, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto forOp = cast(op); // Rewrite init argument SmallVector newInitArgs = forOp.getInitArgs(); bool shouldRematerialize = false; for (size_t i = 0; i < newInitArgs.size(); i++) { auto initArg = newInitArgs[i]; auto regionArg = forOp.getRegionIterArgs()[i]; if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() || newInitArgs[i].getType() != forOp.getResultTypes()[i]) { shouldRematerialize = true; break; } } if (!shouldRematerialize) return failure(); scf::ForOp newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newInitArgs); newForOp->moveBefore(forOp); rewriter.setInsertionPointToStart(newForOp.getBody()); BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); for (Operation &op : forOp.getBody()->getOperations()) { rewriter.clone(op, mapping); } rewriter.replaceOp(forOp, newForOp.getResults()); return success(); } }; // This pattern collects the wrong Mma those need to update and create the right // ones for each. class CollectMmaToUpdateForVolta : public mlir::RewritePattern { DenseMap &mmaToUpdate; public: CollectMmaToUpdateForVolta( mlir::MLIRContext *ctx, DenseMap &mmaToUpdate) : mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx), mmaToUpdate(mmaToUpdate) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto dotOp = cast(op); auto *ctx = dotOp->getContext(); auto AT = dotOp.a().getType().cast(); auto BT = dotOp.b().getType().cast(); auto DT = dotOp.d().getType().cast(); if (!DT.getEncoding()) return failure(); auto mmaLayout = DT.getEncoding().dyn_cast(); if (!(mmaLayout && mmaLayout.isVolta())) return failure(); // Has processed. if (mmaToUpdate.count(mmaLayout)) return failure(); auto dotOperandA = AT.getEncoding().cast(); auto dotOperandB = BT.getEncoding().cast(); bool isARow = dotOperandA.getIsMMAv1Row().cast().getValue(); bool isBRow = dotOperandB.getIsMMAv1Row().cast().getValue(); auto [isARow_, isBRow_, isAVec4, isBVec4] = mmaLayout.decodeVoltaLayoutStates(); if (isARow_ == isARow && isBRow_ == isBRow) { return failure(); // No need to update } auto newMmaLayout = MmaEncodingAttr::get( ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(), AT.getShape(), BT.getShape(), isARow, isBRow); // Collect the wrong MMA Layouts, and mark need to update. mmaToUpdate.try_emplace(mmaLayout, newMmaLayout); return failure(); } }; // Correct the versionMinor field in MmaEncodingAttr for Volta. class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern { const DenseMap &mmaToUpdate; enum class Kind { kUnk, kCvtToMma, kCvtToDotOp, kDot, kConstant, }; mutable Kind rewriteKind{Kind::kUnk}; public: UpdateMMAVersionMinorForVolta( mlir::MLIRContext *ctx, llvm::StringRef opName, const DenseMap &mmaToUpdate) : RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {} LogicalResult match(Operation *op) const override { MmaEncodingAttr mma; if (mmaToUpdate.empty()) return failure(); if (op->getNumResults() != 1) return failure(); auto tensorTy = op->getResult(0).getType().dyn_cast(); if (!tensorTy) return failure(); // ConvertLayoutOp if (auto cvt = llvm::dyn_cast(op)) { // cvt X -> dot_operand if (auto dotOperand = tensorTy.getEncoding().dyn_cast()) { mma = dotOperand.getParent().dyn_cast(); rewriteKind = Kind::kCvtToDotOp; if (mma && mmaToUpdate.count(mma)) return success(); } if ((mma = tensorTy.getEncoding().dyn_cast())) { // cvt X -> mma rewriteKind = Kind::kCvtToMma; if (mma && mmaToUpdate.count(mma)) return success(); } } else if (auto dot = llvm::dyn_cast(op)) { // DotOp mma = dot.d() .getType() .cast() .getEncoding() .dyn_cast(); rewriteKind = Kind::kDot; } else if (auto constant = llvm::dyn_cast(op)) { // ConstantOp mma = tensorTy.getEncoding().dyn_cast(); rewriteKind = Kind::kConstant; } return success(mma && mmaToUpdate.count(mma)); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { switch (rewriteKind) { case Kind::kDot: rewriteDot(op, rewriter); break; case Kind::kConstant: rewriteConstant(op, rewriter); break; case Kind::kCvtToDotOp: rewriteCvtDotOp(op, rewriter); break; case Kind::kCvtToMma: rewriteCvtToMma(op, rewriter); break; default: llvm::report_fatal_error("Not supported rewrite kind"); } } private: void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const { auto *ctx = op->getContext(); auto cvt = llvm::cast(op); auto tensorTy = cvt.result().getType().cast(); auto dotOperand = tensorTy.getEncoding().cast(); MmaEncodingAttr newMma = mmaToUpdate.lookup(dotOperand.getParent().cast()); auto newDotOperand = DotOperandEncodingAttr::get( ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row()); auto newTensorTy = RankedTensorType::get( tensorTy.getShape(), tensorTy.getElementType(), newDotOperand); rewriter.replaceOpWithNewOp(op, newTensorTy, cvt.getOperand()); } void rewriteDot(Operation *op, PatternRewriter &rewriter) const { auto *ctx = op->getContext(); auto dot = llvm::cast(op); auto tensorTy = dot.d().getType().cast(); auto mma = tensorTy.getEncoding().cast(); auto newMma = mmaToUpdate.lookup(mma); auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType(), newMma); rewriter.replaceOpWithNewOp(op, newTensorTy, dot.a(), dot.b(), dot.c(), dot.allowTF32()); } void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const { auto *ctx = op->getContext(); auto cvt = llvm::cast(op); auto tensorTy = cvt.result().getType().cast(); auto mma = tensorTy.getEncoding().cast(); auto newMma = mmaToUpdate.lookup(mma); auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType(), newMma); rewriter.replaceOpWithNewOp(op, newTensorTy, cvt.getOperand()); } void rewriteConstant(Operation *op, PatternRewriter &rewriter) const { auto *ctx = op->getContext(); auto constant = llvm::cast(op); auto tensorTy = constant.getResult().getType().dyn_cast(); auto mma = tensorTy.getEncoding().cast(); auto newMma = mmaToUpdate.lookup(mma); auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType(), newMma); if (auto attr = constant.getValue().dyn_cast()) { auto newRet = SplatElementsAttr::get(newTensorTy, attr.getSplatValue()); rewriter.replaceOpWithNewOp(op, newTensorTy, newRet); return; } assert(false && "Not supported ConstantOp value type"); } }; } // namespace #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" class TritonGPUCombineOpsPass : public TritonGPUCombineOpsBase { public: TritonGPUCombineOpsPass() = default; TritonGPUCombineOpsPass(int computeCapability) { this->computeCapability = computeCapability; } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context, computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } llvm::DenseMap mmaToUpdate; { mlir::RewritePatternSet patterns(context); patterns.add(context, mmaToUpdate); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); } { mlir::RewritePatternSet patterns(context); patterns.add( context, DotOp::getOperationName(), mmaToUpdate); patterns.add( context, ConvertLayoutOp::getOperationName(), mmaToUpdate); patterns.add( context, arith::ConstantOp::getOperationName(), mmaToUpdate); mlir::GreedyRewriteConfig config; config.useTopDownTraversal = true; if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed()) signalPassFailure(); } mlir::RewritePatternSet loopFixup(context); loopFixup.add(context); if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) { signalPassFailure(); } } }; std::unique_ptr mlir::createTritonGPUCombineOpsPass(int computeCapability) { return std::make_unique(computeCapability); }