[OPTIMIZER] Improved layout conversion simplification algorithm (#97)
This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model)
This commit is contained in:
@@ -96,51 +96,7 @@ class PullConversionToSource : public mlir::RewritePattern {
|
||||
public:
|
||||
PullConversionToSource(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
3, context) {}
|
||||
|
||||
void getReachableNotThroughMemOp(
|
||||
ArrayRef<Value> operands,
|
||||
SmallVectorImpl<Operation *> &postOrderRet) const {
|
||||
struct State {
|
||||
Value value;
|
||||
unsigned operandIndex;
|
||||
};
|
||||
SmallVector<State, 4> worklist;
|
||||
for (auto operand : operands)
|
||||
worklist.push_back({operand, 0});
|
||||
|
||||
while (!worklist.empty()) {
|
||||
State &state = worklist.back();
|
||||
auto *opInst = state.value.getDefiningOp();
|
||||
// Note: getDefiningOp will return nullptr if the operand is not an
|
||||
// Operation (i.e., block arguments) which is a terminator for the search.
|
||||
if (opInst == nullptr) {
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
// if we encounter a memory operation, then
|
||||
// we can assume it's not worth doing any
|
||||
// rematerialization: layout conversion
|
||||
// will be cheaper
|
||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(
|
||||
opInst))
|
||||
return;
|
||||
// we don't want to rematerialize conversions
|
||||
if (isa<triton::gpu::ConvertLayoutOp, scf::YieldOp, scf::ForOp>(opInst))
|
||||
return;
|
||||
// visit operands
|
||||
if (state.operandIndex < opInst->getNumOperands()) {
|
||||
auto nextOperand = opInst->getOperand(state.operandIndex);
|
||||
++state.operandIndex;
|
||||
worklist.push_back({nextOperand, 0});
|
||||
} else {
|
||||
// Post-visit: done visiting operand, pop off stack.
|
||||
// and add to post-order result
|
||||
worklist.pop_back();
|
||||
postOrderRet.push_back(opInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
2, context) {}
|
||||
|
||||
Attribute invertEncoding(Type targetType, Operation *op) const {
|
||||
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
||||
@@ -161,18 +117,51 @@ public:
|
||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||
return mlir::failure();
|
||||
// DFS through all operands
|
||||
// auto filter = [](Operation *op) {
|
||||
// return !isa<triton::LoadOp, triton::StoreOp,
|
||||
// triton::gpu::ConvertLayoutOp>(op);
|
||||
// };
|
||||
|
||||
SmallVector<Operation *, 4> postOrderOps;
|
||||
getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps);
|
||||
if (postOrderOps.empty())
|
||||
auto blacklist = [](Operation *op) {
|
||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// find all the ops that the conversion depends on
|
||||
SetVector<Operation *> depIntoOps;
|
||||
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
|
||||
return !blacklist(op) && op->getResult(0) &&
|
||||
(op->getResult(0).getParentBlock() ==
|
||||
cvt->getResult(0).getParentBlock());
|
||||
});
|
||||
// find all the ops that depend on something expensive
|
||||
// to rematerialize and break the dependency
|
||||
// chain there
|
||||
SetVector<Operation *> blacklistedOps;
|
||||
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
|
||||
for (Operation *op : blacklistedOps) {
|
||||
SetVector<Operation *> toRemove;
|
||||
mlir::getBackwardSlice(op, &toRemove);
|
||||
depIntoOps.set_subtract(toRemove);
|
||||
}
|
||||
// if there is nothing that can benefit from moving conversions
|
||||
// in the remaining op, we don't do anything
|
||||
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
// conversions in for loops interfere with the
|
||||
// push-to-sink pass. Need a better cost model if how many conversions
|
||||
// we can actually remove by moving them to the beginning of the block
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp &&
|
||||
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
|
||||
return true;
|
||||
}
|
||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
if (it == depIntoOps.end()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
||||
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
||||
@@ -229,8 +218,6 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
op->getResult(0).setType(op->getOperand(0).getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
// i
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@@ -173,3 +173,34 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: vecadd
|
||||
func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%12 = tt.getelementptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%15 = tt.getelementptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%21 = tt.getelementptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user