[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)

## Features

- Allow taking a block of tensor slice, as long as each dimension is
contiguous (unit stride).
- Fix some problems in `insert_slice_async`'s semantic.
- More general verification for ops that return shared layout encoding.

## Known Limitations

- `insert_slice_async` still uses the old semantic. May submit another
PR later to support similar semantic like `tensor.extract_slice`.
- No encoding verification for `tensor.extract_slice`.
- 3d tensor ops are broken.
- Strided accesses are not allowed.
- May cause a little performance slowdown since we are passing strides
as values but not constants (e.g., int).
It would be difficult to pass strides as attributes when we have control
flows. A block argument is possible to accept tensors with different
strides.
This commit is contained in:
Keren Zhou
2022-11-06 22:59:03 -08:00
committed by GitHub
parent a4ff0c362c
commit fdd59900f7
26 changed files with 507 additions and 339 deletions

View File

@@ -111,37 +111,41 @@ public:
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
auto newType = op->getResult(0).getType();
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.dst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.src(), new_arg.getResult(),
op, newType, insert_slice.src(), newArg.getResult(),
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
insert_slice.axis());
insert_slice.cache(), insert_slice.evict(),
insert_slice.isVolatile(), insert_slice.axis());
return mlir::success();
}
// cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) {
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
auto origType = extract_slice.source().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.src());
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.source());
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
extract_slice.static_strides());
return mlir::success();
}
// cvt(type2, x)
@@ -198,7 +202,7 @@ static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
inline bool expensive_to_remat(Operation *op) {
if (!op)
return true;
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
return true;

View File

@@ -339,14 +339,20 @@ void LoopPipeliner::emitPrologue() {
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
// async.wait & extract_slice
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
loadOp.getLoc(), loadsMapping[loadOp].getType(),
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
loadsExtract[loadOp] = extractSlice;
}
// bump up loopIterIdx, this is used for getting the correct slice for the
@@ -477,6 +483,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
@@ -503,9 +513,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
extractSliceIndex, /*axis*/ 0);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
SmallVector<OpFoldResult>{intAttr(1),
intAttr(sliceType.getShape()[0]),
intAttr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);