[OPTIMIZER] Improved layout conversion simplification algorithm (#97)
This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model)
This commit is contained in:
@@ -96,51 +96,7 @@ class PullConversionToSource : public mlir::RewritePattern {
|
||||
public:
|
||||
PullConversionToSource(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
3, context) {}
|
||||
|
||||
void getReachableNotThroughMemOp(
|
||||
ArrayRef<Value> operands,
|
||||
SmallVectorImpl<Operation *> &postOrderRet) const {
|
||||
struct State {
|
||||
Value value;
|
||||
unsigned operandIndex;
|
||||
};
|
||||
SmallVector<State, 4> worklist;
|
||||
for (auto operand : operands)
|
||||
worklist.push_back({operand, 0});
|
||||
|
||||
while (!worklist.empty()) {
|
||||
State &state = worklist.back();
|
||||
auto *opInst = state.value.getDefiningOp();
|
||||
// Note: getDefiningOp will return nullptr if the operand is not an
|
||||
// Operation (i.e., block arguments) which is a terminator for the search.
|
||||
if (opInst == nullptr) {
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
// if we encounter a memory operation, then
|
||||
// we can assume it's not worth doing any
|
||||
// rematerialization: layout conversion
|
||||
// will be cheaper
|
||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(
|
||||
opInst))
|
||||
return;
|
||||
// we don't want to rematerialize conversions
|
||||
if (isa<triton::gpu::ConvertLayoutOp, scf::YieldOp, scf::ForOp>(opInst))
|
||||
return;
|
||||
// visit operands
|
||||
if (state.operandIndex < opInst->getNumOperands()) {
|
||||
auto nextOperand = opInst->getOperand(state.operandIndex);
|
||||
++state.operandIndex;
|
||||
worklist.push_back({nextOperand, 0});
|
||||
} else {
|
||||
// Post-visit: done visiting operand, pop off stack.
|
||||
// and add to post-order result
|
||||
worklist.pop_back();
|
||||
postOrderRet.push_back(opInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
2, context) {}
|
||||
|
||||
Attribute invertEncoding(Type targetType, Operation *op) const {
|
||||
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
||||
@@ -161,18 +117,51 @@ public:
|
||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||
return mlir::failure();
|
||||
// DFS through all operands
|
||||
// auto filter = [](Operation *op) {
|
||||
// return !isa<triton::LoadOp, triton::StoreOp,
|
||||
// triton::gpu::ConvertLayoutOp>(op);
|
||||
// };
|
||||
|
||||
SmallVector<Operation *, 4> postOrderOps;
|
||||
getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps);
|
||||
if (postOrderOps.empty())
|
||||
auto blacklist = [](Operation *op) {
|
||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// find all the ops that the conversion depends on
|
||||
SetVector<Operation *> depIntoOps;
|
||||
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
|
||||
return !blacklist(op) && op->getResult(0) &&
|
||||
(op->getResult(0).getParentBlock() ==
|
||||
cvt->getResult(0).getParentBlock());
|
||||
});
|
||||
// find all the ops that depend on something expensive
|
||||
// to rematerialize and break the dependency
|
||||
// chain there
|
||||
SetVector<Operation *> blacklistedOps;
|
||||
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
|
||||
for (Operation *op : blacklistedOps) {
|
||||
SetVector<Operation *> toRemove;
|
||||
mlir::getBackwardSlice(op, &toRemove);
|
||||
depIntoOps.set_subtract(toRemove);
|
||||
}
|
||||
// if there is nothing that can benefit from moving conversions
|
||||
// in the remaining op, we don't do anything
|
||||
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
// conversions in for loops interfere with the
|
||||
// push-to-sink pass. Need a better cost model if how many conversions
|
||||
// we can actually remove by moving them to the beginning of the block
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp &&
|
||||
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
|
||||
return true;
|
||||
}
|
||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
if (it == depIntoOps.end()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
||||
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
||||
@@ -229,8 +218,6 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
op->getResult(0).setType(op->getOperand(0).getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
// i
|
||||
return false;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user