[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)
1. Rewrite code generation of insert_slice_async. 2. Correct the wrong index passed to extract_slice in pipeline. 3. Add a prologue in pipeline to wait for dangling cp.asyncs. 4. Move scf to cf conversion inside TritonGPUToLLVM because we need to perform membar before scf to cf. It shouldn't be a technical limitation and could be improved by a more general membar analysis. 5. Use an attribute to memoize the shared memory size and support dynamic shared memory. 6. Prevent the combine pass to reorder insert_slice and extract_slice across async_wait Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -78,6 +78,9 @@ public:
|
||||
/// emit pipelined loads (before loop body)
|
||||
void emitPrologue();
|
||||
|
||||
/// emit pipelined loads (after loop body)
|
||||
void emitEpilogue();
|
||||
|
||||
/// create the new ForOp (add new args & insert prefetched ops)
|
||||
scf::ForOp createNewForOp();
|
||||
|
||||
@@ -362,6 +365,23 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// *next* iteration
|
||||
loopIterIdx = builder.create<arith::AddIOp>(
|
||||
loopIterIdx.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
|
||||
}
|
||||
|
||||
void LoopPipeliner::emitEpilogue() {
|
||||
// If there's any outstanding async copies, we need to wait for them.
|
||||
// TODO(Keren): We may want to completely avoid the async copies in the last
|
||||
// few iterations by setting is_masked attribute to true. We don't want to use
|
||||
// the mask operand because it's a tensor but not a scalar.
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Operation *asyncWait =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
@@ -581,6 +601,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
pipeliner.emitEpilogue();
|
||||
|
||||
// replace the original loop
|
||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||
|
Reference in New Issue
Block a user