more fixes on pipeline.cpp
This commit is contained in:
@@ -159,18 +159,10 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
return success();
|
||||
}
|
||||
|
||||
// llvm::errs() << allLoads.size() << " loads inside the loop\n"
|
||||
// << loads.size() << " loads to be pipelined\n";
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
void LoopPipeliner::emitPrologue() {
|
||||
// llvm::errs() << "to pipeline...\n";
|
||||
// for (Value load : loads)
|
||||
// llvm::errs() << load << "\n";
|
||||
|
||||
// TODO: should we use rewriter here?
|
||||
OpBuilder builder(forOp);
|
||||
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||
@@ -214,7 +206,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
llvm_unreachable("This should be LoadOp");
|
||||
} else
|
||||
newOp = builder.clone(*op);
|
||||
// llvm::errs() << "cloning " << *op << "\n";
|
||||
|
||||
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
||||
auto it = valueMapping.find(op->getOperand(opIdx));
|
||||
if (it != valueMapping.end()) {
|
||||
@@ -224,7 +216,20 @@ void LoopPipeliner::emitPrologue() {
|
||||
} // else, op at opIdx is a loop-invariant value
|
||||
}
|
||||
|
||||
// TODO: if this is a load, we need to update the mask
|
||||
// If this is a load/async_copy, we need to update the mask
|
||||
if (llvm::isa<triton::LoadOp, triton::gpu::CopyAsyncOp>(newOp)) {
|
||||
Value mask = newOp->getOperand(1);
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(newOp);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||
mask.getType(),
|
||||
loopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||
mask,
|
||||
splatCond);
|
||||
newOp->setOperand(1, newMask);
|
||||
}
|
||||
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
@@ -273,8 +278,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||
assert(newLoopArgs[i]);
|
||||
|
||||
// llvm::errs() << "mapped load is:\n" << newLoopArgs[loadIdx] << "\n\n";
|
||||
|
||||
// 1. signature of the new ForOp
|
||||
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
||||
forOp.getLowerBound(),
|
||||
@@ -295,7 +298,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
||||
}
|
||||
|
||||
// 3. replace loads with args
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
Value load = loads[idx];
|
||||
mapping.lookup(load).replaceAllUsesWith(
|
||||
@@ -418,16 +421,10 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
if (pipeliner.initialize().failed())
|
||||
return;
|
||||
|
||||
// llvm::errs() << "find a loop to pipeline...\n";
|
||||
pipeliner.emitPrologue();
|
||||
// llvm::errs() << "\nprologue emitted\n"
|
||||
// << *forOp->getParentOp();
|
||||
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
// llvm::errs() << "new for created:\n" << newForOp << "\n"
|
||||
// << "inside:\n" << *newForOp->getParentOp() << "\n";
|
||||
|
||||
// replace the original loop
|
||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||
|
Reference in New Issue
Block a user