[OPTIMIZER] Improved layout conversion simplification algorithm (#97)
This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model)
This commit is contained in:
@@ -96,51 +96,7 @@ class PullConversionToSource : public mlir::RewritePattern {
|
|||||||
public:
|
public:
|
||||||
PullConversionToSource(mlir::MLIRContext *context)
|
PullConversionToSource(mlir::MLIRContext *context)
|
||||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||||
3, context) {}
|
2, context) {}
|
||||||
|
|
||||||
void getReachableNotThroughMemOp(
|
|
||||||
ArrayRef<Value> operands,
|
|
||||||
SmallVectorImpl<Operation *> &postOrderRet) const {
|
|
||||||
struct State {
|
|
||||||
Value value;
|
|
||||||
unsigned operandIndex;
|
|
||||||
};
|
|
||||||
SmallVector<State, 4> worklist;
|
|
||||||
for (auto operand : operands)
|
|
||||||
worklist.push_back({operand, 0});
|
|
||||||
|
|
||||||
while (!worklist.empty()) {
|
|
||||||
State &state = worklist.back();
|
|
||||||
auto *opInst = state.value.getDefiningOp();
|
|
||||||
// Note: getDefiningOp will return nullptr if the operand is not an
|
|
||||||
// Operation (i.e., block arguments) which is a terminator for the search.
|
|
||||||
if (opInst == nullptr) {
|
|
||||||
worklist.pop_back();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// if we encounter a memory operation, then
|
|
||||||
// we can assume it's not worth doing any
|
|
||||||
// rematerialization: layout conversion
|
|
||||||
// will be cheaper
|
|
||||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(
|
|
||||||
opInst))
|
|
||||||
return;
|
|
||||||
// we don't want to rematerialize conversions
|
|
||||||
if (isa<triton::gpu::ConvertLayoutOp, scf::YieldOp, scf::ForOp>(opInst))
|
|
||||||
return;
|
|
||||||
// visit operands
|
|
||||||
if (state.operandIndex < opInst->getNumOperands()) {
|
|
||||||
auto nextOperand = opInst->getOperand(state.operandIndex);
|
|
||||||
++state.operandIndex;
|
|
||||||
worklist.push_back({nextOperand, 0});
|
|
||||||
} else {
|
|
||||||
// Post-visit: done visiting operand, pop off stack.
|
|
||||||
// and add to post-order result
|
|
||||||
worklist.pop_back();
|
|
||||||
postOrderRet.push_back(opInst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Attribute invertEncoding(Type targetType, Operation *op) const {
|
Attribute invertEncoding(Type targetType, Operation *op) const {
|
||||||
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
||||||
@@ -161,18 +117,51 @@ public:
|
|||||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||||
if (!op)
|
if (!op)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
|
||||||
return mlir::failure();
|
|
||||||
// DFS through all operands
|
|
||||||
// auto filter = [](Operation *op) {
|
|
||||||
// return !isa<triton::LoadOp, triton::StoreOp,
|
|
||||||
// triton::gpu::ConvertLayoutOp>(op);
|
|
||||||
// };
|
|
||||||
|
|
||||||
SmallVector<Operation *, 4> postOrderOps;
|
auto blacklist = [](Operation *op) {
|
||||||
getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps);
|
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(op))
|
||||||
if (postOrderOps.empty())
|
return true;
|
||||||
|
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// find all the ops that the conversion depends on
|
||||||
|
SetVector<Operation *> depIntoOps;
|
||||||
|
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
|
||||||
|
return !blacklist(op) && op->getResult(0) &&
|
||||||
|
(op->getResult(0).getParentBlock() ==
|
||||||
|
cvt->getResult(0).getParentBlock());
|
||||||
|
});
|
||||||
|
// find all the ops that depend on something expensive
|
||||||
|
// to rematerialize and break the dependency
|
||||||
|
// chain there
|
||||||
|
SetVector<Operation *> blacklistedOps;
|
||||||
|
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
|
||||||
|
for (Operation *op : blacklistedOps) {
|
||||||
|
SetVector<Operation *> toRemove;
|
||||||
|
mlir::getBackwardSlice(op, &toRemove);
|
||||||
|
depIntoOps.set_subtract(toRemove);
|
||||||
|
}
|
||||||
|
// if there is nothing that can benefit from moving conversions
|
||||||
|
// in the remaining op, we don't do anything
|
||||||
|
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
|
||||||
|
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||||
|
// conversions in for loops interfere with the
|
||||||
|
// push-to-sink pass. Need a better cost model if how many conversions
|
||||||
|
// we can actually remove by moving them to the beginning of the block
|
||||||
|
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||||
|
if (!forOp &&
|
||||||
|
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
if (it == depIntoOps.end()) {
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
||||||
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
||||||
@@ -229,8 +218,6 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
|||||||
op->getResult(0).setType(op->getOperand(0).getType());
|
op->getResult(0).setType(op->getOperand(0).getType());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// i
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -173,3 +173,34 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
|||||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: vecadd
|
||||||
|
func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||||
|
// CHECK-NOT: triton_gpu.convert_layout
|
||||||
|
%c256_i32 = arith.constant 256 : i32
|
||||||
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
%1 = arith.muli %0, %c256_i32 : i32
|
||||||
|
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%12 = tt.getelementptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||||
|
%15 = tt.getelementptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||||
|
%18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||||
|
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%21 = tt.getelementptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||||
|
return
|
||||||
|
}
|
Reference in New Issue
Block a user