Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can fix memory allocation bugs and save memory allocations in triton code involving complex control flows. Changes to memory bar and allocation are on the way. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
80
include/triton/Analysis/Alias.h
Normal file
80
include/triton/Analysis/Alias.h
Normal file
@@ -0,0 +1,80 @@
|
||||
#ifndef TRITON_ANALYSIS_ALIAS_H
|
||||
#define TRITON_ANALYSIS_ALIAS_H
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AliasInfo {
|
||||
public:
|
||||
AliasInfo() = default;
|
||||
AliasInfo(Value value) { insert(value); }
|
||||
|
||||
void insert(Value value) { allocs.insert(value); }
|
||||
|
||||
const DenseSet<Value> &getAllocs() const { return allocs; }
|
||||
|
||||
bool operator==(const AliasInfo &other) const {
|
||||
return allocs == other.allocs;
|
||||
}
|
||||
|
||||
/// The pessimistic value state of a value without alias
|
||||
static AliasInfo getPessimisticValueState(MLIRContext *context) {
|
||||
return AliasInfo();
|
||||
}
|
||||
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
|
||||
|
||||
/// The union of both arguments
|
||||
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The set of allocated values that are aliased by this lattice.
|
||||
/// For now, we only consider aliased value produced by the following
|
||||
/// situations:
|
||||
/// 1. values returned by scf.yield
|
||||
/// 2. block arguments in scf.for
|
||||
/// Example:
|
||||
/// alloc v1 alloc v2
|
||||
/// | |
|
||||
/// |--------------| |------------|
|
||||
/// scf.for v3 scf.for v4 scf.for v5
|
||||
/// |
|
||||
/// scf.yield v6
|
||||
///
|
||||
/// v1's alloc [v1]
|
||||
/// v2's alloc [v2]
|
||||
/// v3's alloc [v1]
|
||||
/// v4's alloc [v1, v2]
|
||||
/// v5's alloc [v2]
|
||||
/// v6's alloc [v1]
|
||||
///
|
||||
/// Therefore, v1's liveness range is the union of v3, v4, and v6
|
||||
/// v2's liveness range is the union of v4 and v5.
|
||||
DenseSet<Value> allocs;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Alias Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
||||
/// Given two values, returns their aliasing behavior.
|
||||
AliasResult alias(Value lhs, Value rhs);
|
||||
|
||||
/// Returns the modify-reference behavior of `op` on `location`.
|
||||
ModRefResult getModRef(Operation *op, Value location);
|
||||
|
||||
/// Computes if the alloc set of the results are changed.
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_ALIAS_H
|
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
@@ -46,6 +47,8 @@ class Allocation {
|
||||
public:
|
||||
/// A unique identifier for shared memory buffers
|
||||
using BufferId = size_t;
|
||||
using BufferIdSetT = DenseSet<BufferId>;
|
||||
|
||||
static constexpr BufferId InvalidBufferId =
|
||||
std::numeric_limits<BufferId>::max();
|
||||
|
||||
@@ -67,6 +70,9 @@ public:
|
||||
}
|
||||
|
||||
/// Returns the buffer id of the given value.
|
||||
/// This interface only returns the allocated buffer id.
|
||||
/// If you want to get all the buffer ids that are associated with the given
|
||||
/// value, including alias buffers, use getBufferIds.
|
||||
BufferId getBufferId(Value value) const {
|
||||
if (valueBuffer.count(value)) {
|
||||
return valueBuffer.lookup(value)->id;
|
||||
@@ -75,6 +81,19 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all the buffer ids of the given value, including alias buffers.
|
||||
BufferIdSetT getBufferIds(Value value) const {
|
||||
BufferIdSetT bufferIds;
|
||||
auto allocBufferId = getBufferId(value);
|
||||
if (allocBufferId != InvalidBufferId)
|
||||
bufferIds.insert(allocBufferId);
|
||||
for (auto *buffer : aliasBuffer.lookup(value)) {
|
||||
if (buffer->id != InvalidBufferId)
|
||||
bufferIds.insert(buffer->id);
|
||||
}
|
||||
return bufferIds;
|
||||
}
|
||||
|
||||
/// Returns the scratch buffer id of the given value.
|
||||
BufferId getBufferId(Operation *operation) const {
|
||||
if (opScratch.count(operation)) {
|
||||
@@ -126,7 +145,9 @@ private:
|
||||
/// Op -> Scratch Buffer
|
||||
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
|
||||
/// Value -> Explicit Buffer
|
||||
using ValueBufferMapT = DenseMap<Value, BufferT *>;
|
||||
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
|
||||
/// Value -> Alias Buffer
|
||||
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
|
||||
/// BufferId -> Buffer
|
||||
using BufferSetT = DenseMap<BufferId, BufferT>;
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
@@ -144,10 +165,15 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
void addAlias(Value value, Value alloc) {
|
||||
aliasBuffer[value].insert(valueBuffer[alloc]);
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
OpScratchMapT opScratch;
|
||||
ValueBufferMapT valueBuffer;
|
||||
AliasBufferMapT aliasBuffer;
|
||||
BufferSetT bufferSet;
|
||||
size_t sharedMemorySize = 0;
|
||||
|
||||
|
@@ -33,7 +33,7 @@ public:
|
||||
|
||||
private:
|
||||
struct RegionInfo {
|
||||
using BufferIdSetT = DenseSet<Allocation::BufferId>;
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
|
||||
BufferIdSetT syncReadBuffers;
|
||||
BufferIdSetT syncWriteBuffers;
|
||||
|
16
include/triton/Analysis/Utility.h
Normal file
16
include/triton/Analysis/Utility.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef TRITON_ANALYSIS_UTILITY_H
|
||||
#define TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <string>
|
||||
namespace mlir {
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
Reference in New Issue
Block a user