From d0b4c67b050423b74c50d651c3523351003c87e0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 2 Sep 2022 16:52:44 -0700 Subject: [PATCH] [OPTIMIZER] Improved layout conversion simplification algorithm (#97) This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model) --- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 101 ++++++++----------- test/TritonGPU/combine.mlir | 31 ++++++ 2 files changed, 75 insertions(+), 57 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 13061e65e..32ff7e89e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -96,51 +96,7 @@ class PullConversionToSource : public mlir::RewritePattern { public: PullConversionToSource(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 3, context) {} - - void getReachableNotThroughMemOp( - ArrayRef operands, - SmallVectorImpl &postOrderRet) const { - struct State { - Value value; - unsigned operandIndex; - }; - SmallVector worklist; - for (auto operand : operands) - worklist.push_back({operand, 0}); - - while (!worklist.empty()) { - State &state = worklist.back(); - auto *opInst = state.value.getDefiningOp(); - // Note: getDefiningOp will return nullptr if the operand is not an - // Operation (i.e., block arguments) which is a terminator for the search. - if (opInst == nullptr) { - worklist.pop_back(); - continue; - } - // if we encounter a memory operation, then - // we can assume it's not worth doing any - // rematerialization: layout conversion - // will be cheaper - if (isa( - opInst)) - return; - // we don't want to rematerialize conversions - if (isa(opInst)) - return; - // visit operands - if (state.operandIndex < opInst->getNumOperands()) { - auto nextOperand = opInst->getOperand(state.operandIndex); - ++state.operandIndex; - worklist.push_back({nextOperand, 0}); - } else { - // Post-visit: done visiting operand, pop off stack. - // and add to post-order result - worklist.pop_back(); - postOrderRet.push_back(opInst); - } - } - } + 2, context) {} Attribute invertEncoding(Type targetType, Operation *op) const { RankedTensorType targetTensorType = targetType.cast(); @@ -161,18 +117,51 @@ public: Operation *op = cvt->getOperand(0).getDefiningOp(); if (!op) return mlir::failure(); - if (isa(op)) - return mlir::failure(); - // DFS through all operands - // auto filter = [](Operation *op) { - // return !isa(op); - // }; - SmallVector postOrderOps; - getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps); - if (postOrderOps.empty()) + auto blacklist = [](Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return true; + return false; + }; + + // find all the ops that the conversion depends on + SetVector depIntoOps; + mlir::getBackwardSlice(cvt, &depIntoOps, [&](Operation *op) { + return !blacklist(op) && op->getResult(0) && + (op->getResult(0).getParentBlock() == + cvt->getResult(0).getParentBlock()); + }); + // find all the ops that depend on something expensive + // to rematerialize and break the dependency + // chain there + SetVector blacklistedOps; + mlir::getBackwardSlice(cvt, &blacklistedOps, blacklist); + for (Operation *op : blacklistedOps) { + SetVector toRemove; + mlir::getBackwardSlice(op, &toRemove); + depIntoOps.set_subtract(toRemove); + } + // if there is nothing that can benefit from moving conversions + // in the remaining op, we don't do anything + auto it = llvm::find_if(depIntoOps, [&](Operation *op) { + if (isa(op)) { + // conversions in for loops interfere with the + // push-to-sink pass. Need a better cost model if how many conversions + // we can actually remove by moving them to the beginning of the block + auto forOp = dyn_cast(cvt->getParentOp()); + if (!forOp && + (cvt->getResult(0).getType() == op->getOperand(0).getType())) + return true; + } + if (isa(op)) + return true; + return false; + }); + if (it == depIntoOps.end()) { return mlir::failure(); + } // We convert cvt(op(arg_0, arg_1, ..., arg_n)) // into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n)) @@ -229,8 +218,6 @@ bool tryLegalizeOp(Operation *op, DenseSet toPreserve, op->getResult(0).setType(op->getOperand(0).getType()); return true; } - - // i return false; } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 5bdc3c10f..4f1b20bf3 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -173,3 +173,34 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1> return } + +// CHECK-LABEL: vecadd +func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + // CHECK-NOT: triton_gpu.convert_layout + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %4 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %6 = tt.splat %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %12 = tt.getelementptr %8, %9 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> + %15 = tt.getelementptr %10, %11 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> + %18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>> + %19 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %21 = tt.getelementptr %19, %20 : tensor<256x!tt.ptr, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + return +} \ No newline at end of file