From 328b87aec64a631200c768ab2989f5d3e86ce093 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 1 Sep 2022 12:37:17 -0700 Subject: [PATCH] Keren/tensor slice insert alloc (#94) This branch defines three new triton_gpu operations to partially solve #87. Below is an overview: ``` %tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A> %b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<2x16x16xf16, #A> %c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A> ``` We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.** --- bin/triton-opt.cpp | 2 +- bin/triton-translate.cpp | 4 +- include/triton/Analysis/Allocation.h | 23 +-- include/triton/Dialect/TritonGPU/IR/Dialect.h | 1 + .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 3 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 135 +++++++++++++++++- .../Dialect/TritonGPU/Transforms/Passes.td | 3 +- lib/Analysis/Allocation.cpp | 30 ++-- lib/Analysis/Membar.cpp | 1 - lib/Dialect/TritonGPU/IR/Dialect.cpp | 98 ++++++++++++- 10 files changed, 260 insertions(+), 40 deletions(-) diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 8763bbe6c..f67c14135 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -35,7 +35,7 @@ int main(int argc, char **argv) { registry .insert(); + mlir::scf::SCFDialect, mlir::gpu::GPUDialect>(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Triton (GPU) optimizer driver\n", registry)); diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index d0a766ac9..4a0355486 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -37,8 +37,8 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, mlir::DialectRegistry registry; registry - .insert(); + .insert(); context.appendDialectRegistry(registry); diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 5f465f693..c0f54232a 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -17,24 +17,24 @@ class AllocationAnalysis; } /// Modified from llvm-15.0: llvm/ADT/AddressRanges.h -/// A class that represents a range, specified using a start and an end values: -/// [Start, End). -template class Range { +/// A class that represents an interval, specified using a start and an end +/// values: [Start, End). +template class Interval { public: - Range() {} - Range(T S, T E) : Start(S), End(E) { assert(Start <= End); } + Interval() {} + Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } T start() const { return Start; } T end() const { return End; } T size() const { return End - Start; } bool contains(T Addr) const { return Start <= Addr && Addr < End; } - bool intersects(const Range &R) const { + bool intersects(const Interval &R) const { return Start < R.End && R.Start < End; } - bool operator==(const Range &R) const { + bool operator==(const Interval &R) const { return Start == R.Start && End == R.End; } - bool operator!=(const Range &R) const { return !(*this == R); } - bool operator<(const Range &R) const { + bool operator!=(const Interval &R) const { return !(*this == R); } + bool operator<(const Interval &R) const { return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); } @@ -137,8 +137,9 @@ private: : kind(kind), size(size), offset(offset), id(nextId++) {} bool intersects(const BufferT &other) const { - return Range(offset, offset + size) - .intersects(Range(other.offset, other.offset + other.size)); + return Interval(offset, offset + size) + .intersects( + Interval(other.offset, other.offset + other.size)); } }; diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 9e8605ec8..ac02470be 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -1,6 +1,7 @@ #ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 60e644cbb..2c43e22fd 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -15,7 +15,8 @@ def TritonGPU_Dialect : Dialect { }]; let dependentDialects = [ - "triton::TritonDialect" + "triton::TritonDialect", + "mlir::gpu::GPUDialect" ]; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 3bdccf8db..ef5b0914e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -58,7 +58,7 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, ]; - let results = (outs TT_Type:$result); + let results = (outs TT_Tensor:$result); // let assemblyFormat = "operands attr-dict `:` type($ptr) `->` type($result)"; let parser = [{ return parseCopyAsyncOp(parser, result); }]; @@ -97,4 +97,137 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> { let results = (outs TT_BoolLike:$result); } +def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", + [SameVariadicOperandSize, + MemoryEffects<[MemRead, MemWrite]>, + TypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 4) || std::equal_to<>()">]> { + let summary = "insert slice async"; + + let description = [{ + This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s + `$offset` argument and `$axis` attribute. + + It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`. + This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait. + + The insert_slice_async operation supports the following arguments: + + * src: the tensor that is inserted. + * dst: the tensor into which the `$src` tensor is inserted. + * offset: the offset of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into + * mask: optional tensor-rank number of boolean masks which specify which + elements of the `$src` tensor are inserted into the `$dst` tensor. + * other: optional tensor-rank number of other tensors which specify what + values are inserted into the `$dst` tensor if the corresponding + element of the `$mask` tensor is false. + + In the future, we may decompose this operation into a sequence of: + + * `async` operation to specify a sequence of asynchronous operations + * `load` operation to load a tensor from global memory + * `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor + + Example: + + ``` + %1 = triton_gpu.alloc_tensor : tensor<2x32xf32> + %2 = triton_gpu.insert_slice_async %0, %1, %offset { axis = 0 } : tensor<32x!tt.ptr, #AL> -> tensor<2x32xf32, #A> + triiton_gpu.async_wait { num = 0 : i32 } + ``` + }]; + + let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$offset, + Optional:$mask, Optional:$other, + TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, + BoolAttr:$isVolatile, I32Attr:$axis); + + let builders = [ + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, "Value":$mask, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, + OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$offset, + "Value":$mask, "Value":$other, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>, + ]; + + let results = (outs TT_Tensor:$result); + + //let assemblyFormat = [{ + // $src `,` $dst `` + // $offset, $mask, $other + // attr-dict `:` type($src) `->` type($dst) + //}]; + + // The custom parser could be replaced with oilist in LLVM-16 + let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; + + let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; + + // result needs to be of shared layout + let verifier = [{ return ::verify(*this); }]; +} + +def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInterface]> { + let summary = "extract slice"; + let description = [{ + The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as + specified by the operation's `$offset` and `$axis` arguments. + + The extract_slice operation supports the following arguments: + + * src: the tensor that is extracted from. + * offset: the offset at the given `$axis` from which the `$src` tensor is extracted + + Example: + + ``` + // Rank-reducing extract_slice. + %1 = tensor.extract_slice %0, %offset {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32> + ``` + }]; + + let arguments = (ins TT_Tensor:$src, I32:$offset, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{$src `,` $offset attr-dict `:` type($src) `->` type($result)}]; + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes); + }]; + + // result needs to be of shared layout + let verifier = [{ return ::verify(*this); }]; +} + +def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> { + let summary = "allocate tensor"; + + let description = [{ + This operation defines a tensor of a particular shape. + The contents of the tensor are supposed to be in shared memory. + + Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16. + }]; + + let assemblyFormat = [{attr-dict `:` type($result)}]; + + let results = (outs TT_Tensor:$result); + + // result needs to be of shared layout + let verifier = [{ return ::verify(*this); }]; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 07cda2ee5..62afedfdd 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -80,7 +80,8 @@ def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> { let constructor = "mlir::createTritonGPUVerifier()"; - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::gpu::GPUDialect"]; } #endif diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 3ea775415..7a48ae778 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -27,7 +27,7 @@ private: /// Value -> Liveness Range /// Use MapVector to ensure determinism. - using BufferRangeMapT = llvm::MapVector>; + using BufferRangeMapT = llvm::MapVector>; /// Nodes -> Nodes using GraphT = DenseMap>; @@ -110,7 +110,7 @@ private: /// Computes the liveness range of the allocated value. /// Each buffer is allocated only once. void resolveExplicitBufferLiveness( - function_ref(Value value)> getLiveness) { + function_ref(Value value)> getLiveness) { for (auto valueBufferIter : allocation->valueBuffer) { auto value = valueBufferIter.first; auto *buffer = valueBufferIter.second; @@ -122,7 +122,7 @@ private: /// values because each allocated buffer could be an alias of others, if block /// arguments are involved. void resolveAliasBufferLiveness( - function_ref(Value value)> getLiveness) { + function_ref(Value value)> getLiveness) { for (auto aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; @@ -135,7 +135,7 @@ private: minId = std::min(minId, bufferRange[buffer].start()); maxId = std::max(maxId, bufferRange[buffer].end()); } - bufferRange[buffer] = Range(minId, maxId); + bufferRange[buffer] = Interval(minId, maxId); } } } @@ -151,8 +151,8 @@ private: // range. auto *op = opScratchIter.first; auto *buffer = opScratchIter.second; - bufferRange.insert( - {buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)}); + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); } } @@ -179,7 +179,7 @@ private: maxId = operationId[liveOp] + 1; } }); - return Range(minId, maxId); + return Interval(minId, maxId); }; resolveExplicitBufferLiveness(getValueLivenessRange); @@ -223,9 +223,9 @@ private: // |---------------------------------------------| liveness range // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... /// Start -> Liveness Range - using TripleMapT = std::multimap>; + using TripleMapT = std::multimap>; TripleMapT tripleMap; - tripleMap.insert(std::make_pair(0, Range())); + tripleMap.insert(std::make_pair(0, Interval())); SmallVector xBuffers = buffers; while (!xBuffers.empty()) { auto tripleIt = tripleMap.begin(); @@ -246,12 +246,12 @@ private: auto xRange = bufferRange.lookup(buffer); bufferStart[buffer] = size; tripleMap.insert( - {size + xSize, Range{std::max(range.start(), xRange.start()), - std::min(range.end(), xRange.end())}}); + {size + xSize, Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); if (range.start() < xRange.start()) - tripleMap.insert({size, Range{range.start(), xRange.end()}}); + tripleMap.insert({size, Interval{range.start(), xRange.end()}}); if (xRange.end() < range.end()) - tripleMap.insert({size, Range{xRange.start(), range.end()}}); + tripleMap.insert({size, Interval{xRange.start(), range.end()}}); xBuffers.erase(bufferIt); } } @@ -270,8 +270,8 @@ private: auto yStart = bufferStart.lookup(y); auto xSize = x->size; auto ySize = y->size; - Range xSizeRange = {xStart, xStart + xSize}; - Range ySizeRange = {yStart, yStart + ySize}; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; auto xOpRange = bufferRange.lookup(x); auto yOpRange = bufferRange.lookup(y); if (xOpRange.intersects(yOpRange) && diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index b3939fc2d..f3fb1b083 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -7,7 +7,6 @@ namespace mlir { void MembarAnalysis::run() { auto *operation = allocation->getOperation(); - operation->getContext()->getOrLoadDialect(); RegionInfo regionInfo; OpBuilder builder(operation); dfsOperation(operation, ®ionInfo, &builder); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 91fec7d34..35f5aad0a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2,6 +2,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -329,6 +330,72 @@ void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) { printer.printStrippedAttrOrType(copyAsyncOp.result().getType()); } +//===----------------------------------------------------------------------===// +// InsertSliceAsyncOp +//===----------------------------------------------------------------------===// + +ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, + OperationState &result) { + SmallVector allOperands; + Type srcType, dstType; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseCustomTypeWithFallback(srcType) || parser.parseArrow() || + parser.parseCustomTypeWithFallback(dstType)) + return failure(); + result.addTypes(dstType); + + SmallVector operandTypes; + operandTypes.push_back(srcType); // src + operandTypes.push_back(dstType); // dst + operandTypes.push_back( + IntegerType::get(parser.getBuilder().getContext(), 32)); // offset + if (allOperands.size() >= 4) + operandTypes.push_back(triton::getI1SameShape(srcType)); // mask + if (allOperands.size() >= 5) + operandTypes.push_back(triton::getPointeeType(srcType)); // other + + if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, + result.operands)) + return failure(); + return success(); +} + +void printInsertSliceAsyncOp(OpAsmPrinter &printer, + InsertSliceAsyncOp insertSliceAsyncOp) { + printer << " "; + printer << insertSliceAsyncOp.getOperation()->getOperands(); + printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(), + /*elidedAttrs=*/{}); + printer << " : "; + printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType()); + printer << " -> "; + printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); +} + +//===----------------------------------------------------------------------===// +// ExtractSliceOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult ExtractSliceOp::inferReturnTypes( + ::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + auto srcType = operands[0].getType().cast(); + auto encoding = srcType.getEncoding(); + auto srcShape = srcType.getShape(); + auto axis = attributes.get("axis").cast().getInt(); + if (axis < 0 || axis > srcShape.size()) + return failure(); + auto dstShape = srcShape.drop_front(axis + 1); + auto returnType = + RankedTensorType::get(dstShape, srcType.getElementType(), encoding); + inferredReturnTypes.assign({returnType}); + return success(); +} + //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// @@ -372,13 +439,30 @@ void TritonGPUDialect::initialize() { //===----------------------------------------------------------------------===// static LogicalResult verify(CopyAsyncOp op) { - Type resType = op.getResult().getType(); - if (auto tensorType = resType.dyn_cast()) { - Attribute encoding = tensorType.getEncoding(); - if (!encoding.isa()) - return op.emitOpError("copy_async should return a shared memory tensor"); - } else - return op.emitOpError("copy_async should return a tensor"); + if (!isSharedEncoding(op.getResult())) { + return op.emitOpError("copy_async should return a shared memory tensor"); + } + return success(); +} + +static LogicalResult verify(InsertSliceAsyncOp op) { + if (!isSharedEncoding(op.getResult())) { + return op.emitOpError("copy_async should return a shared memory tensor"); + } + return success(); +} + +static LogicalResult verify(ExtractSliceOp op) { + if (!isSharedEncoding(op.getResult())) { + return op.emitOpError("extract_slice should return a shared memory tensor"); + } + return success(); +} + +static LogicalResult verify(AllocTensorOp op) { + if (!isSharedEncoding(op.getResult())) { + return op.emitOpError("alloc_tensor should return a shared memory tensor"); + } return success(); }