[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)

1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Keren Zhou
2022-10-27 22:09:06 -07:00
committed by GitHub
parent 42db3538e4
commit 3b80801dff
10 changed files with 122 additions and 70 deletions

View File

@@ -78,6 +78,9 @@ public:
/// emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
void emitEpilogue();
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
@@ -362,6 +365,23 @@ void LoopPipeliner::emitPrologue() {
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
loadsExtract[loadOp] = extractSlice;
}
// bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
}
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
// TODO(Keren): We may want to completely avoid the async copies in the last
// few iterations by setting is_masked attribute to true. We don't want to use
// the mask operand because it's a tensor but not a scalar.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
Operation *asyncWait =
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
}
scf::ForOp LoopPipeliner::createNewForOp() {
@@ -581,6 +601,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
scf::ForOp newForOp = pipeliner.createNewForOp();
pipeliner.emitEpilogue();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));