diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 4bf30296f..93b6351de 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -32,6 +32,8 @@ class LoopPipeliner { SetVector loads; /// the value that each load will be mapped to (after layout conversion) DenseMap loadsMapping; + /// load => buffer + DenseMap loadsBuffer; /// value (in loop) => value at stage N DenseMap> valueMapping; @@ -46,9 +48,15 @@ class LoopPipeliner { void setValueMapping(Value origin, Value newValue, int stage); + Value lookupOrDefault(Value origin, int stage); + /// return true if this op uses any of `loads` bool isDirectUserOfAsyncLoad(Operation &op); + /// returns a empty buffer of size + triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op, + OpBuilder &builder); + public: LoopPipeliner(scf::ForOp forOp, int numStages) : forOp(forOp), numStages(numStages) { @@ -75,6 +83,12 @@ void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { valueMapping[origin][stage] = newValue; } +Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { + if (valueMapping.find(origin) == valueMapping.end()) + return origin; + return valueMapping[origin][stage]; +} + void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { // Loop-invarant value. skip if (v.getParentRegion() != &forOp.getLoopBody()) @@ -111,6 +125,25 @@ bool LoopPipeliner::isDirectUserOfAsyncLoad(Operation &op) { return false; } +triton::gpu::AllocTensorOp +LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { + // allocate a buffer for each pipelined tensor + // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> + Value convertLayout = loadsMapping[op->getResult(0)]; + if (auto tensorType = convertLayout.getType().dyn_cast()) { + SmallVector shape(tensorType.getShape().begin(), + tensorType.getShape().end()); + shape.insert(shape.begin(), numStages); + Type elementType = tensorType.getElementType(); + // The encoding of the buffer is similar to the original tensor + Attribute encoding = tensorType.getEncoding(); + auto bufferType = RankedTensorType::get(shape, elementType, encoding); + return builder.create(convertLayout.getLoc(), + bufferType); + } + llvm_unreachable("Async copy's return should be of RankedTensorType"); +} + /// A load instruction can be pipelined if: /// - the load doesn't depend on any other loads (after loop peeling) /// - (?) this load is not a loop-invariant value (we should run LICM before @@ -137,13 +170,6 @@ LogicalResult LoopPipeliner::initialize() { loadDeps[loadOp] = deps; } - // for (triton::LoadOp loadOp : allLoads) { - // llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << - // " values\n"; for (Value dep : loadDeps[loadOp]) - // llvm::errs() << dep << "\n"; - // llvm::errs() << "\n"; - // } - // Don't pipeline loads that depend on other loads // (Because if a load depends on another load, this load needs to wait on the // other load in the prologue, which is against the point of the pipeline @@ -234,29 +260,40 @@ void LoopPipeliner::emitPrologue() { for (Operation *op : orderedDeps) { Operation *newOp = nullptr; if (loads.contains(op->getResult(0))) { + // Allocate empty buffer + if (stage == 0) + loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder); // load => copy async // TODO: check if the hardware supports copyasync if (auto loadOp = llvm::dyn_cast(op)) { - newOp = builder.create( - op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(), - loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(), - loadOp.isVolatile()); + Value sliceIndex = builder.create( + iv.getLoc(), builder.getI32Type(), iv); + Value insertAsyncOp = builder.create( + op->getLoc(), loadsBuffer[loadOp].getType(), + lookupOrDefault(loadOp.ptr(), stage), loadsBuffer[loadOp], + sliceIndex, lookupOrDefault(loadOp.mask(), stage), + lookupOrDefault(loadOp.other(), stage), loadOp.cache(), + loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); + newOp = builder.create( + op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp, + sliceIndex, /*axis*/ 0); } else llvm_unreachable("This should be LoadOp"); - } else + } else { newOp = builder.clone(*op); - - for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { - auto it = valueMapping.find(op->getOperand(opIdx)); - if (it != valueMapping.end()) { - Value v = it->second[stage]; - assert(v); - newOp->setOperand(opIdx, v); - } // else, op at opIdx is a loop-invariant value + // Update loop-carried uses + for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { + auto it = valueMapping.find(op->getOperand(opIdx)); + if (it != valueMapping.end()) { + Value v = it->second[stage]; + assert(v); + newOp->setOperand(opIdx, v); + } // else, op at opIdx is a loop-invariant value + } } // If this is a load/async_copy, we need to update the mask - if (llvm::isa(newOp)) { + if (llvm::isa(newOp)) { Value mask = newOp->getOperand(1); // assert(I1 or TensorOf<[I1]>); OpBuilder::InsertionGuard g(builder); @@ -265,7 +302,11 @@ void LoopPipeliner::emitPrologue() { mask.getLoc(), mask.getType(), loopCond); Value newMask = builder.create(mask.getLoc(), mask, splatCond); - newOp->setOperand(1, newMask); + // TODO: better way to do this? + if (llvm::isa(newOp)) + newOp->setOperand(1, newMask); + else // InsertSliceAsyncOp + newOp->setOperand(3, newMask); } // update mapping of results @@ -285,6 +326,17 @@ void LoopPipeliner::emitPrologue() { } } } + + // async.wait & extract_slice + Operation *asyncWait = builder.create( + loads[0].getLoc(), loads.size() * (numStages - 2)); + for (int i = numStages - 2; i >= 0; --i) { + for (auto it = loads.rbegin(); it != loads.rend(); ++it) { + // move extract_slice after asyncWait + Value load = *it; + valueMapping[loadsMapping[load]][i].getDefiningOp()->moveAfter(asyncWait); + } + } } scf::ForOp LoopPipeliner::createNewForOp() { @@ -333,14 +385,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { // 2.1 clone the loop body, replace original args with args of the new ForOp // Insert async wait if necessary. - bool asyncWaitInserted = false; for (Operation &op : forOp.getBody()->without_terminator()) { - if (!asyncWaitInserted && isDirectUserOfAsyncLoad(op)) { - asyncWaitInserted = true; - assert(numStages >= 2); - builder.create(op.getLoc(), - loads.size() * (numStages - 2)); - } Operation *newOp = builder.clone(op, mapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) @@ -385,8 +430,17 @@ scf::ForOp LoopPipeliner::createNewForOp() { Value nextLoopCond = builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); + + // slice index + SmallVector extractSlices; + Value sliceIndex = builder.create( + nextIV.getLoc(), builder.getI32Type(), nextIV); + sliceIndex = builder.create( + nextIV.getLoc(), sliceIndex, + builder.create(nextIV.getLoc(), numStages, 32)); for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; + // TODO(da): does this work if loadOp has no mask? // update loading mask if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); @@ -401,16 +455,18 @@ scf::ForOp LoopPipeliner::createNewForOp() { if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) nextMapping.map(mask, newMask); } - // TODO: more elegant way to do this? - nextOp = builder.create( - op->getLoc(), loadsMapping[op->getResult(0)].getType(), - nextMapping.lookupOrDefault(loadOp.ptr()), - nextMapping.lookupOrDefault(loadOp.mask()), + Value insertAsyncOp = builder.create( + op->getLoc(), loadsBuffer[loadOp].getType(), + nextMapping.lookupOrDefault(loadOp.ptr()), loadsBuffer[loadOp], + sliceIndex, nextMapping.lookupOrDefault(loadOp.mask()), nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(), - loadOp.evict(), loadOp.isVolatile()); + loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); + nextOp = builder.create( + op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp, + sliceIndex, /*axis*/ 0); + extractSlices.push_back(nextOp->getResult(0)); } else nextOp = builder.clone(*op, nextMapping); - // llvm::errs() << "epilogue cloning...: " << *op << "\n"; // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); @@ -427,6 +483,14 @@ scf::ForOp LoopPipeliner::createNewForOp() { } } + // async.wait & extract_slice + Operation *asyncWait = builder.create( + loads[0].getLoc(), loads.size() * (numStages - 2)); + for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { + // move extract_slice after asyncWait + it->getDefiningOp()->moveAfter(asyncWait); + } + // Finally, the YieldOp, need to sync with the order of newLoopArgs SmallVector yieldValues; for (Value v : forOp.getBody()->getTerminator()->getOperands()) diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index d3fb45f39..d951e103d 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -9,15 +9,24 @@ #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> // CHECK: func @matmul_loop -// CHECK: %[[A0:.*]] = triton_gpu.copy_async -// CHECK: %[[B0:.*]] = triton_gpu.copy_async -// CHECK: %[[A1:.*]] = triton_gpu.copy_async -// CHECK: %[[B1:.*]] = triton_gpu.copy_async -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) +// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]] +// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] +// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async -// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]] func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -50,15 +59,24 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: func @matmul_loop_nested // CHECK: scf.for -// CHECK: %[[A0:.*]] = triton_gpu.copy_async -// CHECK: %[[B0:.*]] = triton_gpu.copy_async -// CHECK: %[[A1:.*]] = triton_gpu.copy_async -// CHECK: %[[B1:.*]] = triton_gpu.copy_async -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) +// CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async // CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]] +// CHECK: %[[A1:.*]] = triton_gpu.extract_slice %[[A1BUFFER]] +// CHECK: %[[B1:.*]] = triton_gpu.extract_slice %[[B1BUFFER]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_a1:.*]] = %[[A1]], %[[arg_b0:.*]] = %[[B0]], %[[arg_b1:.*]] = %[[B1]], {{.*}}) // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.copy_async -// CHECK: %[[NEXT_B:.*]] = triton_gpu.copy_async +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[arg_a1]], %[[NEXT_A]], %[[arg_b1]], %[[NEXT_B]] func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { scf.for %iv0 = %lb to %ub step %step { @@ -92,12 +110,17 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL>