[BACKEND] Memory allocation (#33)
This commit is contained in:
@@ -13,7 +13,8 @@
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace test {
|
namespace test {
|
||||||
void registerTestAlignmentPass();
|
void registerTestAlignmentPass();
|
||||||
}
|
void registerTestAllocationPass();
|
||||||
|
} // namespace test
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
@@ -21,6 +22,7 @@ int main(int argc, char **argv) {
|
|||||||
mlir::registerTritonPasses();
|
mlir::registerTritonPasses();
|
||||||
mlir::registerTritonGPUPasses();
|
mlir::registerTritonGPUPasses();
|
||||||
mlir::test::registerTestAlignmentPass();
|
mlir::test::registerTestAlignmentPass();
|
||||||
|
mlir::test::registerTestAllocationPass();
|
||||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||||
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
||||||
|
|
||||||
|
115
include/triton/Analysis/Allocation.h
Normal file
115
include/triton/Analysis/Allocation.h
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
#ifndef TRITON_ANALYSIS_ALLOCATION_H
|
||||||
|
#define TRITON_ANALYSIS_ALLOCATION_H
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/MapVector.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||||
|
/// A class that represents an address range. The range is specified using
|
||||||
|
/// a start and an end address: [Start, End).
|
||||||
|
template <typename AddrT> class Range {
|
||||||
|
public:
|
||||||
|
Range() {}
|
||||||
|
Range(AddrT S, AddrT E) : Start(S), End(E) { assert(Start <= End); }
|
||||||
|
AddrT start() const { return Start; }
|
||||||
|
AddrT end() const { return End; }
|
||||||
|
AddrT size() const { return End - Start; }
|
||||||
|
bool contains(AddrT Addr) const { return Start <= Addr && Addr < End; }
|
||||||
|
bool intersects(const Range &R) const {
|
||||||
|
return Start < R.End && R.Start < End;
|
||||||
|
}
|
||||||
|
bool operator==(const Range &R) const {
|
||||||
|
return Start == R.Start && End == R.End;
|
||||||
|
}
|
||||||
|
bool operator!=(const Range &R) const { return !(*this == R); }
|
||||||
|
bool operator<(const Range &R) const {
|
||||||
|
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AddrT Start = std::numeric_limits<AddrT>::min();
|
||||||
|
AddrT End = std::numeric_limits<AddrT>::max();
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Shared Memory Allocation Analysis
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
class AllocationAnalysis {
|
||||||
|
public:
|
||||||
|
using ValueSizeMapT = llvm::DenseMap<Value, size_t>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Creates a new Allocation analysis that computes the shared memory
|
||||||
|
/// information for all associated shared memory values.
|
||||||
|
AllocationAnalysis(Operation *operation) : operation(operation) { run(); }
|
||||||
|
|
||||||
|
/// Returns the operation this analysis was constructed from.
|
||||||
|
Operation *getOperation() const { return operation; }
|
||||||
|
|
||||||
|
/// Returns the offset of the given value in the shared memory.
|
||||||
|
size_t getOffset(Value value) const { return valueOffset.lookup(value); }
|
||||||
|
|
||||||
|
/// Returns the size of the given value in the shared memory.
|
||||||
|
size_t getAllocatedSize(Value value) const { return valueSize.lookup(value); }
|
||||||
|
|
||||||
|
/// Returns the size of total shared memory allocated
|
||||||
|
size_t getSharedMemorySize() const { return sharedMemorySize; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Value -> Range
|
||||||
|
/// Use MapVector to ensure determinism.
|
||||||
|
using ValueRangeMapT = llvm::MapVector<Value, Range<size_t>>;
|
||||||
|
/// Start -> Range
|
||||||
|
using TripleMapT = std::multimap<size_t, Range<size_t>>;
|
||||||
|
/// Nodes -> Nodes
|
||||||
|
using GraphT = DenseMap<Value, DenseSet<Value>>;
|
||||||
|
|
||||||
|
/// Runs allocation analysis on the given top-level operation.
|
||||||
|
void run();
|
||||||
|
|
||||||
|
/// Resolves liveness of all values involved under the root operation.
|
||||||
|
void resolveLiveness(ValueRangeMapT &valueRangeMap);
|
||||||
|
|
||||||
|
/// Computes the shared memory offsets for all related values.
|
||||||
|
/// Paper: Algorithms for Compile-Time Memory Optimization
|
||||||
|
/// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf)
|
||||||
|
void computeOffsets(const ValueRangeMapT &valueRangeMap);
|
||||||
|
|
||||||
|
/// Gets shared memory value and size from valueRangeMap.
|
||||||
|
void getSharedMemoryValuesAndSizes(const ValueRangeMapT &valueRangeMap,
|
||||||
|
SmallVector<Value> &sharedMemoryValues);
|
||||||
|
|
||||||
|
/// Computes the initial shared memory offsets.
|
||||||
|
void calculateSharedMemoryStarts(const ValueRangeMapT &valueRangeMap,
|
||||||
|
const SmallVector<Value> &sharedMemoryValues,
|
||||||
|
ValueSizeMapT &sharedMemoryStart);
|
||||||
|
|
||||||
|
/// Builds a graph of all shared memory values. Edges are created between
|
||||||
|
/// between shared memory values that are overlapping.
|
||||||
|
void buildInterferenceGraph(const ValueRangeMapT &valueRangeMap,
|
||||||
|
const SmallVector<Value> &sharedMemoryValues,
|
||||||
|
const ValueSizeMapT &sharedMemoryStart,
|
||||||
|
GraphT &interference);
|
||||||
|
|
||||||
|
/// Finalizes shared memory offsets considering interference.
|
||||||
|
void allocateSharedMemory(const ValueRangeMapT &valueRangeMap,
|
||||||
|
const SmallVector<Value> &sharedMemoryValues,
|
||||||
|
const ValueSizeMapT &sharedMemoryStart,
|
||||||
|
const GraphT &interference);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Operation *operation;
|
||||||
|
ValueSizeMapT valueOffset;
|
||||||
|
ValueSizeMapT valueSize;
|
||||||
|
size_t sharedMemorySize = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif
|
200
lib/Analysis/Allocation.cpp
Normal file
200
lib/Analysis/Allocation.cpp
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
#include "triton/Analysis/Allocation.h"
|
||||||
|
#include "mlir/Analysis/Liveness.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
void AllocationAnalysis::run() {
|
||||||
|
ValueRangeMapT valueRange;
|
||||||
|
resolveLiveness(valueRange);
|
||||||
|
computeOffsets(valueRange);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocationAnalysis::resolveLiveness(
|
||||||
|
AllocationAnalysis::ValueRangeMapT &valueRange) {
|
||||||
|
Liveness liveness(operation);
|
||||||
|
DenseMap<Operation *, size_t> operationIds;
|
||||||
|
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||||
|
operationIds.insert({op, operationIds.size()});
|
||||||
|
});
|
||||||
|
|
||||||
|
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
auto liveOperations = liveness.resolveLiveness(result);
|
||||||
|
auto minId = std::numeric_limits<size_t>::max();
|
||||||
|
auto maxId = std::numeric_limits<size_t>::min();
|
||||||
|
std::for_each(liveOperations.begin(), liveOperations.end(),
|
||||||
|
[&](Operation *liveOp) {
|
||||||
|
if (operationIds[liveOp] < minId) {
|
||||||
|
minId = operationIds[liveOp];
|
||||||
|
}
|
||||||
|
if (operationIds[liveOp] > maxId) {
|
||||||
|
maxId = operationIds[liveOp];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
valueRange.insert({result, Range(minId, maxId + 1)});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocationAnalysis::getSharedMemoryValuesAndSizes(
|
||||||
|
const AllocationAnalysis::ValueRangeMapT &valueRange,
|
||||||
|
SmallVector<Value> &sharedMemoryValues) {
|
||||||
|
for (auto &valueRange : valueRange) {
|
||||||
|
auto value = valueRange.first;
|
||||||
|
auto type = value.getType();
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
auto encoding = tensorType.getEncoding();
|
||||||
|
if (encoding &&
|
||||||
|
encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||||
|
// Bytes could be a different value once we support padding or other
|
||||||
|
// allocation policies.
|
||||||
|
auto bytes = tensorType.getNumElements() *
|
||||||
|
tensorType.getElementTypeBitWidth() / 8;
|
||||||
|
sharedMemoryValues.emplace_back(value);
|
||||||
|
valueSize.insert({value, bytes});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocationAnalysis::calculateSharedMemoryStarts(
|
||||||
|
const AllocationAnalysis::ValueRangeMapT &valueRange,
|
||||||
|
const SmallVector<Value> &sharedMemoryValues,
|
||||||
|
ValueSizeMapT &sharedMemoryStart) {
|
||||||
|
// v = values in shared memory
|
||||||
|
// t = triplet of (size, start, end)
|
||||||
|
// shared memory space
|
||||||
|
// -
|
||||||
|
// | *******t4
|
||||||
|
// | /|\ v2 inserts t4, t5, and t6
|
||||||
|
// | |
|
||||||
|
// | ******t5 ************t6
|
||||||
|
// | ^^^^^v2^^^^^^
|
||||||
|
// | | *********************t2
|
||||||
|
// | \|/ v2 erases t1
|
||||||
|
// | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3
|
||||||
|
// |---------------------------------------------| liveness range
|
||||||
|
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
|
||||||
|
TripleMapT tripleMap;
|
||||||
|
tripleMap.insert(std::make_pair(0, Range<size_t>()));
|
||||||
|
SmallVector<Value> values = sharedMemoryValues;
|
||||||
|
while (!values.empty()) {
|
||||||
|
auto tripleIt = tripleMap.begin();
|
||||||
|
auto size = tripleIt->first;
|
||||||
|
auto range = tripleIt->second;
|
||||||
|
tripleMap.erase(tripleIt);
|
||||||
|
auto valueIt = std::find_if(values.begin(), values.end(), [&](Value value) {
|
||||||
|
auto xRange = valueRange.lookup(value);
|
||||||
|
bool res = xRange.intersects(range);
|
||||||
|
for (auto val : tripleMap)
|
||||||
|
res = res && !val.second.intersects(xRange);
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
if (valueIt != values.end()) {
|
||||||
|
auto value = *valueIt;
|
||||||
|
auto xSize = valueSize.lookup(value);
|
||||||
|
auto xRange = valueRange.lookup(value);
|
||||||
|
sharedMemoryStart[value] = size;
|
||||||
|
tripleMap.insert(
|
||||||
|
{size + xSize, Range{std::max(range.start(), xRange.start()),
|
||||||
|
std::min(range.end(), xRange.end())}});
|
||||||
|
if (range.start() < xRange.start())
|
||||||
|
tripleMap.insert({size, Range{range.start(), xRange.end()}});
|
||||||
|
if (xRange.end() < range.end())
|
||||||
|
tripleMap.insert({size, Range{xRange.start(), range.end()}});
|
||||||
|
values.erase(valueIt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocationAnalysis::buildInterferenceGraph(
|
||||||
|
const AllocationAnalysis::ValueRangeMapT &valueRange,
|
||||||
|
const SmallVector<Value> &sharedMemoryValues,
|
||||||
|
const ValueSizeMapT &sharedMemoryStart, GraphT &interference) {
|
||||||
|
for (auto x : sharedMemoryValues) {
|
||||||
|
for (auto y : sharedMemoryValues) {
|
||||||
|
if (x == y)
|
||||||
|
continue;
|
||||||
|
auto xStart = sharedMemoryStart.lookup(x);
|
||||||
|
auto yStart = sharedMemoryStart.lookup(y);
|
||||||
|
auto xSize = valueSize.lookup(x);
|
||||||
|
auto ySize = valueSize.lookup(y);
|
||||||
|
Range xSizeRange = {xStart, xStart + xSize};
|
||||||
|
Range ySizeRange = {yStart, yStart + ySize};
|
||||||
|
auto xOpRange = valueRange.lookup(x);
|
||||||
|
auto yOpRange = valueRange.lookup(y);
|
||||||
|
if (xOpRange.intersects(yOpRange) && xSizeRange.intersects(ySizeRange)) {
|
||||||
|
interference[x].insert(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocationAnalysis::allocateSharedMemory(
|
||||||
|
const AllocationAnalysis::ValueRangeMapT &valueRangeMap,
|
||||||
|
const SmallVector<Value> &sharedMemoryValues,
|
||||||
|
const AllocationAnalysis::ValueSizeMapT &sharedMemoryStart,
|
||||||
|
const AllocationAnalysis::GraphT &interference) {
|
||||||
|
// First-fit graph coloring
|
||||||
|
// Neighbors are nodes that interfere with each other.
|
||||||
|
// We color a node by finding the index of the first available non-neighboring
|
||||||
|
// node or the first neighboring node without any color.
|
||||||
|
// Nodes with the same color do not interfere with each other.
|
||||||
|
DenseMap<Value, int> colors;
|
||||||
|
for (auto value : sharedMemoryValues) {
|
||||||
|
colors[value] = (value == sharedMemoryValues[0]) ? 0 : -1;
|
||||||
|
}
|
||||||
|
SmallVector<bool> available(sharedMemoryValues.size());
|
||||||
|
for (auto x : sharedMemoryValues) {
|
||||||
|
std::fill(available.begin(), available.end(), true);
|
||||||
|
for (auto y : interference.lookup(x)) {
|
||||||
|
int color = colors[y];
|
||||||
|
if (color >= 0) {
|
||||||
|
available[color] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto it = std::find(available.begin(), available.end(), true);
|
||||||
|
colors[x] = std::distance(available.begin(), it);
|
||||||
|
}
|
||||||
|
// Finalize allocation
|
||||||
|
// color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15)
|
||||||
|
// color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24)
|
||||||
|
// color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42)
|
||||||
|
// TODO(Keren): We are wasting memory here.
|
||||||
|
// Nodes with color2 can actually start with 24.
|
||||||
|
for (auto x : sharedMemoryValues) {
|
||||||
|
size_t adj = 0;
|
||||||
|
for (auto y : interference.lookup(x)) {
|
||||||
|
adj = std::max(adj, sharedMemoryStart.lookup(y) + valueSize.lookup(y));
|
||||||
|
}
|
||||||
|
valueOffset[x] = sharedMemoryStart.lookup(x) + colors.lookup(x) * adj;
|
||||||
|
sharedMemorySize =
|
||||||
|
std::max(sharedMemorySize, valueOffset[x] + valueSize.lookup(x));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocationAnalysis::computeOffsets(
|
||||||
|
const AllocationAnalysis::ValueRangeMapT &valueRange) {
|
||||||
|
SmallVector<Value> sharedMemoryValues;
|
||||||
|
getSharedMemoryValuesAndSizes(valueRange, sharedMemoryValues);
|
||||||
|
|
||||||
|
ValueSizeMapT sharedMemoryStart;
|
||||||
|
calculateSharedMemoryStarts(valueRange, sharedMemoryValues,
|
||||||
|
sharedMemoryStart);
|
||||||
|
|
||||||
|
GraphT interference;
|
||||||
|
buildInterferenceGraph(valueRange, sharedMemoryValues, sharedMemoryStart,
|
||||||
|
interference);
|
||||||
|
|
||||||
|
allocateSharedMemory(valueRange, sharedMemoryValues, sharedMemoryStart,
|
||||||
|
interference);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir
|
@@ -1,5 +1,6 @@
|
|||||||
add_mlir_library(TritonAnalysis
|
add_mlir_library(TritonAnalysis
|
||||||
AxisInfo.cpp
|
AxisInfo.cpp
|
||||||
|
Allocation.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TritonGPUAttrDefsIncGen
|
TritonGPUAttrDefsIncGen
|
||||||
|
145
test/Analysis/test-allocation.mlir
Normal file
145
test/Analysis/test-allocation.mlir
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
// RUN: triton-opt %s --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
|
||||||
|
|
||||||
|
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
|
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||||
|
|
||||||
|
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
|
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||||
|
|
||||||
|
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||||
|
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||||
|
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
||||||
|
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL>
|
||||||
|
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
|
||||||
|
|
||||||
|
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
|
||||||
|
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||||
|
|
||||||
|
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||||
|
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
|
// CHECK: offset = 0, size = 8192
|
||||||
|
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
|
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||||
|
// CHECK: offset = 8192, size = 8192
|
||||||
|
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||||
|
|
||||||
|
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
|
|
||||||
|
%next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
|
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||||
|
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
// CHECK: size = 16384
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shared memory is available after a tensor's liveness range ends
|
||||||
|
func @synthesized_reusable(%A : !tt.ptr<f16>) {
|
||||||
|
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||||
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||||
|
%cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL>
|
||||||
|
%cst4 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #AL>
|
||||||
|
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
|
||||||
|
|
||||||
|
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
|
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||||
|
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
|
// CHECK: offset = 0, size = 8192
|
||||||
|
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
|
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
|
// CHECK: offset = 8192, size = 8192
|
||||||
|
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
|
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||||
|
// CHECK: offset = 16384, size = 8192
|
||||||
|
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||||
|
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
|
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||||
|
// CHECK: offset = 0, size = 8192
|
||||||
|
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||||
|
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||||
|
return
|
||||||
|
// CHECK: size = 24576
|
||||||
|
}
|
||||||
|
|
||||||
|
// A tensor's shared memory offset is larger than it needs to accommodate further tensors
|
||||||
|
// %cst0->%c
|
||||||
|
// %cst1->%cst4
|
||||||
|
// %cst3->%g->%h->%i
|
||||||
|
func @synthesize_preallocate(%A : !tt.ptr<f16>) {
|
||||||
|
// CHECK: offset = 0, size = 512
|
||||||
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1024, size = 512
|
||||||
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1536, size = 512
|
||||||
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 2048, size = 1024
|
||||||
|
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 3072, size = 1024
|
||||||
|
%b = tt.cat %cst0, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 0, size = 1024
|
||||||
|
%c = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 1024, size = 1024
|
||||||
|
%cst4 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 6144, size = 2048
|
||||||
|
%e = tt.cat %a, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A>
|
||||||
|
// CHECK: offset = 8192, size = 2048
|
||||||
|
%d = tt.cat %b, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A>
|
||||||
|
// CHECK: offset = 10240, size = 2048
|
||||||
|
%f = tt.cat %c, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A>
|
||||||
|
// CHECK: offset = 0, size = 2048
|
||||||
|
%cst5 = arith.constant dense<0.000000e+00> : tensor<64x16xf16, #A>
|
||||||
|
// CHECK: offset = 2048, size = 4096
|
||||||
|
%g = tt.cat %e, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A>
|
||||||
|
// CHECK: offset = 2048, size = 4096
|
||||||
|
%h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A>
|
||||||
|
// CHECK: offset = 2048, size = 4096
|
||||||
|
%i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A>
|
||||||
|
return
|
||||||
|
// CHECK: size = 12288
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unused tensors are immediately released
|
||||||
|
func @synthesize_unused(%A : !tt.ptr<f16>) {
|
||||||
|
// CHECK: offset = 0, size = 1024
|
||||||
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 0, size = 512
|
||||||
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 512, size = 512
|
||||||
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1024, size = 1024
|
||||||
|
%a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
return
|
||||||
|
// CHECK: size = 2048
|
||||||
|
}
|
||||||
|
|
||||||
|
// cst0 is alive through the entire function, it cannot be released before the end of the function
|
||||||
|
func @synthesize_longlive(%A : !tt.ptr<f16>) {
|
||||||
|
// CHECK: offset = 0, size = 512
|
||||||
|
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 512, size = 512
|
||||||
|
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1024, size = 512
|
||||||
|
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1536, size = 1024
|
||||||
|
%a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 512, size = 512
|
||||||
|
%cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1024, size = 512
|
||||||
|
%cst4 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1536, size = 1024
|
||||||
|
%b = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 1536, size = 512
|
||||||
|
%cst5 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1536, size = 512
|
||||||
|
%cst6 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A>
|
||||||
|
// CHECK: offset = 1536, size = 1024
|
||||||
|
%c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
// CHECK: offset = 512, size = 1024
|
||||||
|
%d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A>
|
||||||
|
return
|
||||||
|
// CHECK: size = 2560
|
||||||
|
}
|
@@ -1,5 +1,6 @@
|
|||||||
add_mlir_library(TritonTestAnalysis
|
add_mlir_library(TritonTestAnalysis
|
||||||
TestAxisInfo.cpp
|
TestAxisInfo.cpp
|
||||||
|
TestAllocation.cpp
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
TritonAnalysis
|
TritonAnalysis
|
||||||
|
49
test/lib/Analysis/TestAllocation.cpp
Normal file
49
test/lib/Analysis/TestAllocation.cpp
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "triton/Analysis/Allocation.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct TestAllocationPass
|
||||||
|
: public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> {
|
||||||
|
|
||||||
|
// LLVM15+
|
||||||
|
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
||||||
|
|
||||||
|
StringRef getArgument() const final { return "test-print-allocation"; }
|
||||||
|
StringRef getDescription() const final {
|
||||||
|
return "print the result of the allocation pass";
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
Operation *operation = getOperation();
|
||||||
|
auto &os = llvm::errs();
|
||||||
|
os << "Testing: " << operation->getName() << "\n";
|
||||||
|
AllocationAnalysis analysis(operation);
|
||||||
|
operation->walk([&](Operation *op) {
|
||||||
|
if (op->getNumResults() < 1)
|
||||||
|
return;
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
Type type = result.getType();
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
Attribute encoding = tensorType.getEncoding();
|
||||||
|
if (encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||||
|
size_t offset = analysis.getOffset(result);
|
||||||
|
size_t size = analysis.getAllocatedSize(result);
|
||||||
|
os << "offset = " << offset << ", size = " << size << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
os << "size = " << analysis.getSharedMemorySize() << "\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace test {
|
||||||
|
void registerTestAllocationPass() { PassRegistration<TestAllocationPass>(); }
|
||||||
|
} // namespace test
|
||||||
|
} // namespace mlir
|
Reference in New Issue
Block a user