more fixes on pipeline.cpp

This commit is contained in:
Yan Da
2022-05-26 13:14:41 +08:00
parent 71d1c10e19
commit c529b462f5

View File

@@ -159,18 +159,10 @@ LogicalResult LoopPipeliner::initialize() {
return success();
}
// llvm::errs() << allLoads.size() << " loads inside the loop\n"
// << loads.size() << " loads to be pipelined\n";
return failure();
}
void LoopPipeliner::emitPrologue() {
// llvm::errs() << "to pipeline...\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
// TODO: should we use rewriter here?
OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
@@ -214,7 +206,7 @@ void LoopPipeliner::emitPrologue() {
llvm_unreachable("This should be LoadOp");
} else
newOp = builder.clone(*op);
// llvm::errs() << "cloning " << *op << "\n";
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx));
if (it != valueMapping.end()) {
@@ -224,7 +216,20 @@ void LoopPipeliner::emitPrologue() {
} // else, op at opIdx is a loop-invariant value
}
// TODO: if this is a load, we need to update the mask
// If this is a load/async_copy, we need to update the mask
if (llvm::isa<triton::LoadOp, triton::gpu::CopyAsyncOp>(newOp)) {
Value mask = newOp->getOperand(1);
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
loopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
mask,
splatCond);
newOp->setOperand(1, newMask);
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -273,8 +278,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// llvm::errs() << "mapped load is:\n" << newLoopArgs[loadIdx] << "\n\n";
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(),
@@ -295,7 +298,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// 3. replace loads with args
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
mapping.lookup(load).replaceAllUsesWith(
@@ -418,16 +421,10 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
if (pipeliner.initialize().failed())
return;
// llvm::errs() << "find a loop to pipeline...\n";
pipeliner.emitPrologue();
// llvm::errs() << "\nprologue emitted\n"
// << *forOp->getParentOp();
scf::ForOp newForOp = pipeliner.createNewForOp();
// llvm::errs() << "new for created:\n" << newForOp << "\n"
// << "inside:\n" << *newForOp->getParentOp() << "\n";
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));