From e0bedeb44cee1835483896108c0e54d64078e762 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 18 Aug 2022 12:32:57 -0700 Subject: [PATCH] [BACKEND] Keren/shared memory barrier (#59) --- bin/triton-opt.cpp | 2 + include/triton/Analysis/Allocation.h | 168 +++++--- include/triton/Analysis/Membar.h | 115 ++++++ include/triton/Dialect/Triton/IR/Dialect.h | 1 + include/triton/Dialect/Triton/IR/TritonOps.td | 10 +- lib/Analysis/Allocation.cpp | 387 ++++++++++-------- lib/Analysis/CMakeLists.txt | 1 + lib/Analysis/Membar.cpp | 95 +++++ test/Analysis/test-allocation.mlir | 147 +++++-- test/Analysis/test-membar.mlir | 178 ++++++++ test/lib/Analysis/CMakeLists.txt | 1 + test/lib/Analysis/TestAllocation.cpp | 27 +- test/lib/Analysis/TestMembar.cpp | 50 +++ 13 files changed, 904 insertions(+), 278 deletions(-) create mode 100644 include/triton/Analysis/Membar.h create mode 100644 lib/Analysis/Membar.cpp create mode 100644 test/Analysis/test-membar.mlir create mode 100644 test/lib/Analysis/TestMembar.cpp diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 359889582..06c5ec519 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -14,6 +14,7 @@ namespace mlir { namespace test { void registerTestAlignmentPass(); void registerTestAllocationPass(); +void registerTestMembarPass(); } // namespace test } // namespace mlir @@ -23,6 +24,7 @@ int main(int argc, char **argv) { mlir::registerTritonGPUPasses(); mlir::test::registerTestAlignmentPass(); mlir::test::registerTestAllocationPass(); + mlir::test::registerTestMembarPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerConvertTritonGPUToLLVMPass(); diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index b489f5142..64c870cd0 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -4,23 +4,28 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/raw_ostream.h" -#include #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include namespace mlir { +namespace triton { +class AllocationAnalysis; +} + /// Modified from llvm-15.0: llvm/ADT/AddressRanges.h -/// A class that represents an address range. The range is specified using -/// a start and an end address: [Start, End). -template class Range { +/// A class that represents a range, specified using a start and an end values: +/// [Start, End). +template class Range { public: Range() {} - Range(AddrT S, AddrT E) : Start(S), End(E) { assert(Start <= End); } - AddrT start() const { return Start; } - AddrT end() const { return End; } - AddrT size() const { return End - Start; } - bool contains(AddrT Addr) const { return Start <= Addr && Addr < End; } + Range(T S, T E) : Start(S), End(E) { assert(Start <= End); } + T start() const { return Start; } + T end() const { return End; } + T size() const { return End - Start; } + bool contains(T Addr) const { return Start <= Addr && Addr < End; } bool intersects(const Range &R) const { return Start < R.End && R.Start < End; } @@ -33,83 +38,122 @@ public: } private: - AddrT Start = std::numeric_limits::min(); - AddrT End = std::numeric_limits::max(); + T Start = std::numeric_limits::min(); + T End = std::numeric_limits::max(); }; -//===----------------------------------------------------------------------===// -// Shared Memory Allocation Analysis -//===----------------------------------------------------------------------===// -class AllocationAnalysis { +class Allocation { public: - using ValueSizeMapT = llvm::DenseMap; + /// A unique identifier for shared memory buffers + using BufferId = size_t; + static constexpr BufferId InvalidBufferId = + std::numeric_limits::max(); -public: /// Creates a new Allocation analysis that computes the shared memory /// information for all associated shared memory values. - AllocationAnalysis(Operation *operation) : operation(operation) { run(); } + Allocation(Operation *operation) : operation(operation) { run(); } /// Returns the operation this analysis was constructed from. Operation *getOperation() const { return operation; } - /// Returns the offset of the given value in the shared memory. - size_t getOffset(Value value) const { return valueOffset.lookup(value); } + /// Returns the offset of the given buffer in the shared memory. + size_t getOffset(BufferId bufferId) const { + return bufferSet.lookup(bufferId).offset; + } - /// Returns the size of the given value in the shared memory. - size_t getAllocatedSize(Value value) const { return valueSize.lookup(value); } + /// Returns the size of the given buffer in the shared memory. + size_t getAllocatedSize(BufferId bufferId) const { + return bufferSet.lookup(bufferId).size; + } + + /// Returns the buffer id of the given value. + BufferId getBufferId(Value value) const { + if (valueBuffer.count(value)) { + return valueBuffer.lookup(value)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns the scratch buffer id of the given value. + BufferId getBufferId(Operation *operation) const { + if (opScratch.count(operation)) { + return opScratch.lookup(operation)->id; + } else { + return InvalidBufferId; + } + } /// Returns the size of total shared memory allocated size_t getSharedMemorySize() const { return sharedMemorySize; } -private: - /// Value -> Range - /// Use MapVector to ensure determinism. - using ValueRangeMapT = llvm::MapVector>; - /// Start -> Range - using TripleMapT = std::multimap>; - /// Nodes -> Nodes - using GraphT = DenseMap>; + bool isIntersected(BufferId lhsId, BufferId rhsId) const { + if (lhsId == InvalidBufferId || rhsId == InvalidBufferId) + return false; + auto lhsBuffer = bufferSet.lookup(lhsId); + auto rhsBuffer = bufferSet.lookup(rhsId); + return lhsBuffer.intersects(rhsBuffer); + } +private: + /// A class that represents a shared memory buffer + struct BufferT { + enum class BufferKind { Explicit, Scratch }; + + /// MT: thread-safe + inline static std::atomic nextId = 0; + + BufferKind kind; + BufferId id; + size_t size; + size_t offset; + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT() : BufferT(BufferKind::Explicit) {} + BufferT(BufferKind kind) : BufferT(kind, 0, 0) {} + BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {} + BufferT(BufferKind kind, size_t size, size_t offset) + : kind(kind), size(size), offset(offset), id(nextId++) {} + + bool intersects(const BufferT &other) const { + return Range(offset, offset + size) + .intersects(Range(other.offset, other.offset + other.size)); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = DenseMap; + /// Value -> Explicit Buffer + using ValueBufferMapT = DenseMap; + /// BufferId -> Buffer + using BufferSetT = DenseMap; /// Runs allocation analysis on the given top-level operation. void run(); - /// Resolves liveness of all values involved under the root operation. - void resolveLiveness(ValueRangeMapT &valueRangeMap); - - /// Computes the shared memory offsets for all related values. - /// Paper: Algorithms for Compile-Time Memory Optimization - /// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf) - void computeOffsets(const ValueRangeMapT &valueRangeMap); - - /// Gets shared memory value and size from valueRangeMap. - void getSharedMemoryValuesAndSizes(const ValueRangeMapT &valueRangeMap, - SmallVector &sharedMemoryValues); - - /// Computes the initial shared memory offsets. - void calculateSharedMemoryStarts(const ValueRangeMapT &valueRangeMap, - const SmallVector &sharedMemoryValues, - ValueSizeMapT &sharedMemoryStart); - - /// Builds a graph of all shared memory values. Edges are created between - /// between shared memory values that are overlapping. - void buildInterferenceGraph(const ValueRangeMapT &valueRangeMap, - const SmallVector &sharedMemoryValues, - const ValueSizeMapT &sharedMemoryStart, - GraphT &interference); - - /// Finalizes shared memory offsets considering interference. - void allocateSharedMemory(const ValueRangeMapT &valueRangeMap, - const SmallVector &sharedMemoryValues, - const ValueSizeMapT &sharedMemoryStart, - const GraphT &interference); +private: + template + void addBuffer(KeyType &key, Args &&... args) { + auto buffer = BufferT(Kind, std::forward(args)...); + bufferSet[buffer.id] = std::move(buffer); + if constexpr (Kind == BufferT::BufferKind::Explicit) { + valueBuffer[key] = &bufferSet[buffer.id]; + } else { + opScratch[key] = &bufferSet[buffer.id]; + } + } private: Operation *operation; - ValueSizeMapT valueOffset; - ValueSizeMapT valueSize; + OpScratchMapT opScratch; + ValueBufferMapT valueBuffer; + BufferSetT bufferSet; size_t sharedMemorySize = 0; + + friend class triton::AllocationAnalysis; }; } // namespace mlir -#endif +#endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h new file mode 100644 index 000000000..383c7a053 --- /dev/null +++ b/include/triton/Analysis/Membar.h @@ -0,0 +1,115 @@ +#ifndef TRITON_ANALYSIS_MEMBAR_H +#define TRITON_ANALYSIS_MEMBAR_H + +#include "Allocation.h" +#include "llvm/ADT/SmallPtrSet.h" + +namespace mlir { + +class OpBuilder; + +//===----------------------------------------------------------------------===// +// Shared Memory Barrier Analysis +//===----------------------------------------------------------------------===// +class MembarAnalysis { +public: + /// Creates a new Membar analysis that generates the shared memory barrier + /// in the following circumstances: + /// - RAW: If a shared memory write is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// - WAR: If a shared memory read is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// The following circumstances do not require a barrier: + /// - WAW: not possible because overlapped memory allocation is not allowed. + /// - RAR: no write is performed. + /// Temporary storage of operations such as Reduce are considered as both + /// a shared memory read. If the temporary storage is written but not read, + /// it is considered as the problem of the operation itself but not the membar + /// analysis. + /// The following circumstances are not considered yet: + /// - Double buffers + /// - N buffers + MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); } + +private: + struct RegionInfo { + using BufferIdSetT = DenseSet; + + BufferIdSetT syncReadBuffers; + BufferIdSetT syncWriteBuffers; + + RegionInfo() = default; + RegionInfo(const BufferIdSetT &syncReadBuffers, + const BufferIdSetT &syncWriteBuffers) + : syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) { + } + + /// Unions two RegionInfo objects. + void join(const RegionInfo &other) { + syncReadBuffers.insert(other.syncReadBuffers.begin(), + other.syncReadBuffers.end()); + syncWriteBuffers.insert(other.syncWriteBuffers.begin(), + other.syncWriteBuffers.end()); + } + + /// Returns true if buffers in two RegionInfo objects are intersected. + bool isIntersected(const RegionInfo &other, Allocation *allocation) const { + return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers, + allocation) || + /*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers, + allocation); + } + + /// Clears the buffers because a barrier is inserted. + void sync() { + syncReadBuffers.clear(); + syncWriteBuffers.clear(); + } + + private: + /// Returns true if buffers in two sets are intersected. + bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs, + Allocation *allocation) const { + return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) { + return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) { + return allocation->isIntersected(lhsId, rhsId); + }); + }); + } + }; + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(); + + /// Applies the barrier analysis based on the SCF dialect, in which each + /// region has a single basic block only. + /// Example: + /// region1 + /// op1 + /// op2 (scf.if) + /// region2 + /// op3 + /// op4 + /// region3 + /// op5 + /// op6 + /// op7 + /// region2 and region3 started with the information of region1. + /// Each region is analyzed separately and keeps their own copy of the + /// information. At op7, we union the information of the region2 and region3 + /// and update the information of region1. + void dfsOperation(Operation *operation, RegionInfo *blockInfo, + OpBuilder *builder); + + /// Updates the RegionInfo operation based on the operation. + void transfer(Operation *operation, RegionInfo *blockInfo, + OpBuilder *builder); + +private: + Allocation *allocation; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_MEMBAR_H diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 8590db9c4..725aed72d 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -1,6 +1,7 @@ #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index bcd96ef53..99e593855 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -228,13 +228,13 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, def TT_ReduceOp : TT_Op<"reduce"> { let summary = "reduce"; - let arguments = (ins TT_RedOpAttr:$redOp, TT_Type:$operand, I32Attr:$axis); + let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); - let results = (outs TT_Type:$result); + let results = (outs TT_Tensor:$result); - // let builders = [ - // OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>, - // ]; + let builders = [ + OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>, + ]; let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 8d5b6afcc..5a5b09832 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -1,5 +1,6 @@ #include "triton/Analysis/Allocation.h" #include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -10,191 +11,255 @@ namespace mlir { -void AllocationAnalysis::run() { - ValueRangeMapT valueRange; - resolveLiveness(valueRange); - computeOffsets(valueRange); -} +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton { +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, Allocation *allocation) + : operation(operation), allocation(allocation) { + run(); + } -void AllocationAnalysis::resolveLiveness( - AllocationAnalysis::ValueRangeMapT &valueRange) { - Liveness liveness(operation); - DenseMap operationIds; - operation->walk([&](Operation *op) { - operationIds.insert({op, operationIds.size()}); - }); +private: + using BufferT = Allocation::BufferT; - operation->walk([&](Operation *op) { + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { for (Value result : op->getResults()) { - auto liveOperations = liveness.resolveLiveness(result); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); - std::for_each(liveOperations.begin(), liveOperations.end(), - [&](Operation *liveOp) { - if (operationIds[liveOp] < minId) { - minId = operationIds[liveOp]; - } - if (operationIds[liveOp] > maxId) { - maxId = operationIds[liveOp]; - } - }); - valueRange.insert({result, Range(minId, maxId + 1)}); + auto type = result.getType(); + if (auto tensorType = type.dyn_cast()) { + auto encoding = tensorType.getEncoding(); + if (encoding && + encoding.isa()) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto bytes = tensorType.getNumElements() * + tensorType.getElementTypeBitWidth() / 8; + allocation->addBuffer(result, bytes); + } + } } - }); -} + } -void AllocationAnalysis::getSharedMemoryValuesAndSizes( - const AllocationAnalysis::ValueRangeMapT &valueRange, - SmallVector &sharedMemoryValues) { - for (auto &valueRange : valueRange) { - auto value = valueRange.first; - auto type = value.getType(); - if (auto tensorType = type.dyn_cast()) { - auto encoding = tensorType.getEncoding(); - if (encoding && - encoding.isa()) { - // Bytes could be a different value once we support padding or other - // allocation policies. + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + // TODO(Keren): Add atomic ops + // TODO(Keren): Add convert ops + if (auto reduceOp = dyn_cast(op)) { + // TODO(Keren): Reduce with index is not supported yet. + auto value = op->getOperand(0); + if (auto tensorType = value.getType().dyn_cast()) { auto bytes = tensorType.getNumElements() * tensorType.getElementTypeBitWidth() / 8; - sharedMemoryValues.emplace_back(value); - valueSize.insert({value, bytes}); + allocation->addBuffer(op, bytes); } } } -} -void AllocationAnalysis::calculateSharedMemoryStarts( - const AllocationAnalysis::ValueRangeMapT &valueRange, - const SmallVector &sharedMemoryValues, - ValueSizeMapT &sharedMemoryStart) { - // v = values in shared memory - // t = triplet of (size, start, end) - // shared memory space - // - - // | *******t4 - // | /|\ v2 inserts t4, t5, and t6 - // | | - // | ******t5 ************t6 - // | ^^^^^v2^^^^^^ - // | | *********************t2 - // | \|/ v2 erases t1 - // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 - // |---------------------------------------------| liveness range - // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... - TripleMapT tripleMap; - tripleMap.insert(std::make_pair(0, Range())); - SmallVector values = sharedMemoryValues; - while (!values.empty()) { - auto tripleIt = tripleMap.begin(); - auto size = tripleIt->first; - auto range = tripleIt->second; - tripleMap.erase(tripleIt); - auto valueIt = std::find_if(values.begin(), values.end(), [&](Value value) { - auto xRange = valueRange.lookup(value); - bool res = xRange.intersects(range); - for (auto val : tripleMap) - res = res && !val.second.intersects(xRange); - return res; + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); }); - if (valueIt != values.end()) { - auto value = *valueIt; - auto xSize = valueSize.lookup(value); - auto xRange = valueRange.lookup(value); - sharedMemoryStart[value] = size; - tripleMap.insert( - {size + xSize, Range{std::max(range.start(), xRange.start()), - std::min(range.end(), xRange.end())}}); - if (range.start() < xRange.start()) - tripleMap.insert({size, Range{range.start(), xRange.end()}}); - if (xRange.end() < range.end()) - tripleMap.insert({size, Range{xRange.start(), range.end()}}); - values.erase(valueIt); - } } -} -void AllocationAnalysis::buildInterferenceGraph( - const AllocationAnalysis::ValueRangeMapT &valueRange, - const SmallVector &sharedMemoryValues, - const ValueSizeMapT &sharedMemoryStart, GraphT &interference) { - for (auto x : sharedMemoryValues) { - for (auto y : sharedMemoryValues) { - if (x == y) - continue; - auto xStart = sharedMemoryStart.lookup(x); - auto yStart = sharedMemoryStart.lookup(y); - auto xSize = valueSize.lookup(x); - auto ySize = valueSize.lookup(y); - Range xSizeRange = {xStart, xStart + xSize}; - Range ySizeRange = {yStart, yStart + ySize}; - auto xOpRange = valueRange.lookup(x); - auto yOpRange = valueRange.lookup(y); - if (xOpRange.intersects(yOpRange) && xSizeRange.intersects(ySizeRange)) { - interference[x].insert(y); + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // In the SCF dialect, we always have a sequentially nested structure of + // blocks + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + Liveness liveness(operation); + operation->walk([&](Operation *op) { + for (Value result : op->getResults()) { + auto liveOperations = liveness.resolveLiveness(result); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if (operationId[liveOp] > maxId) { + maxId = operationId[liveOp]; + } + }); + if (allocation->valueBuffer.count(result)) { + auto *buffer = allocation->valueBuffer[result]; + bufferRange.insert({buffer, Range(minId, maxId + 1)}); + } + } + if (allocation->opScratch.count(op)) { + // Any scratch memory's live range is the current operation's live + // range. + auto *buffer = allocation->opScratch[op]; + bufferRange.insert( + {buffer, Range(operationId[op], operationId[op] + 1)}); + } + }); + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf) + void computeOffsets() { + SmallVector buffers; + for (auto bufferIter : bufferRange) { + buffers.emplace_back(bufferIter.first); + } + + DenseMap bufferStart; + calculateStarts(buffers, bufferStart); + + GraphT interference; + buildInterferenceGraph(buffers, bufferStart, interference); + + allocate(buffers, bufferStart, interference); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers, + DenseMap &bufferStart) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + /// Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Range())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto size = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (auto val : tripleMap) + res = res && !val.second.intersects(xRange); + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + bufferStart[buffer] = size; + tripleMap.insert( + {size + xSize, Range{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + if (range.start() < xRange.start()) + tripleMap.insert({size, Range{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({size, Range{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); } } } -} -void AllocationAnalysis::allocateSharedMemory( - const AllocationAnalysis::ValueRangeMapT &valueRangeMap, - const SmallVector &sharedMemoryValues, - const AllocationAnalysis::ValueSizeMapT &sharedMemoryStart, - const AllocationAnalysis::GraphT &interference) { - // First-fit graph coloring - // Neighbors are nodes that interfere with each other. - // We color a node by finding the index of the first available non-neighboring - // node or the first neighboring node without any color. - // Nodes with the same color do not interfere with each other. - DenseMap colors; - for (auto value : sharedMemoryValues) { - colors[value] = (value == sharedMemoryValues[0]) ? 0 : -1; - } - SmallVector available(sharedMemoryValues.size()); - for (auto x : sharedMemoryValues) { - std::fill(available.begin(), available.end(), true); - for (auto y : interference.lookup(x)) { - int color = colors[y]; - if (color >= 0) { - available[color] = false; + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + const DenseMap &bufferStart, + GraphT &interference) { + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = bufferStart.lookup(x); + auto yStart = bufferStart.lookup(y); + auto xSize = x->size; + auto ySize = y->size; + Range xSizeRange = {xStart, xStart + xSize}; + Range ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } } } - auto it = std::find(available.begin(), available.end(), true); - colors[x] = std::distance(available.begin(), it); } - // Finalize allocation - // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) - // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) - // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) - // TODO(Keren): We are wasting memory here. - // Nodes with color2 can actually start with 24. - for (auto x : sharedMemoryValues) { - size_t adj = 0; - for (auto y : interference.lookup(x)) { - adj = std::max(adj, sharedMemoryStart.lookup(y) + valueSize.lookup(y)); + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const DenseMap &bufferStart, + const GraphT &interference) { + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t adj = 0; + for (auto y : interference.lookup(x)) { + adj = std::max(adj, bufferStart.lookup(y) + y->size); + } + x->offset = bufferStart.lookup(x) + colors.lookup(x) * adj; + allocation->sharedMemorySize = + std::max(allocation->sharedMemorySize, x->offset + x->size); } - valueOffset[x] = sharedMemoryStart.lookup(x) + colors.lookup(x) * adj; - sharedMemorySize = - std::max(sharedMemorySize, valueOffset[x] + valueSize.lookup(x)); } -} -void AllocationAnalysis::computeOffsets( - const AllocationAnalysis::ValueRangeMapT &valueRange) { - SmallVector sharedMemoryValues; - getSharedMemoryValuesAndSizes(valueRange, sharedMemoryValues); +private: + Operation *operation; + Allocation *allocation; + BufferRangeMapT bufferRange; +}; +} // namespace triton - ValueSizeMapT sharedMemoryStart; - calculateSharedMemoryStarts(valueRange, sharedMemoryValues, - sharedMemoryStart); - - GraphT interference; - buildInterferenceGraph(valueRange, sharedMemoryValues, sharedMemoryStart, - interference); - - allocateSharedMemory(valueRange, sharedMemoryValues, sharedMemoryStart, - interference); -} +void Allocation::run() { triton::AllocationAnalysis(getOperation(), this); } } // namespace mlir diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index a0d2092d5..a4acf328f 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(TritonAnalysis AxisInfo.cpp Allocation.cpp + Membar.cpp DEPENDS TritonGPUAttrDefsIncGen diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp new file mode 100644 index 000000000..ab3b0d68c --- /dev/null +++ b/lib/Analysis/Membar.cpp @@ -0,0 +1,95 @@ +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "mlir/Dialect/GPU/GPUDialect.h" + +namespace mlir { + +void MembarAnalysis::run() { + auto *operation = allocation->getOperation(); + operation->getContext()->getOrLoadDialect(); + RegionInfo regionInfo; + OpBuilder builder(operation); + dfsOperation(operation, ®ionInfo, &builder); +} + +void MembarAnalysis::dfsOperation(Operation *operation, + RegionInfo *parentRegionInfo, + OpBuilder *builder) { + transfer(operation, parentRegionInfo, builder); + if (operation->getNumRegions()) { + // If there's any nested regions, we need to visit them. + // scf.if and scf.else: two regions + // scf.if only: two regions + // scf.for: one region + RegionInfo curRegionInfo; + for (auto ®ion : operation->getRegions()) { + // Copy the parent info as the current info. + RegionInfo regionInfo = *parentRegionInfo; + for (auto &block : region.getBlocks()) { + assert(region.getBlocks().size() == 1 && + "Multiple blocks in a region is not supported"); + for (auto &op : block.getOperations()) { + // Traverse the nested operation. + dfsOperation(&op, ®ionInfo, builder); + } + } + curRegionInfo.join(regionInfo); + } + // Set the parent region info as the union of the nested region info. + *parentRegionInfo = curRegionInfo; + } +} + +void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, + OpBuilder *builder) { + if (op->getNumResults() < 1) + return; + + if (dyn_cast(op)) { + // If the current op is a barrier, we sync previous reads and writes + regionInfo->sync(); + return; + } + + if (dyn_cast(op)) { + // If the current op is an async wait, we insert a barrier op and sync + // previous reads and writes. + OpBuilder::InsertionGuard g(*builder); + builder->setInsertionPointAfter(op); + builder->create(op->getLoc()); + regionInfo->sync(); + return; + } + + auto addBuffer = [&](RegionInfo::BufferIdSetT &bufferSet, + Allocation::BufferId bufferId) { + if (bufferId != Allocation::InvalidBufferId) { + bufferSet.insert(bufferId); + } + }; + + RegionInfo curRegionInfo; + for (Value value : op->getOperands()) { + // ConvertLayoutOp: shared memory -> registers + addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(value)); + } + for (Value value : op->getResults()) { + // ConvertLayoutOp: registers -> shared memory + addBuffer(curRegionInfo.syncWriteBuffers, allocation->getBufferId(value)); + } + // Scratch buffer is considered as a shared memory read + addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(op)); + + if (regionInfo->isIntersected(curRegionInfo, allocation)) { + OpBuilder::InsertionGuard g(*builder); + builder->setInsertionPoint(op); + builder->create(op->getLoc()); + regionInfo->sync(); + } + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + regionInfo->join(curRegionInfo); +} + +} // namespace mlir diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 5da49a17f..9dc541363 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -6,6 +6,7 @@ #B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +// CHECK-LABEL: matmul_loop func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -24,7 +25,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: offset = 0, size = 8192 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> - // CHECK: offset = 8192, size = 8192 + // CHECK-NEXT: offset = 8192, size = 8192 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> @@ -34,11 +35,12 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return - // CHECK: size = 16384 + // CHECK-NEXT: size = 16384 } // Shared memory is available after a tensor's liveness range ends -func @synthesized_reusable(%A : !tt.ptr) { +// CHECK-LABEL: reusable +func @reusable(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %cst3 = arith.constant dense : tensor<32x128xi1, #AL> @@ -51,95 +53,162 @@ func @synthesized_reusable(%A : !tt.ptr) { // CHECK: offset = 0, size = 8192 %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK: offset = 8192, size = 8192 + // CHECK-NEXT: offset = 8192, size = 8192 %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: offset = 16384, size = 8192 + // CHECK-NEXT: offset = 16384, size = 8192 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK: offset = 0, size = 8192 + // CHECK-NEXT: offset = 0, size = 8192 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> return - // CHECK: size = 24576 + // CHECK-NEXT: size = 24576 } // A tensor's shared memory offset is larger than it needs to accommodate further tensors // %cst0->%c // %cst1->%cst4 // %cst3->%g->%h->%i -func @synthesize_preallocate(%A : !tt.ptr) { +// CHECK-LABEL: preallocate +func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1024, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1536, size = 512 + // CHECK-NEXT: offset = 1536, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 2048, size = 1024 + // CHECK-NEXT: offset = 2048, size = 1024 %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK: offset = 3072, size = 1024 + // CHECK-NEXT: offset = 3072, size = 1024 %b = tt.cat %cst0, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK: offset = 0, size = 1024 + // CHECK-NEXT: offset = 0, size = 1024 %c = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 1024, size = 1024 %cst4 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A> - // CHECK: offset = 6144, size = 2048 + // CHECK-NEXT: offset = 6144, size = 2048 %e = tt.cat %a, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> - // CHECK: offset = 8192, size = 2048 + // CHECK-NEXT: offset = 8192, size = 2048 %d = tt.cat %b, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> - // CHECK: offset = 10240, size = 2048 + // CHECK-NEXT: offset = 10240, size = 2048 %f = tt.cat %c, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> - // CHECK: offset = 0, size = 2048 + // CHECK-NEXT: offset = 0, size = 2048 %cst5 = arith.constant dense<0.000000e+00> : tensor<64x16xf16, #A> - // CHECK: offset = 2048, size = 4096 + // CHECK-NEXT: offset = 2048, size = 4096 %g = tt.cat %e, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> - // CHECK: offset = 2048, size = 4096 + // CHECK-NEXT: offset = 2048, size = 4096 %h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> - // CHECK: offset = 2048, size = 4096 + // CHECK-NEXT: offset = 2048, size = 4096 %i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> return - // CHECK: size = 12288 + // CHECK-NEXT: size = 12288 } // Unused tensors are immediately released -func @synthesize_unused(%A : !tt.ptr) { +// CHECK-LABEL: unused +func @unused(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A> - // CHECK: offset = 0, size = 512 + // CHECK-NEXT: offset = 0, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 512, size = 512 + // CHECK-NEXT: offset = 512, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1024, size = 1024 + // CHECK-NEXT: offset = 1024, size = 1024 %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> return // CHECK: size = 2048 } // cst0 is alive through the entire function, it cannot be released before the end of the function -func @synthesize_longlive(%A : !tt.ptr) { +// CHECK-LABEL: longlive +func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 512, size = 512 + // CHECK-NEXT: offset = 512, size = 512 %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1024, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1536, size = 1024 + // CHECK-NEXT: offset = 1536, size = 1024 %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK: offset = 512, size = 512 + // CHECK-NEXT: offset = 512, size = 512 %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1024, size = 512 + // CHECK-NEXT: offset = 1024, size = 512 %cst4 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1536, size = 1024 + // CHECK-NEXT: offset = 1536, size = 1024 %b = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK: offset = 1536, size = 512 + // CHECK-NEXT: offset = 1536, size = 512 %cst5 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1536, size = 512 + // CHECK-NEXT: offset = 1536, size = 512 %cst6 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - // CHECK: offset = 1536, size = 1024 + // CHECK-NEXT: offset = 1536, size = 1024 %c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - // CHECK: offset = 512, size = 1024 + // CHECK-NEXT: offset = 512, size = 1024 %d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> return - // CHECK: size = 2560 + // CHECK-NEXT: size = 2560 +} + +// CHECK-LABEL: scratch +func @scratch() { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 1056, size = 1024 + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK-NEXT: scratch offset = 32, size = 1024 + // CHECK-NEXT: offset = 0, size = 32 + %b = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #A> -> tensor<16xf16, #A> + return + // CHECK-NEXT: size = 2080 +} + +// B0 -> (B1) -> B0 +// Memory used by B1 can be reused by B0. +// CHECK-LABEL: multi_blocks_reuse +func @multi_blocks_reuse(%i1 : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 512, size = 512 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + scf.if %i1 { + // CHECK-NEXT: offset = 1024, size = 1024 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK-NEXT: offset = 1024, size = 1024 + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + } + // CHECK-NEXT: offset = 0, size = 512 + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 512, size = 512 + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 1024, size = 1024 + %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + return + // CHECK-NEXT: size = 2048 +} + +// B0 -> (B1) -> (B2) -> B0 +// Memory used by B0 cannot be reused by B1 or B2. +// CHECK-LABEL: multi_blocks_noreuse +func @multi_blocks_noreuse(%i1 : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 512, size = 512 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + scf.if %i1 { + // CHECK-NEXT: offset = 1024, size = 1024 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK-NEXT: offset = 1024, size = 1024 + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + } else { + // CHECK-NEXT: offset = 1024, size = 512 + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 1536, size = 512 + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: offset = 2048, size = 1024 + %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + } + // CHECK-NEXT: offset = 1024, size = 1024 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + return + // CHECK-NEXT: size = 3072 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir new file mode 100644 index 000000000..23ccfcdda --- /dev/null +++ b/test/Analysis/test-membar.mlir @@ -0,0 +1,178 @@ +// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> + +// CHECK-LABEL: matmul_loop +func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + // CHECK: Membar 13 + %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + return +} + +// CHECK-LABEL: raw_single_block +func @raw_single_block(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + // CHECK: Membar 5 + %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A> + return +} + +// CHECK-LABEL: war_single_block +func @war_single_block(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<128x32xf16, #AL> + %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + // CHECK: Membar 5 + %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL> + // CHECK-NEXT: Membar 7 + %a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + return +} + +// CHECK-LABEL: scratch +func @scratch() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: Membar 1 + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK-NEXT: Membar 3 + %b = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #A> -> tensor<16xf16, #A> + return +} + +// CHECK-LABEL: async_wait +func @async_wait() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: Membar 1 + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + triton_gpu.async_wait {num = 4 : i32} + // CHECK-NEXT: Membar 4 + %a_ = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL> + return +} + +// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region +// CHECK-LABEL: multi_blocks +func @multi_blocks(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + scf.if %i1 { + // CHECK: Membar 2 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } else { + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK-NEXT: Membar 7 + %b = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } + // CHECK-NEXT: Membar 10 + %c = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + return +} + +// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region +// CHECK-LABEL: multi_blocks_join_barrier +func @multi_blocks_join_barrier(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + scf.if %i1 { + // CHECK: Membar 2 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } else { + // CHECK-NEXT: Membar 5 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + return +} + +// Read yielded tensor requires a barrier +// CHECK-LABEL: multi_blocks_yield +func @multi_blocks_yield(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %a = scf.if %i1 -> (tensor<32x16xf16, #A>) { + // CHECK: Membar 2 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield %a : tensor<32x16xf16, #A> + } else { + // CHECK-NEXT: Membar 5 + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield %b : tensor<32x16xf16, #A> + } + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + // CHECK-NEXT: Membar 9 + %b = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + return +} + +// Conservatively add a barrier as if the branch (%i1) is never taken +// CHECK-LABEL: multi_blocks_noelse +func @multi_blocks_noelse(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + scf.if %i1 { + // CHECK: Membar 2 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + return +} + +// Conservatively add a barrier as if the branch (%i2) is never taken +// CHECK-LABEL: multi_blocks_nested_scf +func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + scf.if %i1 { + scf.if %i2 { + // CHECK: Membar 2 + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } + scf.yield + } else { + // CHECK-NEXT: Membar 6 + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + scf.yield + } + // CHECK-NEXT: Membar 9 + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + return +} diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index cbb7661ad..992d687e0 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(TritonTestAnalysis TestAxisInfo.cpp TestAllocation.cpp + TestMembar.cpp LINK_LIBS PUBLIC TritonAnalysis diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 7a3cc55b0..a29465630 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -19,24 +19,29 @@ struct TestAllocationPass void runOnOperation() override { Operation *operation = getOperation(); auto &os = llvm::errs(); - os << "Testing: " << operation->getName() << "\n"; - AllocationAnalysis analysis(operation); + // Convert to std::string can remove quotes from op_name + auto op_name = SymbolTable::getSymbolName(operation).getValue().str(); + os << op_name << "\n"; + Allocation allocation(operation); operation->walk([&](Operation *op) { + auto scratchBufferId = allocation.getBufferId(op); + if (scratchBufferId != Allocation::InvalidBufferId) { + size_t offset = allocation.getOffset(scratchBufferId); + size_t size = allocation.getAllocatedSize(scratchBufferId); + os << "scratch offset = " << offset << ", size = " << size << "\n"; + } if (op->getNumResults() < 1) return; for (Value result : op->getResults()) { - Type type = result.getType(); - if (auto tensorType = type.dyn_cast()) { - Attribute encoding = tensorType.getEncoding(); - if (encoding.isa()) { - size_t offset = analysis.getOffset(result); - size_t size = analysis.getAllocatedSize(result); - os << "offset = " << offset << ", size = " << size << "\n"; - } + auto bufferId = allocation.getBufferId(result); + if (bufferId != Allocation::InvalidBufferId) { + size_t offset = allocation.getOffset(bufferId); + size_t size = allocation.getAllocatedSize(bufferId); + os << "offset = " << offset << ", size = " << size << "\n"; } } }); - os << "size = " << analysis.getSharedMemorySize() << "\n"; + os << "size = " << allocation.getSharedMemorySize() << "\n"; } }; diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp new file mode 100644 index 000000000..ec8c53948 --- /dev/null +++ b/test/lib/Analysis/TestMembar.cpp @@ -0,0 +1,50 @@ +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" + +using namespace mlir; + +namespace { + +struct TestMembarPass + : public PassWrapper> { + + // LLVM15+ + // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); + + StringRef getArgument() const final { return "test-print-membar"; } + StringRef getDescription() const final { + return "print the result of the allocation pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + auto &os = llvm::errs(); + // Convert to std::string can remove quotes from op_name + auto op_name = SymbolTable::getSymbolName(operation).getValue().str(); + os << op_name << "\n"; + Allocation allocation(operation); + MembarAnalysis analysis(&allocation); + size_t operationId = 0; + operation->walk([&](Operation *op) { + if (dyn_cast(op)) { + os << "Membar " << operationId << "\n"; + } + if (op->getNumRegions() == 0) { + // Don't count parent Operation to simplify the test. + operationId++; + } + return; + }); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestMembarPass() { PassRegistration(); } +} // namespace test +} // namespace mlir