diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index d476135d8..359889582 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -13,7 +13,8 @@ namespace mlir { namespace test { void registerTestAlignmentPass(); -} +void registerTestAllocationPass(); +} // namespace test } // namespace mlir int main(int argc, char **argv) { @@ -21,6 +22,7 @@ int main(int argc, char **argv) { mlir::registerTritonPasses(); mlir::registerTritonGPUPasses(); mlir::test::registerTestAlignmentPass(); + mlir::test::registerTestAllocationPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerConvertTritonGPUToLLVMPass(); diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h new file mode 100644 index 000000000..b489f5142 --- /dev/null +++ b/include/triton/Analysis/Allocation.h @@ -0,0 +1,115 @@ +#ifndef TRITON_ANALYSIS_ALLOCATION_H +#define TRITON_ANALYSIS_ALLOCATION_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/raw_ostream.h" +#include + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h +/// A class that represents an address range. The range is specified using +/// a start and an end address: [Start, End). +template class Range { +public: + Range() {} + Range(AddrT S, AddrT E) : Start(S), End(E) { assert(Start <= End); } + AddrT start() const { return Start; } + AddrT end() const { return End; } + AddrT size() const { return End - Start; } + bool contains(AddrT Addr) const { return Start <= Addr && Addr < End; } + bool intersects(const Range &R) const { + return Start < R.End && R.Start < End; + } + bool operator==(const Range &R) const { + return Start == R.Start && End == R.End; + } + bool operator!=(const Range &R) const { return !(*this == R); } + bool operator<(const Range &R) const { + return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); + } + +private: + AddrT Start = std::numeric_limits::min(); + AddrT End = std::numeric_limits::max(); +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +class AllocationAnalysis { +public: + using ValueSizeMapT = llvm::DenseMap; + +public: + /// Creates a new Allocation analysis that computes the shared memory + /// information for all associated shared memory values. + AllocationAnalysis(Operation *operation) : operation(operation) { run(); } + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Returns the offset of the given value in the shared memory. + size_t getOffset(Value value) const { return valueOffset.lookup(value); } + + /// Returns the size of the given value in the shared memory. + size_t getAllocatedSize(Value value) const { return valueSize.lookup(value); } + + /// Returns the size of total shared memory allocated + size_t getSharedMemorySize() const { return sharedMemorySize; } + +private: + /// Value -> Range + /// Use MapVector to ensure determinism. + using ValueRangeMapT = llvm::MapVector>; + /// Start -> Range + using TripleMapT = std::multimap>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + /// Runs allocation analysis on the given top-level operation. + void run(); + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness(ValueRangeMapT &valueRangeMap); + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf) + void computeOffsets(const ValueRangeMapT &valueRangeMap); + + /// Gets shared memory value and size from valueRangeMap. + void getSharedMemoryValuesAndSizes(const ValueRangeMapT &valueRangeMap, + SmallVector &sharedMemoryValues); + + /// Computes the initial shared memory offsets. + void calculateSharedMemoryStarts(const ValueRangeMapT &valueRangeMap, + const SmallVector &sharedMemoryValues, + ValueSizeMapT &sharedMemoryStart); + + /// Builds a graph of all shared memory values. Edges are created between + /// between shared memory values that are overlapping. + void buildInterferenceGraph(const ValueRangeMapT &valueRangeMap, + const SmallVector &sharedMemoryValues, + const ValueSizeMapT &sharedMemoryStart, + GraphT &interference); + + /// Finalizes shared memory offsets considering interference. + void allocateSharedMemory(const ValueRangeMapT &valueRangeMap, + const SmallVector &sharedMemoryValues, + const ValueSizeMapT &sharedMemoryStart, + const GraphT &interference); + +private: + Operation *operation; + ValueSizeMapT valueOffset; + ValueSizeMapT valueSize; + size_t sharedMemorySize = 0; +}; + +} // namespace mlir + +#endif diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp new file mode 100644 index 000000000..8d5b6afcc --- /dev/null +++ b/lib/Analysis/Allocation.cpp @@ -0,0 +1,200 @@ +#include "triton/Analysis/Allocation.h" +#include "mlir/Analysis/Liveness.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include + +namespace mlir { + +void AllocationAnalysis::run() { + ValueRangeMapT valueRange; + resolveLiveness(valueRange); + computeOffsets(valueRange); +} + +void AllocationAnalysis::resolveLiveness( + AllocationAnalysis::ValueRangeMapT &valueRange) { + Liveness liveness(operation); + DenseMap operationIds; + operation->walk([&](Operation *op) { + operationIds.insert({op, operationIds.size()}); + }); + + operation->walk([&](Operation *op) { + for (Value result : op->getResults()) { + auto liveOperations = liveness.resolveLiveness(result); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationIds[liveOp] < minId) { + minId = operationIds[liveOp]; + } + if (operationIds[liveOp] > maxId) { + maxId = operationIds[liveOp]; + } + }); + valueRange.insert({result, Range(minId, maxId + 1)}); + } + }); +} + +void AllocationAnalysis::getSharedMemoryValuesAndSizes( + const AllocationAnalysis::ValueRangeMapT &valueRange, + SmallVector &sharedMemoryValues) { + for (auto &valueRange : valueRange) { + auto value = valueRange.first; + auto type = value.getType(); + if (auto tensorType = type.dyn_cast()) { + auto encoding = tensorType.getEncoding(); + if (encoding && + encoding.isa()) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto bytes = tensorType.getNumElements() * + tensorType.getElementTypeBitWidth() / 8; + sharedMemoryValues.emplace_back(value); + valueSize.insert({value, bytes}); + } + } + } +} + +void AllocationAnalysis::calculateSharedMemoryStarts( + const AllocationAnalysis::ValueRangeMapT &valueRange, + const SmallVector &sharedMemoryValues, + ValueSizeMapT &sharedMemoryStart) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Range())); + SmallVector values = sharedMemoryValues; + while (!values.empty()) { + auto tripleIt = tripleMap.begin(); + auto size = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto valueIt = std::find_if(values.begin(), values.end(), [&](Value value) { + auto xRange = valueRange.lookup(value); + bool res = xRange.intersects(range); + for (auto val : tripleMap) + res = res && !val.second.intersects(xRange); + return res; + }); + if (valueIt != values.end()) { + auto value = *valueIt; + auto xSize = valueSize.lookup(value); + auto xRange = valueRange.lookup(value); + sharedMemoryStart[value] = size; + tripleMap.insert( + {size + xSize, Range{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + if (range.start() < xRange.start()) + tripleMap.insert({size, Range{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({size, Range{xRange.start(), range.end()}}); + values.erase(valueIt); + } + } +} + +void AllocationAnalysis::buildInterferenceGraph( + const AllocationAnalysis::ValueRangeMapT &valueRange, + const SmallVector &sharedMemoryValues, + const ValueSizeMapT &sharedMemoryStart, GraphT &interference) { + for (auto x : sharedMemoryValues) { + for (auto y : sharedMemoryValues) { + if (x == y) + continue; + auto xStart = sharedMemoryStart.lookup(x); + auto yStart = sharedMemoryStart.lookup(y); + auto xSize = valueSize.lookup(x); + auto ySize = valueSize.lookup(y); + Range xSizeRange = {xStart, xStart + xSize}; + Range ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = valueRange.lookup(x); + auto yOpRange = valueRange.lookup(y); + if (xOpRange.intersects(yOpRange) && xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + } + } +} + +void AllocationAnalysis::allocateSharedMemory( + const AllocationAnalysis::ValueRangeMapT &valueRangeMap, + const SmallVector &sharedMemoryValues, + const AllocationAnalysis::ValueSizeMapT &sharedMemoryStart, + const AllocationAnalysis::GraphT &interference) { + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available non-neighboring + // node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : sharedMemoryValues) { + colors[value] = (value == sharedMemoryValues[0]) ? 0 : -1; + } + SmallVector available(sharedMemoryValues.size()); + for (auto x : sharedMemoryValues) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : sharedMemoryValues) { + size_t adj = 0; + for (auto y : interference.lookup(x)) { + adj = std::max(adj, sharedMemoryStart.lookup(y) + valueSize.lookup(y)); + } + valueOffset[x] = sharedMemoryStart.lookup(x) + colors.lookup(x) * adj; + sharedMemorySize = + std::max(sharedMemorySize, valueOffset[x] + valueSize.lookup(x)); + } +} + +void AllocationAnalysis::computeOffsets( + const AllocationAnalysis::ValueRangeMapT &valueRange) { + SmallVector sharedMemoryValues; + getSharedMemoryValuesAndSizes(valueRange, sharedMemoryValues); + + ValueSizeMapT sharedMemoryStart; + calculateSharedMemoryStarts(valueRange, sharedMemoryValues, + sharedMemoryStart); + + GraphT interference; + buildInterferenceGraph(valueRange, sharedMemoryValues, sharedMemoryStart, + interference); + + allocateSharedMemory(valueRange, sharedMemoryValues, sharedMemoryStart, + interference); +} + +} // namespace mlir diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 68278474b..a0d2092d5 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TritonAnalysis AxisInfo.cpp + Allocation.cpp DEPENDS TritonGPUAttrDefsIncGen diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir new file mode 100644 index 000000000..8362ea58c --- /dev/null +++ b/test/Analysis/test-allocation.mlir @@ -0,0 +1,145 @@ +// RUN: triton-opt %s --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> + +func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + // CHECK: offset = 0, size = 8192 + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + // CHECK: offset = 8192, size = 8192 + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + return + // CHECK: size = 16384 +} + +// Shared memory is available after a tensor's liveness range ends +func @synthesized_reusable(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %cst3 = arith.constant dense : tensor<32x128xi1, #AL> + %cst4 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #AL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> + %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + // CHECK: offset = 0, size = 8192 + %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> + // CHECK: offset = 8192, size = 8192 + %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> + %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + // CHECK: offset = 16384, size = 8192 + %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> + // CHECK: offset = 0, size = 8192 + %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> + %c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + return + // CHECK: size = 24576 +} + +// A tensor's shared memory offset is larger than it needs to accommodate further tensors +// %cst0->%c +// %cst1->%cst4 +// %cst3->%g->%h->%i +func @synthesize_preallocate(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1024, size = 512 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1536, size = 512 + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 2048, size = 1024 + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: offset = 3072, size = 1024 + %b = tt.cat %cst0, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: offset = 0, size = 1024 + %c = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: offset = 1024, size = 1024 + %cst4 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A> + // CHECK: offset = 6144, size = 2048 + %e = tt.cat %a, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + // CHECK: offset = 8192, size = 2048 + %d = tt.cat %b, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + // CHECK: offset = 10240, size = 2048 + %f = tt.cat %c, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + // CHECK: offset = 0, size = 2048 + %cst5 = arith.constant dense<0.000000e+00> : tensor<64x16xf16, #A> + // CHECK: offset = 2048, size = 4096 + %g = tt.cat %e, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> + // CHECK: offset = 2048, size = 4096 + %h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> + // CHECK: offset = 2048, size = 4096 + %i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> + return + // CHECK: size = 12288 +} + +// Unused tensors are immediately released +func @synthesize_unused(%A : !tt.ptr) { + // CHECK: offset = 0, size = 1024 + %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A> + // CHECK: offset = 0, size = 512 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 512, size = 512 + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1024, size = 1024 + %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + return + // CHECK: size = 2048 +} + +// cst0 is alive through the entire function, it cannot be released before the end of the function +func @synthesize_longlive(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 512, size = 512 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1024, size = 512 + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1536, size = 1024 + %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: offset = 512, size = 512 + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1024, size = 512 + %cst4 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1536, size = 1024 + %b = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: offset = 1536, size = 512 + %cst5 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1536, size = 512 + %cst6 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + // CHECK: offset = 1536, size = 1024 + %c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + // CHECK: offset = 512, size = 1024 + %d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + return + // CHECK: size = 2560 +} diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index 2f45f2e53..cbb7661ad 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TritonTestAnalysis TestAxisInfo.cpp + TestAllocation.cpp LINK_LIBS PUBLIC TritonAnalysis diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp new file mode 100644 index 000000000..7a3cc55b0 --- /dev/null +++ b/test/lib/Analysis/TestAllocation.cpp @@ -0,0 +1,49 @@ +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" + +using namespace mlir; + +namespace { + +struct TestAllocationPass + : public PassWrapper> { + + // LLVM15+ + // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + + StringRef getArgument() const final { return "test-print-allocation"; } + StringRef getDescription() const final { + return "print the result of the allocation pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + auto &os = llvm::errs(); + os << "Testing: " << operation->getName() << "\n"; + AllocationAnalysis analysis(operation); + operation->walk([&](Operation *op) { + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { + Type type = result.getType(); + if (auto tensorType = type.dyn_cast()) { + Attribute encoding = tensorType.getEncoding(); + if (encoding.isa()) { + size_t offset = analysis.getOffset(result); + size_t size = analysis.getAllocatedSize(result); + os << "offset = " << offset << ", size = " << size << "\n"; + } + } + } + }); + os << "size = " << analysis.getSharedMemorySize() << "\n"; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestAllocationPass() { PassRegistration(); } +} // namespace test +} // namespace mlir