Keren/tensor slice insert alloc (#94)

This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:

```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```

We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
This commit is contained in:
Keren Zhou
2022-09-01 12:37:17 -07:00
committed by GitHub
parent d01353de07
commit 328b87aec6
10 changed files with 260 additions and 40 deletions

View File

@@ -2,6 +2,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -329,6 +330,72 @@ void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// InsertSliceAsyncOp
//===----------------------------------------------------------------------===//
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type srcType, dstType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(srcType) || parser.parseArrow() ||
parser.parseCustomTypeWithFallback(dstType))
return failure();
result.addTypes(dstType);
SmallVector<Type> operandTypes;
operandTypes.push_back(srcType); // src
operandTypes.push_back(dstType); // dst
operandTypes.push_back(
IntegerType::get(parser.getBuilder().getContext(), 32)); // offset
if (allOperands.size() >= 4)
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
if (allOperands.size() >= 5)
operandTypes.push_back(triton::getPointeeType(srcType)); // other
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printInsertSliceAsyncOp(OpAsmPrinter &printer,
InsertSliceAsyncOp insertSliceAsyncOp) {
printer << " ";
printer << insertSliceAsyncOp.getOperation()->getOperands();
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(),
/*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
printer << " -> ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// ExtractSliceOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
::mlir::RegionRange regions,
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
auto srcType = operands[0].getType().cast<RankedTensorType>();
auto encoding = srcType.getEncoding();
auto srcShape = srcType.getShape();
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || axis > srcShape.size())
return failure();
auto dstShape = srcShape.drop_front(axis + 1);
auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
inferredReturnTypes.assign({returnType});
return success();
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
@@ -372,13 +439,30 @@ void TritonGPUDialect::initialize() {
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding.isa<SharedEncodingAttr>())
return op.emitOpError("copy_async should return a shared memory tensor");
} else
return op.emitOpError("copy_async should return a tensor");
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("copy_async should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(InsertSliceAsyncOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("copy_async should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(ExtractSliceOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("extract_slice should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(AllocTensorOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("alloc_tensor should return a shared memory tensor");
}
return success();
}