more fixes on pipeline.cpp
This commit is contained in:
@@ -159,18 +159,10 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// llvm::errs() << allLoads.size() << " loads inside the loop\n"
|
|
||||||
// << loads.size() << " loads to be pipelined\n";
|
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoopPipeliner::emitPrologue() {
|
void LoopPipeliner::emitPrologue() {
|
||||||
// llvm::errs() << "to pipeline...\n";
|
|
||||||
// for (Value load : loads)
|
|
||||||
// llvm::errs() << load << "\n";
|
|
||||||
|
|
||||||
// TODO: should we use rewriter here?
|
|
||||||
OpBuilder builder(forOp);
|
OpBuilder builder(forOp);
|
||||||
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||||
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||||
@@ -214,7 +206,7 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
llvm_unreachable("This should be LoadOp");
|
llvm_unreachable("This should be LoadOp");
|
||||||
} else
|
} else
|
||||||
newOp = builder.clone(*op);
|
newOp = builder.clone(*op);
|
||||||
// llvm::errs() << "cloning " << *op << "\n";
|
|
||||||
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
|
||||||
auto it = valueMapping.find(op->getOperand(opIdx));
|
auto it = valueMapping.find(op->getOperand(opIdx));
|
||||||
if (it != valueMapping.end()) {
|
if (it != valueMapping.end()) {
|
||||||
@@ -224,7 +216,20 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
} // else, op at opIdx is a loop-invariant value
|
} // else, op at opIdx is a loop-invariant value
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: if this is a load, we need to update the mask
|
// If this is a load/async_copy, we need to update the mask
|
||||||
|
if (llvm::isa<triton::LoadOp, triton::gpu::CopyAsyncOp>(newOp)) {
|
||||||
|
Value mask = newOp->getOperand(1);
|
||||||
|
// assert(I1 or TensorOf<[I1]>);
|
||||||
|
OpBuilder::InsertionGuard g(builder);
|
||||||
|
builder.setInsertionPoint(newOp);
|
||||||
|
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
|
||||||
|
mask.getType(),
|
||||||
|
loopCond);
|
||||||
|
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
|
||||||
|
mask,
|
||||||
|
splatCond);
|
||||||
|
newOp->setOperand(1, newMask);
|
||||||
|
}
|
||||||
|
|
||||||
// update mapping of results
|
// update mapping of results
|
||||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||||
@@ -273,8 +278,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
for (size_t i = 0; i < newLoopArgs.size(); ++i)
|
||||||
assert(newLoopArgs[i]);
|
assert(newLoopArgs[i]);
|
||||||
|
|
||||||
// llvm::errs() << "mapped load is:\n" << newLoopArgs[loadIdx] << "\n\n";
|
|
||||||
|
|
||||||
// 1. signature of the new ForOp
|
// 1. signature of the new ForOp
|
||||||
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
|
||||||
forOp.getLowerBound(),
|
forOp.getLowerBound(),
|
||||||
@@ -295,7 +298,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
|||||||
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. replace loads with args
|
// 3. replace loads with block args (from prologue)
|
||||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||||
Value load = loads[idx];
|
Value load = loads[idx];
|
||||||
mapping.lookup(load).replaceAllUsesWith(
|
mapping.lookup(load).replaceAllUsesWith(
|
||||||
@@ -418,16 +421,10 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
|||||||
if (pipeliner.initialize().failed())
|
if (pipeliner.initialize().failed())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// llvm::errs() << "find a loop to pipeline...\n";
|
|
||||||
pipeliner.emitPrologue();
|
pipeliner.emitPrologue();
|
||||||
// llvm::errs() << "\nprologue emitted\n"
|
|
||||||
// << *forOp->getParentOp();
|
|
||||||
|
|
||||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||||
|
|
||||||
// llvm::errs() << "new for created:\n" << newForOp << "\n"
|
|
||||||
// << "inside:\n" << *newForOp->getParentOp() << "\n";
|
|
||||||
|
|
||||||
// replace the original loop
|
// replace the original loop
|
||||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||||
|
Reference in New Issue
Block a user