[BACKEND] Keren/shared memory barrier (#59)
This commit is contained in:
@@ -4,23 +4,28 @@
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <limits>
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class AllocationAnalysis;
|
||||
}
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
/// A class that represents an address range. The range is specified using
|
||||
/// a start and an end address: [Start, End).
|
||||
template <typename AddrT> class Range {
|
||||
/// A class that represents a range, specified using a start and an end values:
|
||||
/// [Start, End).
|
||||
template <typename T> class Range {
|
||||
public:
|
||||
Range() {}
|
||||
Range(AddrT S, AddrT E) : Start(S), End(E) { assert(Start <= End); }
|
||||
AddrT start() const { return Start; }
|
||||
AddrT end() const { return End; }
|
||||
AddrT size() const { return End - Start; }
|
||||
bool contains(AddrT Addr) const { return Start <= Addr && Addr < End; }
|
||||
Range(T S, T E) : Start(S), End(E) { assert(Start <= End); }
|
||||
T start() const { return Start; }
|
||||
T end() const { return End; }
|
||||
T size() const { return End - Start; }
|
||||
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
|
||||
bool intersects(const Range &R) const {
|
||||
return Start < R.End && R.Start < End;
|
||||
}
|
||||
@@ -33,83 +38,122 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
AddrT Start = std::numeric_limits<AddrT>::min();
|
||||
AddrT End = std::numeric_limits<AddrT>::max();
|
||||
T Start = std::numeric_limits<T>::min();
|
||||
T End = std::numeric_limits<T>::max();
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Allocation Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class AllocationAnalysis {
|
||||
class Allocation {
|
||||
public:
|
||||
using ValueSizeMapT = llvm::DenseMap<Value, size_t>;
|
||||
/// A unique identifier for shared memory buffers
|
||||
using BufferId = size_t;
|
||||
static constexpr BufferId InvalidBufferId =
|
||||
std::numeric_limits<BufferId>::max();
|
||||
|
||||
public:
|
||||
/// Creates a new Allocation analysis that computes the shared memory
|
||||
/// information for all associated shared memory values.
|
||||
AllocationAnalysis(Operation *operation) : operation(operation) { run(); }
|
||||
Allocation(Operation *operation) : operation(operation) { run(); }
|
||||
|
||||
/// Returns the operation this analysis was constructed from.
|
||||
Operation *getOperation() const { return operation; }
|
||||
|
||||
/// Returns the offset of the given value in the shared memory.
|
||||
size_t getOffset(Value value) const { return valueOffset.lookup(value); }
|
||||
/// Returns the offset of the given buffer in the shared memory.
|
||||
size_t getOffset(BufferId bufferId) const {
|
||||
return bufferSet.lookup(bufferId).offset;
|
||||
}
|
||||
|
||||
/// Returns the size of the given value in the shared memory.
|
||||
size_t getAllocatedSize(Value value) const { return valueSize.lookup(value); }
|
||||
/// Returns the size of the given buffer in the shared memory.
|
||||
size_t getAllocatedSize(BufferId bufferId) const {
|
||||
return bufferSet.lookup(bufferId).size;
|
||||
}
|
||||
|
||||
/// Returns the buffer id of the given value.
|
||||
BufferId getBufferId(Value value) const {
|
||||
if (valueBuffer.count(value)) {
|
||||
return valueBuffer.lookup(value)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the scratch buffer id of the given value.
|
||||
BufferId getBufferId(Operation *operation) const {
|
||||
if (opScratch.count(operation)) {
|
||||
return opScratch.lookup(operation)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of total shared memory allocated
|
||||
size_t getSharedMemorySize() const { return sharedMemorySize; }
|
||||
|
||||
private:
|
||||
/// Value -> Range
|
||||
/// Use MapVector to ensure determinism.
|
||||
using ValueRangeMapT = llvm::MapVector<Value, Range<size_t>>;
|
||||
/// Start -> Range
|
||||
using TripleMapT = std::multimap<size_t, Range<size_t>>;
|
||||
/// Nodes -> Nodes
|
||||
using GraphT = DenseMap<Value, DenseSet<Value>>;
|
||||
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
|
||||
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
|
||||
return false;
|
||||
auto lhsBuffer = bufferSet.lookup(lhsId);
|
||||
auto rhsBuffer = bufferSet.lookup(rhsId);
|
||||
return lhsBuffer.intersects(rhsBuffer);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A class that represents a shared memory buffer
|
||||
struct BufferT {
|
||||
enum class BufferKind { Explicit, Scratch };
|
||||
|
||||
/// MT: thread-safe
|
||||
inline static std::atomic<BufferId> nextId = 0;
|
||||
|
||||
BufferKind kind;
|
||||
BufferId id;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
|
||||
bool operator==(const BufferT &other) const { return id == other.id; }
|
||||
bool operator<(const BufferT &other) const { return id < other.id; }
|
||||
|
||||
BufferT() : BufferT(BufferKind::Explicit) {}
|
||||
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), size(size), offset(offset), id(nextId++) {}
|
||||
|
||||
bool intersects(const BufferT &other) const {
|
||||
return Range<size_t>(offset, offset + size)
|
||||
.intersects(Range<size_t>(other.offset, other.offset + other.size));
|
||||
}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
|
||||
/// Value -> Explicit Buffer
|
||||
using ValueBufferMapT = DenseMap<Value, BufferT *>;
|
||||
/// BufferId -> Buffer
|
||||
using BufferSetT = DenseMap<BufferId, BufferT>;
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
void run();
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
void resolveLiveness(ValueRangeMapT &valueRangeMap);
|
||||
|
||||
/// Computes the shared memory offsets for all related values.
|
||||
/// Paper: Algorithms for Compile-Time Memory Optimization
|
||||
/// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf)
|
||||
void computeOffsets(const ValueRangeMapT &valueRangeMap);
|
||||
|
||||
/// Gets shared memory value and size from valueRangeMap.
|
||||
void getSharedMemoryValuesAndSizes(const ValueRangeMapT &valueRangeMap,
|
||||
SmallVector<Value> &sharedMemoryValues);
|
||||
|
||||
/// Computes the initial shared memory offsets.
|
||||
void calculateSharedMemoryStarts(const ValueRangeMapT &valueRangeMap,
|
||||
const SmallVector<Value> &sharedMemoryValues,
|
||||
ValueSizeMapT &sharedMemoryStart);
|
||||
|
||||
/// Builds a graph of all shared memory values. Edges are created between
|
||||
/// between shared memory values that are overlapping.
|
||||
void buildInterferenceGraph(const ValueRangeMapT &valueRangeMap,
|
||||
const SmallVector<Value> &sharedMemoryValues,
|
||||
const ValueSizeMapT &sharedMemoryStart,
|
||||
GraphT &interference);
|
||||
|
||||
/// Finalizes shared memory offsets considering interference.
|
||||
void allocateSharedMemory(const ValueRangeMapT &valueRangeMap,
|
||||
const SmallVector<Value> &sharedMemoryValues,
|
||||
const ValueSizeMapT &sharedMemoryStart,
|
||||
const GraphT &interference);
|
||||
private:
|
||||
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
|
||||
void addBuffer(KeyType &key, Args &&... args) {
|
||||
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
|
||||
bufferSet[buffer.id] = std::move(buffer);
|
||||
if constexpr (Kind == BufferT::BufferKind::Explicit) {
|
||||
valueBuffer[key] = &bufferSet[buffer.id];
|
||||
} else {
|
||||
opScratch[key] = &bufferSet[buffer.id];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
ValueSizeMapT valueOffset;
|
||||
ValueSizeMapT valueSize;
|
||||
OpScratchMapT opScratch;
|
||||
ValueBufferMapT valueBuffer;
|
||||
BufferSetT bufferSet;
|
||||
size_t sharedMemorySize = 0;
|
||||
|
||||
friend class triton::AllocationAnalysis;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
#endif // TRITON_ANALYSIS_ALLOCATION_H
|
||||
|
115
include/triton/Analysis/Membar.h
Normal file
115
include/triton/Analysis/Membar.h
Normal file
@@ -0,0 +1,115 @@
|
||||
#ifndef TRITON_ANALYSIS_MEMBAR_H
|
||||
#define TRITON_ANALYSIS_MEMBAR_H
|
||||
|
||||
#include "Allocation.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OpBuilder;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Barrier Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class MembarAnalysis {
|
||||
public:
|
||||
/// Creates a new Membar analysis that generates the shared memory barrier
|
||||
/// in the following circumstances:
|
||||
/// - RAW: If a shared memory write is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// - WAR: If a shared memory read is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// The following circumstances do not require a barrier:
|
||||
/// - WAW: not possible because overlapped memory allocation is not allowed.
|
||||
/// - RAR: no write is performed.
|
||||
/// Temporary storage of operations such as Reduce are considered as both
|
||||
/// a shared memory read. If the temporary storage is written but not read,
|
||||
/// it is considered as the problem of the operation itself but not the membar
|
||||
/// analysis.
|
||||
/// The following circumstances are not considered yet:
|
||||
/// - Double buffers
|
||||
/// - N buffers
|
||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
|
||||
|
||||
private:
|
||||
struct RegionInfo {
|
||||
using BufferIdSetT = DenseSet<Allocation::BufferId>;
|
||||
|
||||
BufferIdSetT syncReadBuffers;
|
||||
BufferIdSetT syncWriteBuffers;
|
||||
|
||||
RegionInfo() = default;
|
||||
RegionInfo(const BufferIdSetT &syncReadBuffers,
|
||||
const BufferIdSetT &syncWriteBuffers)
|
||||
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
|
||||
}
|
||||
|
||||
/// Unions two RegionInfo objects.
|
||||
void join(const RegionInfo &other) {
|
||||
syncReadBuffers.insert(other.syncReadBuffers.begin(),
|
||||
other.syncReadBuffers.end());
|
||||
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
|
||||
other.syncWriteBuffers.end());
|
||||
}
|
||||
|
||||
/// Returns true if buffers in two RegionInfo objects are intersected.
|
||||
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
|
||||
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
||||
allocation) ||
|
||||
/*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers,
|
||||
allocation);
|
||||
}
|
||||
|
||||
/// Clears the buffers because a barrier is inserted.
|
||||
void sync() {
|
||||
syncReadBuffers.clear();
|
||||
syncWriteBuffers.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns true if buffers in two sets are intersected.
|
||||
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
|
||||
Allocation *allocation) const {
|
||||
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
|
||||
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
|
||||
return allocation->isIntersected(lhsId, rhsId);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run();
|
||||
|
||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||
/// region has a single basic block only.
|
||||
/// Example:
|
||||
/// region1
|
||||
/// op1
|
||||
/// op2 (scf.if)
|
||||
/// region2
|
||||
/// op3
|
||||
/// op4
|
||||
/// region3
|
||||
/// op5
|
||||
/// op6
|
||||
/// op7
|
||||
/// region2 and region3 started with the information of region1.
|
||||
/// Each region is analyzed separately and keeps their own copy of the
|
||||
/// information. At op7, we union the information of the region2 and region3
|
||||
/// and update the information of region1.
|
||||
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
|
||||
/// Updates the RegionInfo operation based on the operation.
|
||||
void transfer(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
|
||||
private:
|
||||
Allocation *allocation;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_MEMBAR_H
|
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@@ -228,13 +228,13 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
def TT_ReduceOp : TT_Op<"reduce"> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Type:$operand, I32Attr:$axis);
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
// let builders = [
|
||||
// OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>,
|
||||
// ];
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
}
|
||||
|
Reference in New Issue
Block a user