[OPTIMIZER] Updated TritonGPU-combine pass (#784)

WIP but should work int t…he cases we need so far
This commit is contained in:
Philippe Tillet
2022-10-16 21:19:42 -07:00
committed by GitHub
parent e948a618b3
commit 38a80664b5
2 changed files with 165 additions and 136 deletions

View File

@@ -5,9 +5,13 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -36,7 +40,7 @@ class SimplifyConversion : public mlir::RewritePattern {
public:
SimplifyConversion(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
4, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
@@ -86,106 +90,165 @@ public:
//
// -----------------------------------------------------------------------------
static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
ret = targetEncoding;
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
ret = triton::gpu::SliceEncodingAttr::get(
op->getContext(), expand_dims.axis(), targetEncoding);
}
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
auto sliceEncoding =
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
if (!sliceEncoding)
return failure();
ret = sliceEncoding.getParent();
}
return success();
}
inline bool expensive_to_remat(Operation *op) {
if (!op)
return true;
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
return false;
};
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
newOp->getResult(0).setType(newType);
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
SmallVector<Type, 1> newType;
auto sucess = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newType);
if (success)
newOp->getResult(0).setType(newType.front());
}
return newOp;
}
// Layout conversions are expensive. They require going through
// shared memory, which is orders of magnitude slower than
// other non-i/o operations in the dialect.
// It therefore makes sense to remove them whenever possible,
// even if it means rematerializing all values whose definitions
// are reachable from it without passing through any memory operation.
class PullConversionToSource : public mlir::RewritePattern {
class RematerializeBackward : public mlir::RewritePattern {
public:
PullConversionToSource(mlir::MLIRContext *context)
RematerializeBackward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
Attribute invertEncoding(Type targetType, Operation *op) const {
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
return targetTensorType.getEncoding()
.cast<triton::gpu::BlockedEncodingAttr>()
.squeeze(expand_dims.axis());
}
return targetTensorType.getEncoding();
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *cvt,
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
return mlir::failure();
// constants/splat are handled separately
// we don't touch block arguments
Operation *op = cvt->getOperand(0).getDefiningOp();
if (!op)
return mlir::failure();
auto blacklist = [](Operation *op) {
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
triton::DotOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
return false;
};
// find all the ops that the conversion depends on
SetVector<Operation *> depIntoOps;
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
return !blacklist(op) && op->getResult(0) &&
(op->getResult(0).getParentBlock() ==
cvt->getResult(0).getParentBlock());
});
// find all the ops that depend on something expensive
// to rematerialize and break the dependency
// chain there
SetVector<Operation *> blacklistedOps;
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
for (Operation *op : blacklistedOps) {
SetVector<Operation *> toRemove;
mlir::getBackwardSlice(op, &toRemove);
depIntoOps.set_subtract(toRemove);
}
// if there is nothing that can benefit from moving conversions
// in the remaining op, we don't do anything
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
// conversions in for loops interfere with the
// push-to-sink pass. Need a better cost model if how many conversions
// we can actually remove by moving them to the beginning of the block
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp &&
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
return true;
}
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
return true;
return false;
});
if (it == depIntoOps.end()) {
// we don't want to rematerialize any conversion to/from shared
if (isSharedLayout(cvt->getResults()[0]) ||
isSharedLayout(cvt->getOperand(0)))
return mlir::failure();
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
// DFS
SetVector<Operation *> processed;
SetVector<Attribute> layout;
std::vector<std::pair<Operation *, Attribute>> queue;
std::vector<std::pair<Value, Attribute>> toConvert;
queue.push_back({cvt, targetType.getEncoding()});
int numCvts = 1;
while (!queue.empty()) {
Operation *currOp;
Attribute currLayout;
std::tie(currOp, currLayout) = queue.back();
queue.pop_back();
// If the current operation is expensive to rematerialize,
// we stop everything
if (expensive_to_remat(currOp))
break;
// a conversion will be removed here (i.e. transfered to operands)
numCvts -= 1;
// done processing
processed.insert(currOp);
layout.insert(currLayout);
// add all operands to the queue
for (Value argI : currOp->getOperands()) {
Attribute newEncoding;
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
return mlir::failure();
toConvert.push_back({argI, newEncoding});
Operation *opArgI = argI.getDefiningOp();
if (!opArgI)
continue;
if (!opArgI || processed.contains(opArgI) ||
(opArgI->getBlock() != cvt->getBlock()))
continue;
// if the conversion can be folded into opArgI then
// we actually haven't added anny conversion
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
// we add one expensive conversion for the current operand
numCvts += 1;
queue.push_back({opArgI, newEncoding});
}
}
// if rematerialization would add more conversions than it removes
// then we don't do it
if (numCvts > 0)
return mlir::failure();
FuncOp parentFunc = cvt->getParentOfType<FuncOp>();
bool test = cvt->getResult(0)
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>();
// if (test)
// llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n";
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
BlockAndValueMapping mapping;
for (Value argI : op->getOperands()) {
// Compute new argument types
auto oldArgType = argI.getType().dyn_cast<RankedTensorType>();
if (!oldArgType)
continue;
auto newEncoding = invertEncoding(cvt->getResultTypes()[0], op);
auto newArgType = RankedTensorType::get(
oldArgType.getShape(), oldArgType.getElementType(), newEncoding);
// Create new argument
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newArgType, argI);
cvtI->moveBefore(op);
mapping.map(argI, cvtI);
for (int i = toConvert.size() - 1; i >= 0; i--) {
// unpack information
Value currOperand;
Attribute targetLayout;
std::tie(currOperand, targetLayout) = toConvert[i];
// if (test)
// llvm::outs() << "current " << currOperand << "\n";
// rematerialize the operand if necessary
Operation *currOperation = currOperand.getDefiningOp();
if (processed.contains(currOperation)) {
currOperation = cloneWithInferType(rewriter, currOperation, mapping);
currOperand = currOperation->getResult(0);
}
if (i == 0)
break;
// compute target type for the layout cast
auto currType = currOperand.getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
currType.getShape(), currType.getElementType(), targetLayout);
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
currOperand.getLoc(), newType, currOperand);
if (currOperation)
newOperand->moveAfter(currOperation);
mapping.map(currOperand, newOperand);
}
Operation *newOp = rewriter.clone(*op, mapping);
newOp->getResult(0).setType(cvt->getResult(0).getType());
rewriter.replaceOp(cvt, newOp->getResults());
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
return mlir::success();
}
};
@@ -226,6 +289,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
std::pair<SmallVector<Value, 4>, scf::ForOp>
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
Type newType) {
forOp.getInductionVar();
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
auto ctx = forOp.getContext();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
@@ -243,6 +307,7 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// traverse all ops in the loop
for (Operation &op : forOp.getBody()->without_terminator()) {
// we clone the op
@@ -278,9 +343,9 @@ tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
return {newResults, newForOp};
}
class MoveArgConvertOutOfLoop : public mlir::RewritePattern {
class MoveConvertOutOfLoop : public mlir::RewritePattern {
public:
MoveArgConvertOutOfLoop(mlir::MLIRContext *context)
MoveConvertOutOfLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
@@ -290,27 +355,14 @@ public:
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) {
// skip non-tensor types
if (!iterArg.value().getType().isa<RankedTensorType>())
continue;
// check
for (auto op : iterArg.value().getUsers()) {
auto currOps = mlir::getSlice(op, isInLoop);
auto pred = [&](Operation *op) {
return isa<triton::LoadOp, triton::StoreOp>(op);
};
auto isCvt = [&](Operation *op) {
return isa<triton::gpu::ConvertLayoutOp>(op);
};
auto isYield = [&](Operation *op) { return isa<scf::YieldOp>(op); };
auto opIt = std::find(currOps.begin(), currOps.end(), op);
auto yieldIt = std::find_if(currOps.begin(), currOps.end(), isYield);
auto fwdEndIt = std::find_if(opIt, currOps.end(), pred);
auto bwdBeginIt = std::find_if(currOps.begin(), opIt, pred);
auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt);
auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt);
if (!iterArg.value().getType().isa<RankedTensorType>())
continue;
if (fwdCvtIt != fwdEndIt) {
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
(*fwdCvtIt)->getResult(0).getType());
op->getResult(0).getType());
rewriter.replaceOp(forOp, newFor.first);
return success();
}
@@ -324,9 +376,9 @@ public:
//
// -----------------------------------------------------------------------------
class PushConversionToSink : public mlir::RewritePattern {
class RematerializeForward : public mlir::RewritePattern {
public:
PushConversionToSink(mlir::MLIRContext *context)
RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
@@ -430,16 +482,17 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<SimplifyConversion>(context);
patterns.add<PullConversionToSource>(context);
patterns.add<PushConversionToSink>(context);
patterns.add<MoveArgConvertOutOfLoop>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<BlockedToMMA>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}
}

View File

@@ -45,8 +45,8 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
// CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
// CHECK: %4 = arith.muli %2, %3 : tensor<1024xi32, [[target_layout]]>
// CHECK: %5 = arith.muli %0, %1 : tensor<1024xi32, [[target_layout]]>
// CHECK: %4 = arith.muli %0, %2 : tensor<1024xi32, [[target_layout]]>
// CHECK: %5 = arith.muli %1, %3 : tensor<1024xi32, [[target_layout]]>
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]>
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
}
@@ -61,34 +61,10 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK-LABEL: transpose
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK: %cst = arith.constant dense<true> : tensor<64x64xi1, [[row_layout]]>
// CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, [[row_layout]]>
// CHECK: %cst_1 = arith.constant dense<true> : tensor<64x64xi1, [[col_layout]]>
// CHECK: %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>
// CHECK: %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>
// CHECK: %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>) -> tensor<64x1xi32, [[row_layout]]>
// CHECK: %3 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, [[row_layout]]>
// CHECK: %4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
// CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]>
// CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>
// CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>
// CHECK: %8 = tt.addptr %4, %5 : tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
// CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]>
// CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr<f32>, [[row_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]>
// CHECK: %12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
// CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]>
// CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]>
// CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]>
// CHECK: %16 = tt.addptr %12, %13 : tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
// CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]>
// CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr<f32>, [[col_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
// CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]>
// CHECK: %20 = tt.addptr %10, %11 : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: %22 = tt.addptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>