Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can fix memory allocation bugs and save memory allocations in triton code involving complex control flows. Changes to memory bar and allocation are on the way. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
66
lib/Analysis/Alias.cpp
Normal file
66
lib/Analysis/Alias.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
AliasInfo ret;
|
||||
for (auto value : lhs.allocs) {
|
||||
ret.insert(value);
|
||||
}
|
||||
for (auto value : rhs.allocs) {
|
||||
ret.insert(value);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<AliasInfo> *> operands) {
|
||||
AliasInfo aliasInfo;
|
||||
bool pessimistic = true;
|
||||
if (maybeSharedAllocationOp(op)) {
|
||||
// These ops will allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
if (isSharedEncoding(result)) {
|
||||
aliasInfo.insert(result);
|
||||
pessimistic = false;
|
||||
}
|
||||
} else {
|
||||
llvm::errs() << "op: " << op->getName() << "\n";
|
||||
}
|
||||
// XXX(Keren): triton ops don't support aliasing yet.
|
||||
// else if (auto viewOp = dyn_cast<triton::ViewOp>(op) ||
|
||||
// dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
// // These ops will reate a new view of the same shared memory buffer.
|
||||
// auto result = op->getResult(0);
|
||||
// if (isSharedEncoding(result)) {
|
||||
// aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
// pessimistic = false;
|
||||
// }
|
||||
//}
|
||||
if (pessimistic) {
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// Join all latice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(aliasInfo);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
|
||||
// TODO: implement
|
||||
return AliasResult::MayAlias;
|
||||
}
|
||||
|
||||
ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op,
|
||||
Value location) {
|
||||
// TODO: implement
|
||||
return ModRefResult::getModAndRef();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
@@ -1,6 +1,8 @@
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -37,17 +39,22 @@ private:
|
||||
|
||||
/// Initializes explicitly defined shared memory values for a given operation.
|
||||
void getExplicitValueSize(Operation *op) {
|
||||
/// Values returned from scf.yield will not be allocated even though they
|
||||
/// have the shared encoding.
|
||||
/// For example: %a = scf.if -> yield
|
||||
/// %a must be allocated elsewhere by other operations.
|
||||
if (!maybeSharedAllocationOp(op)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
auto type = result.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorType.getEncoding();
|
||||
if (encoding && encoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
|
||||
}
|
||||
if (isSharedEncoding(result)) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,12 +74,86 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
|
||||
LatticeElement<AliasInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(value);
|
||||
if (latticeElement) {
|
||||
auto &info = latticeElement->getValue();
|
||||
if (!info.getAllocs().empty()) {
|
||||
for (auto alloc : info.getAllocs()) {
|
||||
allocation->addAlias(value, alloc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract all shared memory values and their sizes
|
||||
void getValuesAndSizes() {
|
||||
// Get the alloc values
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
getExplicitValueSize(op);
|
||||
getScratchValueSize(op);
|
||||
});
|
||||
// Get the alias values
|
||||
SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext());
|
||||
aliasAnalysis.run(operation);
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
getValueAlias(operand, aliasAnalysis);
|
||||
}
|
||||
for (auto value : op->getResults()) {
|
||||
getValueAlias(value, aliasAnalysis);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Computes the liveness range of the allocated value.
|
||||
/// Each buffer is allocated only once.
|
||||
void resolveExplicitBufferLiveness(
|
||||
function_ref<Range<size_t>(Value value)> getLiveness) {
|
||||
for (auto valueBufferIter : allocation->valueBuffer) {
|
||||
auto value = valueBufferIter.first;
|
||||
auto *buffer = valueBufferIter.second;
|
||||
bufferRange[buffer] = getLiveness(value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends the liveness range by unioning the liveness range of the aliased
|
||||
/// values because each allocated buffer could be an alias of others, if block
|
||||
/// arguments are involved.
|
||||
void resolveAliasBufferLiveness(
|
||||
function_ref<Range<size_t>(Value value)> getLiveness) {
|
||||
for (auto aliasBufferIter : allocation->aliasBuffer) {
|
||||
auto value = aliasBufferIter.first;
|
||||
auto buffers = aliasBufferIter.second;
|
||||
auto range = getLiveness(value);
|
||||
for (auto *buffer : buffers) {
|
||||
auto minId = range.start();
|
||||
auto maxId = range.end();
|
||||
if (bufferRange.count(buffer)) {
|
||||
// Extend the allocated buffer's range
|
||||
minId = std::min(minId, bufferRange[buffer].start());
|
||||
maxId = std::max(maxId, bufferRange[buffer].end());
|
||||
}
|
||||
bufferRange[buffer] = Range(minId, maxId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the liveness range of scratched buffers.
|
||||
/// Some operations may have a temporary buffer that is not explicitly
|
||||
/// allocated, but is used to store intermediate results.
|
||||
void resolveScratchBufferLiveness(
|
||||
const DenseMap<Operation *, size_t> &operationId) {
|
||||
// Analyze liveness of scratch buffers
|
||||
for (auto opScratchIter : allocation->opScratch) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *op = opScratchIter.first;
|
||||
auto *buffer = opScratchIter.second;
|
||||
bufferRange.insert(
|
||||
{buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)});
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
@@ -83,34 +164,27 @@ private:
|
||||
operation->walk<WalkOrder::PreOrder>(
|
||||
[&](Operation *op) { operationId[op] = operationId.size(); });
|
||||
|
||||
// Analyze liveness of explicit buffers
|
||||
Liveness liveness(operation);
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
for (Value result : op->getResults()) {
|
||||
auto liveOperations = liveness.resolveLiveness(result);
|
||||
auto minId = std::numeric_limits<size_t>::max();
|
||||
auto maxId = std::numeric_limits<size_t>::min();
|
||||
std::for_each(liveOperations.begin(), liveOperations.end(),
|
||||
[&](Operation *liveOp) {
|
||||
if (operationId[liveOp] < minId) {
|
||||
minId = operationId[liveOp];
|
||||
}
|
||||
if (operationId[liveOp] > maxId) {
|
||||
maxId = operationId[liveOp];
|
||||
}
|
||||
});
|
||||
if (allocation->valueBuffer.count(result)) {
|
||||
auto *buffer = allocation->valueBuffer[result];
|
||||
bufferRange.insert({buffer, Range(minId, maxId + 1)});
|
||||
}
|
||||
}
|
||||
if (allocation->opScratch.count(op)) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *buffer = allocation->opScratch[op];
|
||||
bufferRange.insert(
|
||||
{buffer, Range(operationId[op], operationId[op] + 1)});
|
||||
}
|
||||
});
|
||||
auto getValueLivenessRange = [&](Value value) {
|
||||
auto liveOperations = liveness.resolveLiveness(value);
|
||||
auto minId = std::numeric_limits<size_t>::max();
|
||||
auto maxId = std::numeric_limits<size_t>::min();
|
||||
std::for_each(liveOperations.begin(), liveOperations.end(),
|
||||
[&](Operation *liveOp) {
|
||||
if (operationId[liveOp] < minId) {
|
||||
minId = operationId[liveOp];
|
||||
}
|
||||
if ((operationId[liveOp] + 1) > maxId) {
|
||||
maxId = operationId[liveOp] + 1;
|
||||
}
|
||||
});
|
||||
return Range(minId, maxId);
|
||||
};
|
||||
|
||||
resolveExplicitBufferLiveness(getValueLivenessRange);
|
||||
resolveAliasBufferLiveness(getValueLivenessRange);
|
||||
resolveScratchBufferLiveness(operationId);
|
||||
}
|
||||
|
||||
/// Computes the shared memory offsets for all related values.
|
||||
|
@@ -2,6 +2,8 @@ add_mlir_library(TritonAnalysis
|
||||
AxisInfo.cpp
|
||||
Allocation.cpp
|
||||
Membar.cpp
|
||||
Alias.cpp
|
||||
Utility.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUAttrDefsIncGen
|
||||
|
@@ -46,6 +46,12 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
|
||||
if (dyn_cast<scf::ForOp>(op) || dyn_cast<scf::IfOp>(op) ||
|
||||
dyn_cast<scf::YieldOp>(op)) {
|
||||
// Do not insert barriers before control flow operations.
|
||||
return;
|
||||
}
|
||||
|
||||
if (dyn_cast<gpu::BarrierOp>(op)) {
|
||||
// If the current op is a barrier, we sync previous reads and writes
|
||||
regionInfo->sync();
|
||||
@@ -62,24 +68,28 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
auto addBuffer = [&](RegionInfo::BufferIdSetT &bufferSet,
|
||||
Allocation::BufferId bufferId) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
bufferSet.insert(bufferId);
|
||||
}
|
||||
};
|
||||
|
||||
RegionInfo curRegionInfo;
|
||||
for (Value value : op->getOperands()) {
|
||||
// ConvertLayoutOp: shared memory -> registers
|
||||
addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(value));
|
||||
// Need to consider all alias buffers
|
||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Value value : op->getResults()) {
|
||||
// ConvertLayoutOp: registers -> shared memory
|
||||
addBuffer(curRegionInfo.syncWriteBuffers, allocation->getBufferId(value));
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncWriteBuffers.insert(bufferId);
|
||||
}
|
||||
}
|
||||
// Scratch buffer is considered as a shared memory read
|
||||
addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(op));
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
|
||||
if (regionInfo->isIntersected(curRegionInfo, allocation)) {
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
|
38
lib/Analysis/Utility.cpp
Normal file
38
lib/Analysis/Utility.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorType.getEncoding();
|
||||
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op) {
|
||||
// TODO(Keren): This function can be replaced by adding
|
||||
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
|
||||
// query the memory effects of the op.
|
||||
auto *dialect = op->getDialect();
|
||||
return dialect &&
|
||||
(dialect->getTypeID() ==
|
||||
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
||||
dialect->getTypeID() ==
|
||||
mlir::TypeID::get<arith::ArithmeticDialect>());
|
||||
}
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state) {
|
||||
auto *op = value.getDefiningOp();
|
||||
std::string opName;
|
||||
llvm::raw_string_ostream ss(opName);
|
||||
value.printAsOperand(ss, state);
|
||||
return std::move(opName);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
Reference in New Issue
Block a user