diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 93b6351de..ea39bff54 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -34,6 +34,14 @@ class LoopPipeliner { DenseMap loadsMapping; /// load => buffer DenseMap loadsBuffer; + /// load => buffer at stage N + DenseMap> loadStageBuffer; + /// load => after extract + DenseMap loadsExtract; + /// + Value pipelineIterIdx; + /// + Value loopIterIdx; /// value (in loop) => value at stage N DenseMap> valueMapping; @@ -237,6 +245,7 @@ void LoopPipeliner::emitPrologue() { // prologue from [0, numStage-1) Value iv = forOp.getLowerBound(); + pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); for (int stage = 0; stage < numStages - 1; ++stage) { // special handling for induction variable as the increment is implicit if (stage != 0) @@ -261,22 +270,21 @@ void LoopPipeliner::emitPrologue() { Operation *newOp = nullptr; if (loads.contains(op->getResult(0))) { // Allocate empty buffer - if (stage == 0) + if (stage == 0) { loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder); + loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]}; + } // load => copy async // TODO: check if the hardware supports copyasync if (auto loadOp = llvm::dyn_cast(op)) { - Value sliceIndex = builder.create( - iv.getLoc(), builder.getI32Type(), iv); - Value insertAsyncOp = builder.create( + newOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.ptr(), stage), loadsBuffer[loadOp], - sliceIndex, lookupOrDefault(loadOp.mask(), stage), + lookupOrDefault(loadOp.ptr(), stage), + loadStageBuffer[loadOp][stage], pipelineIterIdx, + lookupOrDefault(loadOp.mask(), stage), lookupOrDefault(loadOp.other(), stage), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); - newOp = builder.create( - op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp, - sliceIndex, /*axis*/ 0); + loadStageBuffer[loadOp].push_back(newOp->getResult(0)); } else llvm_unreachable("This should be LoadOp"); } else { @@ -294,9 +302,11 @@ void LoopPipeliner::emitPrologue() { // If this is a load/async_copy, we need to update the mask if (llvm::isa(newOp)) { - Value mask = newOp->getOperand(1); + Value mask = llvm::isa(newOp) ? newOp->getOperand(1) + : newOp->getOperand(3); // assert(I1 or TensorOf<[I1]>); OpBuilder::InsertionGuard g(builder); + // TODO: move this out of the loop builder.setInsertionPoint(newOp); Value splatCond = builder.create( mask.getLoc(), mask.getType(), loopCond); @@ -313,8 +323,11 @@ void LoopPipeliner::emitPrologue() { for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { Value originalResult = op->getResult(dstIdx); // copy_async will update the value of its only use - if (loads.contains(originalResult)) - originalResult = loadsMapping[originalResult]; + // TODO: load should no be used in the preheader? + if (loads.contains(originalResult)) { + break; + // originalResult = loadsMapping[originalResult]; + } setValueMapping(originalResult, newOp->getResult(dstIdx), stage); // update mapping for loop-carried values (args) for (OpOperand &operand : yieldOp->getOpOperands()) { @@ -325,18 +338,25 @@ void LoopPipeliner::emitPrologue() { } } } - } + + pipelineIterIdx = builder.create( + iv.getLoc(), pipelineIterIdx, + builder.create(iv.getLoc(), 1, 32)); + } // for (int stage = 0; stage < numStages - 1; ++stage) // async.wait & extract_slice Operation *asyncWait = builder.create( loads[0].getLoc(), loads.size() * (numStages - 2)); - for (int i = numStages - 2; i >= 0; --i) { - for (auto it = loads.rbegin(); it != loads.rend(); ++it) { - // move extract_slice after asyncWait - Value load = *it; - valueMapping[loadsMapping[load]][i].getDefiningOp()->moveAfter(asyncWait); - } + for (Value loadOp : loads) { + Value extractSlice = builder.create( + loadOp.getLoc(), loadsMapping[loadOp].getType(), + loadStageBuffer[loadOp][numStages - 1], + builder.create(loadOp.getLoc(), 0, 32), + /*axis*/ 0); + loadsExtract[loadOp] = extractSlice; } + + loopIterIdx = builder.create(iv.getLoc(), 0, 32); } scf::ForOp LoopPipeliner::createNewForOp() { @@ -344,10 +364,12 @@ scf::ForOp LoopPipeliner::createNewForOp() { // order of new args: // (original args), - // for each load result (after layout conversion) x: - // (x at stage[0, numStages-1)) + // (insertSliceAsync buffer at stage numStages - 1) for each load + // (extracted tensor) for each load // (depArgs at stage numStages-1) // (iv at stage numStages-1) + // (pipeline iteration index) + // (loop iteration index) SmallVector newLoopArgs; // We need this to update operands for yield // original block arg => new arg's idx @@ -355,10 +377,12 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (auto v : forOp.getIterOperands()) newLoopArgs.push_back(v); + size_t bufferIdx = newLoopArgs.size(); + for (Value loadOp : loads) + newLoopArgs.push_back(loadStageBuffer[loadOp].back()); size_t loadIdx = newLoopArgs.size(); for (Value loadOp : loads) - for (int i = 0; i < numStages - 1; ++i) - newLoopArgs.push_back(valueMapping[loadsMapping[loadOp]][i]); + newLoopArgs.push_back(loadsExtract[loadOp]); size_t depArgsBeginIdx = newLoopArgs.size(); for (BlockArgument depArg : depArgs) { @@ -368,6 +392,8 @@ scf::ForOp LoopPipeliner::createNewForOp() { size_t nextIVIdx = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); + newLoopArgs.push_back(pipelineIterIdx); + newLoopArgs.push_back(loopIterIdx); for (size_t i = 0; i < newLoopArgs.size(); ++i) assert(newLoopArgs[i]); @@ -399,7 +425,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { "we assume that this load has one use (ConvertLayout)"); Value loadUse = load.getUsers().begin()->getResult(0); mapping.lookup(loadUse).replaceAllUsesWith( - newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]); + newForOp.getRegionIterArgs()[loadIdx + idx]); // delete old load and layout conversion mapping.lookup(loadUse).getDefiningOp()->erase(); mapping.lookup(load).getDefiningOp()->erase(); @@ -432,12 +458,18 @@ scf::ForOp LoopPipeliner::createNewForOp() { nextIV, newForOp.getUpperBound()); // slice index + SmallVector nextBuffers; SmallVector extractSlices; - Value sliceIndex = builder.create( - nextIV.getLoc(), builder.getI32Type(), nextIV); - sliceIndex = builder.create( - nextIV.getLoc(), sliceIndex, + + pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1]; + Value insertSliceIndex = builder.create( + nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), numStages, 32)); + loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2]; + Value extractSliceIndex = builder.create( + nextIV.getLoc(), loopIterIdx, + builder.create(nextIV.getLoc(), numStages, 32)); + for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; // TODO(da): does this work if loadOp has no mask? @@ -457,13 +489,15 @@ scf::ForOp LoopPipeliner::createNewForOp() { } Value insertAsyncOp = builder.create( op->getLoc(), loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.ptr()), loadsBuffer[loadOp], - sliceIndex, nextMapping.lookupOrDefault(loadOp.mask()), + nextMapping.lookupOrDefault(loadOp.ptr()), + newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], + insertSliceIndex, nextMapping.lookupOrDefault(loadOp.mask()), nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); + nextBuffers.push_back(insertAsyncOp); nextOp = builder.create( op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp, - sliceIndex, /*axis*/ 0); + extractSliceIndex, /*axis*/ 0); extractSlices.push_back(nextOp->getResult(0)); } else nextOp = builder.clone(*op, nextMapping); @@ -491,25 +525,29 @@ scf::ForOp LoopPipeliner::createNewForOp() { it->getDefiningOp()->moveAfter(asyncWait); } + // bump iteration count + pipelineIterIdx = builder.create( + nextIV.getLoc(), pipelineIterIdx, + builder.create(nextIV.getLoc(), 1, 32)); + loopIterIdx = builder.create( + nextIV.getLoc(), loopIterIdx, + builder.create(nextIV.getLoc(), 1, 32)); + // Finally, the YieldOp, need to sync with the order of newLoopArgs SmallVector yieldValues; for (Value v : forOp.getBody()->getTerminator()->getOperands()) yieldValues.push_back(mapping.lookup(v)); - // shift pipelined args by 1 - for (size_t idx = 0; idx < loads.size(); ++idx) { - Value load = loads[idx]; - for (int stage = 1; stage < numStages - 1; ++stage) { - yieldValues.push_back( - newForOp - .getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]); - } - yieldValues.push_back(nextMapping.lookup(load)); - } + for (Value nextBuffer : nextBuffers) + yieldValues.push_back(nextBuffer); + for (Value nextSlice : extractSlices) + yieldValues.push_back(nextSlice); for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) yieldValues.push_back( depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); + yieldValues.push_back(pipelineIterIdx); + yieldValues.push_back(loopIterIdx); builder.setInsertionPointToEnd(newForOp.getBody()); auto test = builder.create( diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index d951e103d..ae5f9191f 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -16,18 +16,16 @@ // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]] -// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] -// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}) // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]] func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -66,18 +64,16 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]] -// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] -// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}) // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]] func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { scf.for %iv0 = %lb to %ub step %step { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -114,14 +110,13 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL>