From 38a80664b5e062af2e8a2af90d96b166e972a0be Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 16 Oct 2022 21:19:42 -0700 Subject: [PATCH] [OPTIMIZER] Updated TritonGPU-combine pass (#784) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WIP but should work int t…he cases we need so far --- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 265 +++++++++++-------- test/TritonGPU/combine.mlir | 36 +-- 2 files changed, 165 insertions(+), 136 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 5574b3ffd..af1cee904 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -5,9 +5,13 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -36,7 +40,7 @@ class SimplifyConversion : public mlir::RewritePattern { public: SimplifyConversion(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 2, context) {} + 4, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, @@ -86,106 +90,165 @@ public: // // ----------------------------------------------------------------------------- +static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, + Attribute &ret) { + ret = targetEncoding; + if (auto expand_dims = dyn_cast(op)) { + ret = triton::gpu::SliceEncodingAttr::get( + op->getContext(), expand_dims.axis(), targetEncoding); + } + if (auto reduce = dyn_cast(op)) { + auto sliceEncoding = + targetEncoding.dyn_cast(); + if (!sliceEncoding) + return failure(); + ret = sliceEncoding.getParent(); + } + return success(); +} + +inline bool expensive_to_remat(Operation *op) { + if (!op) + return true; + if (isa(op)) + return true; + if (isa(op)) + return true; + return false; +}; + +Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, + BlockAndValueMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + auto origType = op->getResult(0).getType().cast(); + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), + newOp->getOperand(0).getType().cast().getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newType; + auto sucess = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getRegions(), newType); + if (success) + newOp->getResult(0).setType(newType.front()); + } + return newOp; +} + // Layout conversions are expensive. They require going through // shared memory, which is orders of magnitude slower than // other non-i/o operations in the dialect. // It therefore makes sense to remove them whenever possible, // even if it means rematerializing all values whose definitions // are reachable from it without passing through any memory operation. -class PullConversionToSource : public mlir::RewritePattern { +class RematerializeBackward : public mlir::RewritePattern { public: - PullConversionToSource(mlir::MLIRContext *context) + RematerializeBackward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} - Attribute invertEncoding(Type targetType, Operation *op) const { - RankedTensorType targetTensorType = targetType.cast(); - if (auto expand_dims = dyn_cast(op)) { - return targetTensorType.getEncoding() - .cast() - .squeeze(expand_dims.axis()); - } - return targetTensorType.getEncoding(); - } - mlir::LogicalResult matchAndRewrite(mlir::Operation *cvt, mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(cvt)) return mlir::failure(); - // constants/splat are handled separately + // we don't touch block arguments Operation *op = cvt->getOperand(0).getDefiningOp(); if (!op) return mlir::failure(); - - auto blacklist = [](Operation *op) { - if (isa(op)) - return true; - if (isa(op)) - return true; - return false; - }; - - // find all the ops that the conversion depends on - SetVector depIntoOps; - mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) { - return !blacklist(op) && op->getResult(0) && - (op->getResult(0).getParentBlock() == - cvt->getResult(0).getParentBlock()); - }); - // find all the ops that depend on something expensive - // to rematerialize and break the dependency - // chain there - SetVector blacklistedOps; - mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist); - for (Operation *op : blacklistedOps) { - SetVector toRemove; - mlir::getBackwardSlice(op, &toRemove); - depIntoOps.set_subtract(toRemove); - } - // if there is nothing that can benefit from moving conversions - // in the remaining op, we don't do anything - auto it = llvm::find_if(depIntoOps, [&](Operation *op) { - if (isa(op)) { - // conversions in for loops interfere with the - // push-to-sink pass. Need a better cost model if how many conversions - // we can actually remove by moving them to the beginning of the block - auto forOp = dyn_cast(cvt->getParentOp()); - if (!forOp && - (cvt->getResult(0).getType() == op->getOperand(0).getType())) - return true; - } - if (isa(op)) - return true; - return false; - }); - if (it == depIntoOps.end()) { + // we don't want to rematerialize any conversion to/from shared + if (isSharedLayout(cvt->getResults()[0]) || + isSharedLayout(cvt->getOperand(0))) return mlir::failure(); + auto targetType = cvt->getResultTypes()[0].cast(); + // DFS + SetVector processed; + SetVector layout; + std::vector> queue; + std::vector> toConvert; + queue.push_back({cvt, targetType.getEncoding()}); + int numCvts = 1; + while (!queue.empty()) { + Operation *currOp; + Attribute currLayout; + std::tie(currOp, currLayout) = queue.back(); + queue.pop_back(); + // If the current operation is expensive to rematerialize, + // we stop everything + if (expensive_to_remat(currOp)) + break; + // a conversion will be removed here (i.e. transfered to operands) + numCvts -= 1; + // done processing + processed.insert(currOp); + layout.insert(currLayout); + // add all operands to the queue + for (Value argI : currOp->getOperands()) { + Attribute newEncoding; + if (failed(invertEncoding(currLayout, currOp, newEncoding))) + return mlir::failure(); + toConvert.push_back({argI, newEncoding}); + Operation *opArgI = argI.getDefiningOp(); + if (!opArgI) + continue; + if (!opArgI || processed.contains(opArgI) || + (opArgI->getBlock() != cvt->getBlock())) + continue; + // if the conversion can be folded into opArgI then + // we actually haven't added anny conversion + if (isa(*opArgI)) + continue; + // we add one expensive conversion for the current operand + numCvts += 1; + queue.push_back({opArgI, newEncoding}); + } } + // if rematerialization would add more conversions than it removes + // then we don't do it + if (numCvts > 0) + return mlir::failure(); + + FuncOp parentFunc = cvt->getParentOfType(); + bool test = cvt->getResult(0) + .getType() + .cast() + .getEncoding() + .isa(); + // if (test) + // llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n"; - // We convert cvt(op(arg_0, arg_1, ..., arg_n)) - // into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n)) BlockAndValueMapping mapping; - for (Value argI : op->getOperands()) { - // Compute new argument types - auto oldArgType = argI.getType().dyn_cast(); - if (!oldArgType) - continue; - auto newEncoding = invertEncoding(cvt->getResultTypes()[0], op); - auto newArgType = RankedTensorType::get( - oldArgType.getShape(), oldArgType.getElementType(), newEncoding); - // Create new argument - auto cvtI = rewriter.create( - op->getLoc(), newArgType, argI); - cvtI->moveBefore(op); - mapping.map(argI, cvtI); + for (int i = toConvert.size() - 1; i >= 0; i--) { + // unpack information + Value currOperand; + Attribute targetLayout; + std::tie(currOperand, targetLayout) = toConvert[i]; + // if (test) + // llvm::outs() << "current " << currOperand << "\n"; + // rematerialize the operand if necessary + Operation *currOperation = currOperand.getDefiningOp(); + if (processed.contains(currOperation)) { + currOperation = cloneWithInferType(rewriter, currOperation, mapping); + currOperand = currOperation->getResult(0); + } + if (i == 0) + break; + // compute target type for the layout cast + auto currType = currOperand.getType().cast(); + auto newType = RankedTensorType::get( + currType.getShape(), currType.getElementType(), targetLayout); + auto newOperand = rewriter.create( + currOperand.getLoc(), newType, currOperand); + if (currOperation) + newOperand->moveAfter(currOperation); + mapping.map(currOperand, newOperand); } - Operation *newOp = rewriter.clone(*op, mapping); - newOp->getResult(0).setType(cvt->getResult(0).getType()); - rewriter.replaceOp(cvt, newOp->getResults()); - + rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); return mlir::success(); } }; @@ -226,6 +289,7 @@ bool tryLegalizeOp(Operation *op, DenseSet toPreserve, std::pair, scf::ForOp> tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i, Type newType) { + forOp.getInductionVar(); auto newEncoding = newType.cast().getEncoding(); auto ctx = forOp.getContext(); auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; @@ -243,6 +307,7 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i, BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); // traverse all ops in the loop for (Operation &op : forOp.getBody()->without_terminator()) { // we clone the op @@ -278,9 +343,9 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i, return {newResults, newForOp}; } -class MoveArgConvertOutOfLoop : public mlir::RewritePattern { +class MoveConvertOutOfLoop : public mlir::RewritePattern { public: - MoveArgConvertOutOfLoop(mlir::MLIRContext *context) + MoveConvertOutOfLoop(mlir::MLIRContext *context) : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, @@ -290,27 +355,14 @@ public: auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; auto iterArgs = forOp.getRegionIterArgs(); for (auto iterArg : llvm::enumerate(iterArgs)) { + // skip non-tensor types + if (!iterArg.value().getType().isa()) + continue; + // check for (auto op : iterArg.value().getUsers()) { - auto currOps = mlir::getSlice(op, isInLoop); - auto pred = [&](Operation *op) { - return isa(op); - }; - auto isCvt = [&](Operation *op) { - return isa(op); - }; - auto isYield = [&](Operation *op) { return isa(op); }; - auto opIt = std::find(currOps.begin(), currOps.end(), op); - auto yieldIt = std::find_if(currOps.begin(), currOps.end(), isYield); - auto fwdEndIt = std::find_if(opIt, currOps.end(), pred); - auto bwdBeginIt = std::find_if(currOps.begin(), opIt, pred); - auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt); - auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt); - - if (!iterArg.value().getType().isa()) - continue; - if (fwdCvtIt != fwdEndIt) { + if (isa(op)) { auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(), - (*fwdCvtIt)->getResult(0).getType()); + op->getResult(0).getType()); rewriter.replaceOp(forOp, newFor.first); return success(); } @@ -324,9 +376,9 @@ public: // // ----------------------------------------------------------------------------- -class PushConversionToSink : public mlir::RewritePattern { +class RematerializeForward : public mlir::RewritePattern { public: - PushConversionToSink(mlir::MLIRContext *context) + RematerializeForward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 2, context) {} @@ -430,16 +482,17 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); patterns.add(context); - if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); + } } }; std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); -} +} \ No newline at end of file diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 91ad2703b..64025d1bf 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -45,8 +45,8 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { // CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> // CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> - // CHECK: %4 = arith.muli %2, %3 : tensor<1024xi32, [[target_layout]]> - // CHECK: %5 = arith.muli %0, %1 : tensor<1024xi32, [[target_layout]]> + // CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[target_layout]]> + // CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[target_layout]]> // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]> // CHECK: return %6 : tensor<1024xi32, [[target_layout]]> } @@ -61,34 +61,10 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { // CHECK-LABEL: transpose func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { - // CHECK: %cst = arith.constant dense : tensor<64x64xi1, [[row_layout]]> - // CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, [[row_layout]]> - // CHECK: %cst_1 = arith.constant dense : tensor<64x64xi1, [[col_layout]]> - // CHECK: %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>> - // CHECK: %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>> - // CHECK: %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>) -> tensor<64x1xi32, [[row_layout]]> - // CHECK: %3 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, [[row_layout]]> - // CHECK: %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, [[row_layout]]> - // CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]> - // CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>> - // CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>> - // CHECK: %8 = tt.addptr %4, %5 : tensor<64x1x!tt.ptr, [[row_layout]]> - // CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]> - // CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr, [[row_layout]]>) -> tensor<64x64x!tt.ptr, [[row_layout]]> - // CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]> - // CHECK: %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, [[col_layout]]> - // CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]> - // CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]> - // CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]> - // CHECK: %16 = tt.addptr %12, %13 : tensor<64x1x!tt.ptr, [[col_layout]]> - // CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]> - // CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr, [[col_layout]]>) -> tensor<64x64x!tt.ptr, [[col_layout]]> - // CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]> - // CHECK: %20 = tt.addptr %10, %11 : tensor<64x64x!tt.ptr, [[row_layout]]> - // CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> - // CHECK: %22 = tt.addptr %18, %19 : tensor<64x64x!tt.ptr, [[col_layout]]> - // CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> - // CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]> + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> + // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> + // CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]> // CHECK: return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1>