Keren/tensor slice insert alloc (#94)

This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:

```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```

We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
This commit is contained in:
Keren Zhou
2022-09-01 12:37:17 -07:00
committed by GitHub
parent d01353de07
commit 328b87aec6
10 changed files with 260 additions and 40 deletions

View File

@@ -35,7 +35,7 @@ int main(int argc, char **argv) {
registry
.insert<mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect>();
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));

View File

@@ -37,8 +37,8 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
mlir::DialectRegistry registry;
registry
.insert<TritonDialect, gpu::TritonGPUDialect, arith::ArithmeticDialect,
StandardOpsDialect, scf::SCFDialect>();
.insert<TritonDialect, triton::gpu::TritonGPUDialect,
arith::ArithmeticDialect, StandardOpsDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);

View File

@@ -17,24 +17,24 @@ class AllocationAnalysis;
}
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
/// A class that represents a range, specified using a start and an end values:
/// [Start, End).
template <typename T> class Range {
/// A class that represents an interval, specified using a start and an end
/// values: [Start, End).
template <typename T> class Interval {
public:
Range() {}
Range(T S, T E) : Start(S), End(E) { assert(Start <= End); }
Interval() {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
T size() const { return End - Start; }
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
bool intersects(const Range &R) const {
bool intersects(const Interval &R) const {
return Start < R.End && R.Start < End;
}
bool operator==(const Range &R) const {
bool operator==(const Interval &R) const {
return Start == R.Start && End == R.End;
}
bool operator!=(const Range &R) const { return !(*this == R); }
bool operator<(const Range &R) const {
bool operator!=(const Interval &R) const { return !(*this == R); }
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
@@ -137,8 +137,9 @@ private:
: kind(kind), size(size), offset(offset), id(nextId++) {}
bool intersects(const BufferT &other) const {
return Range<size_t>(offset, offset + size)
.intersects(Range<size_t>(other.offset, other.offset + other.size));
return Interval<size_t>(offset, offset + size)
.intersects(
Interval<size_t>(other.offset, other.offset + other.size));
}
};

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"

View File

@@ -15,7 +15,8 @@ def TritonGPU_Dialect : Dialect {
}];
let dependentDialects = [
"triton::TritonDialect"
"triton::TritonDialect",
"mlir::gpu::GPUDialect"
];
}

View File

@@ -58,7 +58,7 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
let results = (outs TT_Type:$result);
let results = (outs TT_Tensor:$result);
// let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)";
let parser = [{ return parseCopyAsyncOp(parser, result); }];
@@ -97,4 +97,137 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
let results = (outs TT_BoolLike:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize,
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
let summary = "insert slice async";
let description = [{
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operations
`$offset` argument and `$axis` attribute.
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
The insert_slice_async operation supports the following arguments:
* src: the tensor that is inserted.
* dst: the tensor into which the `$src` tensor is inserted.
* offset: the offset of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
* mask: optional tensor-rank number of boolean masks which specify which
elements of the `$src` tensor are inserted into the `$dst` tensor.
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
In the future, we may decompose this operation into a sequence of:
* `async` operation to specify a sequence of asynchronous operations
* `load` operation to load a tensor from global memory
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
Example:
```
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
%2 = triton_gpu.insert_slice_async %0, %1, %offset { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
triiton_gpu.async_wait { num = 0 : i32 }
```
}];
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$offset,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let builders = [
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, "Value":$mask,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset,
"Value":$mask, "Value":$other,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
];
let results = (outs TT_Tensor:$result);
//let assemblyFormat = [{
// $src `,` $dst ``
// $offset, $mask, $other
// attr-dict `:` type($src) `->` type($dst)
//}];
// The custom parser could be replaced with oilist in LLVM-16
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInterface]> {
let summary = "extract slice";
let description = [{
The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as
specified by the operation's `$offset` and `$axis` arguments.
The extract_slice operation supports the following arguments:
* src: the tensor that is extracted from.
* offset: the offset at the given `$axis` from which the `$src` tensor is extracted
Example:
```
// Rank-reducing extract_slice.
%1 = tensor.extract_slice %0, %offset {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32>
```
}];
let arguments = (ins TT_Tensor:$src, I32:$offset, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{$src `,` $offset attr-dict `:` type($src) `->` type($result)}];
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
}];
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
let summary = "allocate tensor";
let description = [{
This operation defines a tensor of a particular shape.
The contents of the tensor are supposed to be in shared memory.
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let results = (outs TT_Tensor:$result);
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
#endif

View File

@@ -80,7 +80,8 @@ def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
let constructor = "mlir::createTritonGPUVerifier()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::gpu::GPUDialect"];
}
#endif

View File

@@ -27,7 +27,7 @@ private:
/// Value -> Liveness Range
/// Use MapVector to ensure determinism.
using BufferRangeMapT = llvm::MapVector<BufferT *, Range<size_t>>;
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
/// Nodes -> Nodes
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
@@ -110,7 +110,7 @@ private:
/// Computes the liveness range of the allocated value.
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Range<size_t>(Value value)> getLiveness) {
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
@@ -122,7 +122,7 @@ private:
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Range<size_t>(Value value)> getLiveness) {
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
@@ -135,7 +135,7 @@ private:
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
}
bufferRange[buffer] = Range(minId, maxId);
bufferRange[buffer] = Interval(minId, maxId);
}
}
}
@@ -151,8 +151,8 @@ private:
// range.
auto *op = opScratchIter.first;
auto *buffer = opScratchIter.second;
bufferRange.insert(
{buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)});
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
}
@@ -179,7 +179,7 @@ private:
maxId = operationId[liveOp] + 1;
}
});
return Range(minId, maxId);
return Interval(minId, maxId);
};
resolveExplicitBufferLiveness(getValueLivenessRange);
@@ -223,9 +223,9 @@ private:
// |---------------------------------------------| liveness range
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
/// Start -> Liveness Range
using TripleMapT = std::multimap<size_t, Range<size_t>>;
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Range<size_t>()));
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
@@ -246,12 +246,12 @@ private:
auto xRange = bufferRange.lookup(buffer);
bufferStart[buffer] = size;
tripleMap.insert(
{size + xSize, Range{std::max(range.start(), xRange.start()),
{size + xSize, Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
if (range.start() < xRange.start())
tripleMap.insert({size, Range{range.start(), xRange.end()}});
tripleMap.insert({size, Interval{range.start(), xRange.end()}});
if (xRange.end() < range.end())
tripleMap.insert({size, Range{xRange.start(), range.end()}});
tripleMap.insert({size, Interval{xRange.start(), range.end()}});
xBuffers.erase(bufferIt);
}
}
@@ -270,8 +270,8 @@ private:
auto yStart = bufferStart.lookup(y);
auto xSize = x->size;
auto ySize = y->size;
Range xSizeRange = {xStart, xStart + xSize};
Range ySizeRange = {yStart, yStart + ySize};
Interval xSizeRange = {xStart, xStart + xSize};
Interval ySizeRange = {yStart, yStart + ySize};
auto xOpRange = bufferRange.lookup(x);
auto yOpRange = bufferRange.lookup(y);
if (xOpRange.intersects(yOpRange) &&

View File

@@ -7,7 +7,6 @@ namespace mlir {
void MembarAnalysis::run() {
auto *operation = allocation->getOperation();
operation->getContext()->getOrLoadDialect<mlir::gpu::GPUDialect>();
RegionInfo regionInfo;
OpBuilder builder(operation);
dfsOperation(operation, &regionInfo, &builder);

View File

@@ -2,6 +2,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -329,6 +330,72 @@ void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// InsertSliceAsyncOp
//===----------------------------------------------------------------------===//
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type srcType, dstType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(srcType) || parser.parseArrow() ||
parser.parseCustomTypeWithFallback(dstType))
return failure();
result.addTypes(dstType);
SmallVector<Type> operandTypes;
operandTypes.push_back(srcType); // src
operandTypes.push_back(dstType); // dst
operandTypes.push_back(
IntegerType::get(parser.getBuilder().getContext(), 32)); // offset
if (allOperands.size() >= 4)
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
if (allOperands.size() >= 5)
operandTypes.push_back(triton::getPointeeType(srcType)); // other
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printInsertSliceAsyncOp(OpAsmPrinter &printer,
InsertSliceAsyncOp insertSliceAsyncOp) {
printer << " ";
printer << insertSliceAsyncOp.getOperation()->getOperands();
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(),
/*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
printer << " -> ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// ExtractSliceOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
::mlir::RegionRange regions,
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
auto srcType = operands[0].getType().cast<RankedTensorType>();
auto encoding = srcType.getEncoding();
auto srcShape = srcType.getShape();
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
if (axis < 0 || axis > srcShape.size())
return failure();
auto dstShape = srcShape.drop_front(axis + 1);
auto returnType =
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
inferredReturnTypes.assign({returnType});
return success();
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
@@ -372,13 +439,30 @@ void TritonGPUDialect::initialize() {
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding.isa<SharedEncodingAttr>())
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("copy_async should return a shared memory tensor");
} else
return op.emitOpError("copy_async should return a tensor");
}
return success();
}
static LogicalResult verify(InsertSliceAsyncOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("copy_async should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(ExtractSliceOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("extract_slice should return a shared memory tensor");
}
return success();
}
static LogicalResult verify(AllocTensorOp op) {
if (!isSharedEncoding(op.getResult())) {
return op.emitOpError("alloc_tensor should return a shared memory tensor");
}
return success();
}