Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can fix memory allocation bugs and save memory allocations in triton code involving complex control flows. Changes to memory bar and allocation are on the way. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -37,17 +39,22 @@ private:
|
||||
|
||||
/// Initializes explicitly defined shared memory values for a given operation.
|
||||
void getExplicitValueSize(Operation *op) {
|
||||
/// Values returned from scf.yield will not be allocated even though they
|
||||
/// have the shared encoding.
|
||||
/// For example: %a = scf.if -> yield
|
||||
/// %a must be allocated elsewhere by other operations.
|
||||
if (!maybeSharedAllocationOp(op)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
auto type = result.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorType.getEncoding();
|
||||
if (encoding && encoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
|
||||
}
|
||||
if (isSharedEncoding(result)) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,12 +74,86 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
|
||||
LatticeElement<AliasInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(value);
|
||||
if (latticeElement) {
|
||||
auto &info = latticeElement->getValue();
|
||||
if (!info.getAllocs().empty()) {
|
||||
for (auto alloc : info.getAllocs()) {
|
||||
allocation->addAlias(value, alloc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract all shared memory values and their sizes
|
||||
void getValuesAndSizes() {
|
||||
// Get the alloc values
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
getExplicitValueSize(op);
|
||||
getScratchValueSize(op);
|
||||
});
|
||||
// Get the alias values
|
||||
SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext());
|
||||
aliasAnalysis.run(operation);
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
getValueAlias(operand, aliasAnalysis);
|
||||
}
|
||||
for (auto value : op->getResults()) {
|
||||
getValueAlias(value, aliasAnalysis);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Computes the liveness range of the allocated value.
|
||||
/// Each buffer is allocated only once.
|
||||
void resolveExplicitBufferLiveness(
|
||||
function_ref<Range<size_t>(Value value)> getLiveness) {
|
||||
for (auto valueBufferIter : allocation->valueBuffer) {
|
||||
auto value = valueBufferIter.first;
|
||||
auto *buffer = valueBufferIter.second;
|
||||
bufferRange[buffer] = getLiveness(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends the liveness range by unioning the liveness range of the aliased
|
||||
/// values because each allocated buffer could be an alias of others, if block
|
||||
/// arguments are involved.
|
||||
void resolveAliasBufferLiveness(
|
||||
function_ref<Range<size_t>(Value value)> getLiveness) {
|
||||
for (auto aliasBufferIter : allocation->aliasBuffer) {
|
||||
auto value = aliasBufferIter.first;
|
||||
auto buffers = aliasBufferIter.second;
|
||||
auto range = getLiveness(value);
|
||||
for (auto *buffer : buffers) {
|
||||
auto minId = range.start();
|
||||
auto maxId = range.end();
|
||||
if (bufferRange.count(buffer)) {
|
||||
// Extend the allocated buffer's range
|
||||
minId = std::min(minId, bufferRange[buffer].start());
|
||||
maxId = std::max(maxId, bufferRange[buffer].end());
|
||||
}
|
||||
bufferRange[buffer] = Range(minId, maxId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the liveness range of scratched buffers.
|
||||
/// Some operations may have a temporary buffer that is not explicitly
|
||||
/// allocated, but is used to store intermediate results.
|
||||
void resolveScratchBufferLiveness(
|
||||
const DenseMap<Operation *, size_t> &operationId) {
|
||||
// Analyze liveness of scratch buffers
|
||||
for (auto opScratchIter : allocation->opScratch) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *op = opScratchIter.first;
|
||||
auto *buffer = opScratchIter.second;
|
||||
bufferRange.insert(
|
||||
{buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)});
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
@@ -83,34 +164,27 @@ private:
|
||||
operation->walk<WalkOrder::PreOrder>(
|
||||
[&](Operation *op) { operationId[op] = operationId.size(); });
|
||||
|
||||
// Analyze liveness of explicit buffers
|
||||
Liveness liveness(operation);
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
for (Value result : op->getResults()) {
|
||||
auto liveOperations = liveness.resolveLiveness(result);
|
||||
auto minId = std::numeric_limits<size_t>::max();
|
||||
auto maxId = std::numeric_limits<size_t>::min();
|
||||
std::for_each(liveOperations.begin(), liveOperations.end(),
|
||||
[&](Operation *liveOp) {
|
||||
if (operationId[liveOp] < minId) {
|
||||
minId = operationId[liveOp];
|
||||
}
|
||||
if (operationId[liveOp] > maxId) {
|
||||
maxId = operationId[liveOp];
|
||||
}
|
||||
});
|
||||
if (allocation->valueBuffer.count(result)) {
|
||||
auto *buffer = allocation->valueBuffer[result];
|
||||
bufferRange.insert({buffer, Range(minId, maxId + 1)});
|
||||
}
|
||||
}
|
||||
if (allocation->opScratch.count(op)) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *buffer = allocation->opScratch[op];
|
||||
bufferRange.insert(
|
||||
{buffer, Range(operationId[op], operationId[op] + 1)});
|
||||
}
|
||||
});
|
||||
auto getValueLivenessRange = [&](Value value) {
|
||||
auto liveOperations = liveness.resolveLiveness(value);
|
||||
auto minId = std::numeric_limits<size_t>::max();
|
||||
auto maxId = std::numeric_limits<size_t>::min();
|
||||
std::for_each(liveOperations.begin(), liveOperations.end(),
|
||||
[&](Operation *liveOp) {
|
||||
if (operationId[liveOp] < minId) {
|
||||
minId = operationId[liveOp];
|
||||
}
|
||||
if ((operationId[liveOp] + 1) > maxId) {
|
||||
maxId = operationId[liveOp] + 1;
|
||||
}
|
||||
});
|
||||
return Range(minId, maxId);
|
||||
};
|
||||
|
||||
resolveExplicitBufferLiveness(getValueLivenessRange);
|
||||
resolveAliasBufferLiveness(getValueLivenessRange);
|
||||
resolveScratchBufferLiveness(operationId);
|
||||
}
|
||||
|
||||
/// Computes the shared memory offsets for all related values.
|
||||
|
Reference in New Issue
Block a user