From 8776ad1a0eb7df5f6489f9fd3cfafb71f3ff7dfa Mon Sep 17 00:00:00 2001 From: Da Yan Date: Fri, 19 Aug 2022 01:31:57 +0800 Subject: [PATCH] [OPTIMIZER] Let the pipeline pass insert async wait. (#63) --- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 ++ lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 23 +++++++++++++++++++ test/TritonGPU/loop-pipeline.mlir | 3 +++ 3 files changed, 28 insertions(+) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index a01a55e66..90e61dcb5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -34,6 +34,8 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { let summary = "async wait"; let arguments = (ins I32Attr:$num); + + let assemblyFormat = "attr-dict"; } def TTG_CopyAsyncOp : TTG_Op<"copy_async", diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 4b3029a3a..102299757 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -47,6 +47,9 @@ class LoopPipeliner { void setValueMapping(Value origin, Value newValue, int stage); + /// return true if this op uses any of `loads` + bool isDirectUserOfAsyncLoad(Operation &op); + public: LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) { @@ -96,6 +99,19 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { } } +bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) { + for (Value loadOp : loads) { + assert(loadOp.hasOneUse() && + "load should only have one use (ConvertLayout)"); + Value loadUseResult = loadOp.getUsers().begin()->getResult(0); + for (Value opOperand : op.getOperands()) { + if (opOperand == loadUseResult) + return true; + } + } + return false; +} + /// A load instruction can be pipelined if: /// - the load doesn't depend on any other loads (after loop peeling) /// - (?) this load is not a loop-invariant value (we should run LICM before @@ -318,7 +334,14 @@ scf::ForOp LoopPipeliner::createNewForOp() { mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); // 2.1 clone the loop body, replace original args with args of the new ForOp + // Insert async wait if necessary. + bool asyncWaitInserted = false; for (Operation &op : forOp.getBody()->without_terminator()) { + if (!asyncWaitInserted && isDirectUserOfAsyncLoad(op)) { + asyncWaitInserted = true; + builder.create(op.getLoc(), + loads.size() * (numStages - 1)); + } Operation *newOp = builder.clone(op, mapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index f90b9ed9b..cc73c154f 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -14,6 +14,7 @@ // CHECK: %[[A1:.*]] = triton_gpu.copy_async // CHECK: %[[B1:.*]] = triton_gpu.copy_async // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) +// CHECK: triton_gpu.async_wait {num = 4 : i32} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async // CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async @@ -54,6 +55,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: %[[A1:.*]] = triton_gpu.copy_async // CHECK: %[[B1:.*]] = triton_gpu.copy_async // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) +// CHECK: triton_gpu.async_wait {num = 4 : i32} // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async // CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async @@ -93,6 +95,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr