[OPTIMIZER] Improved layout conversion simplification algorithm (#97)

This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model)
This commit is contained in:
Philippe Tillet
2022-09-02 16:52:44 -07:00
committed by GitHub
parent 3c635449e5
commit d0b4c67b05
2 changed files with 75 additions and 57 deletions

View File

@@ -96,51 +96,7 @@ class PullConversionToSource : public mlir::RewritePattern {
public: public:
PullConversionToSource(mlir::MLIRContext *context) PullConversionToSource(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
3, context) {} 2, context) {}
void getReachableNotThroughMemOp(
ArrayRef<Value> operands,
SmallVectorImpl<Operation *> &postOrderRet) const {
struct State {
Value value;
unsigned operandIndex;
};
SmallVector<State, 4> worklist;
for (auto operand : operands)
worklist.push_back({operand, 0});
while (!worklist.empty()) {
State &state = worklist.back();
auto *opInst = state.value.getDefiningOp();
// Note: getDefiningOp will return nullptr if the operand is not an
// Operation (i.e., block arguments) which is a terminator for the search.
if (opInst == nullptr) {
worklist.pop_back();
continue;
}
// if we encounter a memory operation, then
// we can assume it's not worth doing any
// rematerialization: layout conversion
// will be cheaper
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(
opInst))
return;
// we don't want to rematerialize conversions
if (isa<triton::gpu::ConvertLayoutOp, scf::YieldOp, scf::ForOp>(opInst))
return;
// visit operands
if (state.operandIndex < opInst->getNumOperands()) {
auto nextOperand = opInst->getOperand(state.operandIndex);
++state.operandIndex;
worklist.push_back({nextOperand, 0});
} else {
// Post-visit: done visiting operand, pop off stack.
// and add to post-order result
worklist.pop_back();
postOrderRet.push_back(opInst);
}
}
}
Attribute invertEncoding(Type targetType, Operation *op) const { Attribute invertEncoding(Type targetType, Operation *op) const {
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>(); RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
@@ -161,18 +117,51 @@ public:
Operation *op = cvt->getOperand(0).getDefiningOp(); Operation *op = cvt->getOperand(0).getDefiningOp();
if (!op) if (!op)
return mlir::failure(); return mlir::failure();
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
return mlir::failure();
// DFS through all operands
// auto filter = [](Operation *op) {
// return !isa<triton::LoadOp, triton::StoreOp,
// triton::gpu::ConvertLayoutOp>(op);
// };
SmallVector<Operation *, 4> postOrderOps; auto blacklist = [](Operation *op) {
getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps); if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(op))
if (postOrderOps.empty()) return true;
if (isa<scf::YieldOp, scf::ForOp>(op))
return true;
return false;
};
// find all the ops that the conversion depends on
SetVector<Operation *> depIntoOps;
mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) {
return !blacklist(op) && op->getResult(0) &&
(op->getResult(0).getParentBlock() ==
cvt->getResult(0).getParentBlock());
});
// find all the ops that depend on something expensive
// to rematerialize and break the dependency
// chain there
SetVector<Operation *> blacklistedOps;
mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist);
for (Operation *op : blacklistedOps) {
SetVector<Operation *> toRemove;
mlir::getBackwardSlice(op, &toRemove);
depIntoOps.set_subtract(toRemove);
}
// if there is nothing that can benefit from moving conversions
// in the remaining op, we don't do anything
auto it = llvm::find_if(depIntoOps, [&](Operation *op) {
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
// conversions in for loops interfere with the
// push-to-sink pass. Need a better cost model if how many conversions
// we can actually remove by moving them to the beginning of the block
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp &&
(cvt->getResult(0).getType() == op->getOperand(0).getType()))
return true;
}
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
return true;
return false;
});
if (it == depIntoOps.end()) {
return mlir::failure(); return mlir::failure();
}
// We convert cvt(op(arg_0, arg_1, ..., arg_n)) // We convert cvt(op(arg_0, arg_1, ..., arg_n))
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n)) // into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
@@ -229,8 +218,6 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
op->getResult(0).setType(op->getOperand(0).getType()); op->getResult(0).setType(op->getOperand(0).getType());
return true; return true;
} }
// i
return false; return false;
} }

View File

@@ -173,3 +173,34 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1> tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
return return
} }
// CHECK-LABEL: vecadd
func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%8 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%12 = tt.getelementptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%15 = tt.getelementptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%21 = tt.getelementptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
return
}