[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)
1. Rewrite code generation of insert_slice_async. 2. Correct the wrong index passed to extract_slice in pipeline. 3. Add a prologue in pipeline to wait for dangling cp.asyncs. 4. Move scf to cf conversion inside TritonGPUToLLVM because we need to perform membar before scf to cf. It shouldn't be a technical limitation and could be improved by a more general membar analysis. 5. Use an attribute to memoize the shared memory size and support dynamic shared memory. 6. Prevent the combine pass to reorder insert_slice and extract_slice across async_wait Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -103,16 +103,21 @@ public:
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
// cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
auto newType = op->getResult(0).getType();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.dst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
@@ -126,6 +131,11 @@ public:
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
|
@@ -78,6 +78,9 @@ public:
|
||||
/// emit pipelined loads (before loop body)
|
||||
void emitPrologue();
|
||||
|
||||
/// emit pipelined loads (after loop body)
|
||||
void emitEpilogue();
|
||||
|
||||
/// create the new ForOp (add new args & insert prefetched ops)
|
||||
scf::ForOp createNewForOp();
|
||||
|
||||
@@ -362,6 +365,23 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// *next* iteration
|
||||
loopIterIdx = builder.create<arith::AddIOp>(
|
||||
loopIterIdx.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
|
||||
}
|
||||
|
||||
void LoopPipeliner::emitEpilogue() {
|
||||
// If there's any outstanding async copies, we need to wait for them.
|
||||
// TODO(Keren): We may want to completely avoid the async copies in the last
|
||||
// few iterations by setting is_masked attribute to true. We don't want to use
|
||||
// the mask operand because it's a tensor but not a scalar.
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Operation *asyncWait =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
@@ -581,6 +601,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
pipeliner.emitEpilogue();
|
||||
|
||||
// replace the original loop
|
||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||
|
Reference in New Issue
Block a user