From c529b462f5fb0ba42f8cbfa28f7da10736b267ba Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 26 May 2022 13:14:41 +0800 Subject: [PATCH] more fixes on pipeline.cpp --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 00da768b9..a99ae7e4b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -159,18 +159,10 @@ LogicalResult LoopPipeliner::initialize() { return success(); } - // llvm::errs() << allLoads.size() << " loads inside the loop\n" - // << loads.size() << " loads to be pipelined\n"; - return failure(); } void LoopPipeliner::emitPrologue() { - // llvm::errs() << "to pipeline...\n"; - // for (Value load : loads) - // llvm::errs() << load << "\n"; - - // TODO: should we use rewriter here? OpBuilder builder(forOp); for (BlockArgument &arg : forOp.getRegionIterArgs()) { OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); @@ -214,7 +206,7 @@ void LoopPipeliner::emitPrologue() { llvm_unreachable("This should be LoadOp"); } else newOp = builder.clone(*op); - // llvm::errs() << "cloning " << *op << "\n"; + for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { auto it = valueMapping.find(op->getOperand(opIdx)); if (it != valueMapping.end()) { @@ -224,7 +216,20 @@ void LoopPipeliner::emitPrologue() { } // else, op at opIdx is a loop-invariant value } - // TODO: if this is a load, we need to update the mask + // If this is a load/async_copy, we need to update the mask + if (llvm::isa(newOp)) { + Value mask = newOp->getOperand(1); + // assert(I1 or TensorOf<[I1]>); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(newOp); + Value splatCond = builder.create(mask.getLoc(), + mask.getType(), + loopCond); + Value newMask = builder.create(mask.getLoc(), + mask, + splatCond); + newOp->setOperand(1, newMask); + } // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { @@ -273,8 +278,6 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (size_t i = 0; i < newLoopArgs.size(); ++i) assert(newLoopArgs[i]); - // llvm::errs() << "mapped load is:\n" << newLoopArgs[loadIdx] << "\n\n"; - // 1. signature of the new ForOp auto newForOp = builder.create(forOp.getLoc(), forOp.getLowerBound(), @@ -295,7 +298,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); } - // 3. replace loads with args + // 3. replace loads with block args (from prologue) for (size_t idx = 0; idx < loads.size(); ++idx) { Value load = loads[idx]; mapping.lookup(load).replaceAllUsesWith( @@ -418,16 +421,10 @@ struct PipelinePass : public TritonGPUPipelineBase { if (pipeliner.initialize().failed()) return; - // llvm::errs() << "find a loop to pipeline...\n"; pipeliner.emitPrologue(); - // llvm::errs() << "\nprologue emitted\n" - // << *forOp->getParentOp(); scf::ForOp newForOp = pipeliner.createNewForOp(); - // llvm::errs() << "new for created:\n" << newForOp << "\n" - // << "inside:\n" << *newForOp->getParentOp() << "\n"; - // replace the original loop for (unsigned i = 0; i < forOp->getNumResults(); ++i) forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));