[BACKEND] Keren/shared memory barrier (#59)

This commit is contained in:
Keren Zhou
2022-08-18 12:32:57 -07:00
committed by GitHub
parent 8776ad1a0e
commit e0bedeb44c
13 changed files with 904 additions and 278 deletions

View File

@@ -4,23 +4,28 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/raw_ostream.h"
#include <limits>
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <atomic>
#include <limits>
namespace mlir {
namespace triton {
class AllocationAnalysis;
}
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
/// A class that represents an address range. The range is specified using
/// a start and an end address: [Start, End).
template <typename AddrT> class Range {
/// A class that represents a range, specified using a start and an end values:
/// [Start, End).
template <typename T> class Range {
public:
Range() {}
Range(AddrT S, AddrT E) : Start(S), End(E) { assert(Start <= End); }
AddrT start() const { return Start; }
AddrT end() const { return End; }
AddrT size() const { return End - Start; }
bool contains(AddrT Addr) const { return Start <= Addr && Addr < End; }
Range(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
T size() const { return End - Start; }
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
bool intersects(const Range &R) const {
return Start < R.End && R.Start < End;
}
@@ -33,83 +38,122 @@ public:
}
private:
AddrT Start = std::numeric_limits<AddrT>::min();
AddrT End = std::numeric_limits<AddrT>::max();
T Start = std::numeric_limits<T>::min();
T End = std::numeric_limits<T>::max();
};
//===----------------------------------------------------------------------===//
// Shared Memory Allocation Analysis
//===----------------------------------------------------------------------===//
class AllocationAnalysis {
class Allocation {
public:
using ValueSizeMapT = llvm::DenseMap<Value, size_t>;
/// A unique identifier for shared memory buffers
using BufferId = size_t;
static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();
public:
/// Creates a new Allocation analysis that computes the shared memory
/// information for all associated shared memory values.
AllocationAnalysis(Operation *operation) : operation(operation) { run(); }
Allocation(Operation *operation) : operation(operation) { run(); }
/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
/// Returns the offset of the given value in the shared memory.
size_t getOffset(Value value) const { return valueOffset.lookup(value); }
/// Returns the offset of the given buffer in the shared memory.
size_t getOffset(BufferId bufferId) const {
return bufferSet.lookup(bufferId).offset;
}
/// Returns the size of the given value in the shared memory.
size_t getAllocatedSize(Value value) const { return valueSize.lookup(value); }
/// Returns the size of the given buffer in the shared memory.
size_t getAllocatedSize(BufferId bufferId) const {
return bufferSet.lookup(bufferId).size;
}
/// Returns the buffer id of the given value.
BufferId getBufferId(Value value) const {
if (valueBuffer.count(value)) {
return valueBuffer.lookup(value)->id;
} else {
return InvalidBufferId;
}
}
/// Returns the scratch buffer id of the given value.
BufferId getBufferId(Operation *operation) const {
if (opScratch.count(operation)) {
return opScratch.lookup(operation)->id;
} else {
return InvalidBufferId;
}
}
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }
private:
/// Value -> Range
/// Use MapVector to ensure determinism.
using ValueRangeMapT = llvm::MapVector<Value, Range<size_t>>;
/// Start -> Range
using TripleMapT = std::multimap<size_t, Range<size_t>>;
/// Nodes -> Nodes
using GraphT = DenseMap<Value, DenseSet<Value>>;
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
return false;
auto lhsBuffer = bufferSet.lookup(lhsId);
auto rhsBuffer = bufferSet.lookup(rhsId);
return lhsBuffer.intersects(rhsBuffer);
}
private:
/// A class that represents a shared memory buffer
struct BufferT {
enum class BufferKind { Explicit, Scratch };
/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;
BufferKind kind;
BufferId id;
size_t size;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), size(size), offset(offset), id(nextId++) {}
bool intersects(const BufferT &other) const {
return Range<size_t>(offset, offset + size)
.intersects(Range<size_t>(other.offset, other.offset + other.size));
}
};
/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = DenseMap<Value, BufferT *>;
/// BufferId -> Buffer
using BufferSetT = DenseMap<BufferId, BufferT>;
/// Runs allocation analysis on the given top-level operation.
void run();
/// Resolves liveness of all values involved under the root operation.
void resolveLiveness(ValueRangeMapT &valueRangeMap);
/// Computes the shared memory offsets for all related values.
/// Paper: Algorithms for Compile-Time Memory Optimization
/// (https://www.cs.utexas.edu/users/harrison/papers/compile-time.pdf)
void computeOffsets(const ValueRangeMapT &valueRangeMap);
/// Gets shared memory value and size from valueRangeMap.
void getSharedMemoryValuesAndSizes(const ValueRangeMapT &valueRangeMap,
SmallVector<Value> &sharedMemoryValues);
/// Computes the initial shared memory offsets.
void calculateSharedMemoryStarts(const ValueRangeMapT &valueRangeMap,
const SmallVector<Value> &sharedMemoryValues,
ValueSizeMapT &sharedMemoryStart);
/// Builds a graph of all shared memory values. Edges are created between
/// between shared memory values that are overlapping.
void buildInterferenceGraph(const ValueRangeMapT &valueRangeMap,
const SmallVector<Value> &sharedMemoryValues,
const ValueSizeMapT &sharedMemoryStart,
GraphT &interference);
/// Finalizes shared memory offsets considering interference.
void allocateSharedMemory(const ValueRangeMapT &valueRangeMap,
const SmallVector<Value> &sharedMemoryValues,
const ValueSizeMapT &sharedMemoryStart,
const GraphT &interference);
private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&... args) {
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
bufferSet[buffer.id] = std::move(buffer);
if constexpr (Kind == BufferT::BufferKind::Explicit) {
valueBuffer[key] = &bufferSet[buffer.id];
} else {
opScratch[key] = &bufferSet[buffer.id];
}
}
private:
Operation *operation;
ValueSizeMapT valueOffset;
ValueSizeMapT valueSize;
OpScratchMapT opScratch;
ValueBufferMapT valueBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;
friend class triton::AllocationAnalysis;
};
} // namespace mlir
#endif
#endif // TRITON_ANALYSIS_ALLOCATION_H

View File

@@ -0,0 +1,115 @@
#ifndef TRITON_ANALYSIS_MEMBAR_H
#define TRITON_ANALYSIS_MEMBAR_H
#include "Allocation.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
class OpBuilder;
//===----------------------------------------------------------------------===//
// Shared Memory Barrier Analysis
//===----------------------------------------------------------------------===//
class MembarAnalysis {
public:
/// Creates a new Membar analysis that generates the shared memory barrier
/// in the following circumstances:
/// - RAW: If a shared memory write is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// - WAR: If a shared memory read is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// The following circumstances do not require a barrier:
/// - WAW: not possible because overlapped memory allocation is not allowed.
/// - RAR: no write is performed.
/// Temporary storage of operations such as Reduce are considered as both
/// a shared memory read. If the temporary storage is written but not read,
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
private:
struct RegionInfo {
using BufferIdSetT = DenseSet<Allocation::BufferId>;
BufferIdSetT syncReadBuffers;
BufferIdSetT syncWriteBuffers;
RegionInfo() = default;
RegionInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
}
/// Unions two RegionInfo objects.
void join(const RegionInfo &other) {
syncReadBuffers.insert(other.syncReadBuffers.begin(),
other.syncReadBuffers.end());
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
other.syncWriteBuffers.end());
}
/// Returns true if buffers in two RegionInfo objects are intersected.
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation);
}
/// Clears the buffers because a barrier is inserted.
void sync() {
syncReadBuffers.clear();
syncWriteBuffers.clear();
}
private:
/// Returns true if buffers in two sets are intersected.
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
Allocation *allocation) const {
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
return allocation->isIntersected(lhsId, rhsId);
});
});
}
};
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:
/// region1
/// op1
/// op2 (scf.if)
/// region2
/// op3
/// op4
/// region3
/// op5
/// op6
/// op7
/// region2 and region3 started with the information of region1.
/// Each region is analyzed separately and keeps their own copy of the
/// information. At op7, we union the information of the region2 and region3
/// and update the information of region1.
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
/// Updates the RegionInfo operation based on the operation.
void transfer(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
private:
Allocation *allocation;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_MEMBAR_H

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"

View File

@@ -228,13 +228,13 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
def TT_ReduceOp : TT_Op<"reduce"> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Type:$operand, I32Attr:$axis);
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Type:$result);
let results = (outs TT_Tensor:$result);
// let builders = [
// OpBuilder<(ins "triton::RedOp":$redOp, "value":$operand, "int":$axis)>,
// ];
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
}