[OPTIMIZER] Improved layout conversion simplification algorithm (#97)

This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model)
This commit is contained in:
Philippe Tillet
2022-09-02 16:52:44 -07:00
committed by GitHub
parent 3c635449e5
commit d0b4c67b05
2 changed files with 75 additions and 57 deletions

View File

@@ -96,51 +96,7 @@ class PullConversionToSource : public mlir::RewritePattern {
public:
PullConversionToSource(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
3, context) {}
void getReachableNotThroughMemOp(
ArrayRef<Value> operands,
SmallVectorImpl<Operation *> &postOrderRet) const {
struct State {
Value value;
unsigned operandIndex;
};
SmallVector<State, 4> worklist;
for (auto operand : operands)
worklist.push_back({operand, 0});
while (!worklist.empty()) {
State &state = worklist.back();
auto *opInst = state.value.getDefiningOp();
// Note: getDefiningOp will return nullptr if the operand is not an
// Operation (i.e., block arguments) which is a terminator for the search.
if (opInst == nullptr) {
worklist.pop_back();
continue;
}
// if we encounter a memory operation, then
// we can assume it's not worth doing any
// rematerialization: layout conversion
// will be cheaper
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(
opInst))
return;
// we don't want to rematerialize conversions
if (isa<triton::gpu::ConvertLayoutOp, scf::YieldOp, scf::ForOp>(opInst))
return;
// visit operands
if (state.operandIndex < opInst->getNumOperands()) {
auto nextOperand = opInst->getOperand(state.operandIndex);
++state.operandIndex;
worklist.push_back({nextOperand, 0});
} else {
// Post-visit: done visiting operand, pop off stack.
// and add to post-order result
worklist.pop_back();
postOrderRet.push_back(opInst);
}
}
}
2, context) {}
Attribute invertEncoding(Type targetType, Operation *op) const {
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
@@ -161,18 +117,51 @@ public:
Operation *op = cvt->getOperand(0).getDefiningOp();
if (!op)
return mlir::failure();
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
return mlir::failure();
// DFS through all operands
// auto filter = [](Operation *op) {
// return !isa<triton::LoadOp, triton::StoreOp,
// triton::gpu::ConvertLayoutOp>(op);
// };
SmallVector<Operation *, 4> postOrderOps;
getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps);
if (postOrderOps.empty())
auto blacklist = [](Operation *op) {
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(op))
return true;
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
return false;
};
// find all the ops that the conversion depends on
SetVector<Operation *> depIntoOps;
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
return !blacklist(op) && op->getResult(0) &&
(op->getResult(0).getParentBlock() ==
cvt->getResult(0).getParentBlock());
});
// find all the ops that depend on something expensive
// to rematerialize and break the dependency
// chain there
SetVector<Operation *> blacklistedOps;
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
for (Operation *op : blacklistedOps) {
SetVector<Operation *> toRemove;
mlir::getBackwardSlice(op, &toRemove);
depIntoOps.set_subtract(toRemove);
}
// if there is nothing that can benefit from moving conversions
// in the remaining op, we don't do anything
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
// conversions in for loops interfere with the
// push-to-sink pass. Need a better cost model if how many conversions
// we can actually remove by moving them to the beginning of the block
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp &&
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
return true;
}
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
return true;
return false;
});
if (it == depIntoOps.end()) {
return mlir::failure();
}
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
@@ -229,8 +218,6 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
op->getResult(0).setType(op->getOperand(0).getType());
return true;
}
// i
return false;
}