[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features - Allow taking a block of tensor slice, as long as each dimension is contiguous (unit stride). - Fix some problems in `insert_slice_async`'s semantic. - More general verification for ops that return shared layout encoding. ## Known Limitations - `insert_slice_async` still uses the old semantic. May submit another PR later to support similar semantic like `tensor.extract_slice`. - No encoding verification for `tensor.extract_slice`. - 3d tensor ops are broken. - Strided accesses are not allowed. - May cause a little performance slowdown since we are passing strides as values but not constants (e.g., int). It would be difficult to pass strides as attributes when we have control flows. A block argument is possible to accept tensors with different strides.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
add_mlir_dialect_library(TritonGPUIR
|
||||
Dialect.cpp
|
||||
Traits.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUTableGen
|
||||
|
@@ -474,7 +474,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
|
||||
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
SmallVector<OpAsmParser::OperandType, 8> allOperands;
|
||||
Type srcType, dstType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
@@ -489,14 +489,27 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
operandTypes.push_back(dstType); // dst
|
||||
operandTypes.push_back(
|
||||
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
|
||||
if (allOperands.size() >= 4)
|
||||
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 4) {
|
||||
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
||||
if (allOperands.size() >= 5)
|
||||
hasMask = 1;
|
||||
}
|
||||
if (allOperands.size() >= 5) {
|
||||
operandTypes.push_back(triton::getPointeeType(srcType)); // other
|
||||
hasOther = 1;
|
||||
}
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
auto operand_segment_sizesAttrName =
|
||||
InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
operand_segment_sizesAttrName,
|
||||
parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther}));
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -504,39 +517,16 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
InsertSliceAsyncOp insertSliceAsyncOp) {
|
||||
printer << " ";
|
||||
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(),
|
||||
/*elidedAttrs=*/{});
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(
|
||||
insertSliceAsyncOp->getAttrs(),
|
||||
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
||||
printer << " -> ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
|
||||
::mlir::RegionRange regions,
|
||||
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
auto srcType = operands[0].getType().cast<RankedTensorType>();
|
||||
auto encoding = srcType.getEncoding();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (axis < 0 || (size_t)axis > srcShape.size())
|
||||
return failure();
|
||||
SmallVector<int64_t, 4> dstShape;
|
||||
for (size_t i = 0; i < srcShape.size(); i++)
|
||||
if (i != (size_t)axis)
|
||||
dstShape.push_back(srcShape[i]);
|
||||
auto returnType =
|
||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||
inferredReturnTypes.assign({returnType});
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DotOperand Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -631,32 +621,6 @@ void TritonGPUDialect::initialize() {
|
||||
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(InsertSliceAsyncOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError(
|
||||
"insert_slice_async should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractSliceOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("extract_slice should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AllocTensorOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("alloc_tensor should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
||||
|
14
lib/Dialect/TritonGPU/IR/Traits.cpp
Normal file
14
lib/Dialect/TritonGPU/IR/Traits.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
for (auto result : op->getResults())
|
||||
if (!isSharedEncoding(result))
|
||||
return op->emitOpError() << "requires all results to be shared encoding";
|
||||
|
||||
return success();
|
||||
};
|
@@ -111,37 +111,41 @@ public:
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
auto newType = op->getResult(0).getType();
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.dst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.src(), new_arg.getResult(),
|
||||
op, newType, insert_slice.src(), newArg.getResult(),
|
||||
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
||||
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
|
||||
insert_slice.axis());
|
||||
insert_slice.cache(), insert_slice.evict(),
|
||||
insert_slice.isVolatile(), insert_slice.axis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
|
||||
auto origType = extract_slice.source().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.src());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
|
||||
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.source());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type2, x)
|
||||
@@ -198,7 +202,7 @@ static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
inline bool expensive_to_remat(Operation *op) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
||||
return true;
|
||||
|
@@ -339,14 +339,20 @@ void LoopPipeliner::emitPrologue() {
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// async.wait & extract_slice
|
||||
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
||||
loads.size() * (numStages - 2));
|
||||
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (Value loadOp : loads) {
|
||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
loadOp.getLoc(), loadsMapping[loadOp].getType(),
|
||||
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
|
||||
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
@@ -477,6 +483,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||
nextIV.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
extractSliceIndex = builder.create<arith::IndexCastOp>(
|
||||
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
|
||||
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
@@ -503,9 +513,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
nextBuffers.push_back(insertAsyncOp);
|
||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
||||
extractSliceIndex, /*axis*/ 0);
|
||||
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1),
|
||||
intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
extractSlices.push_back(nextOp->getResult(0));
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
|
Reference in New Issue
Block a user