Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can fix memory allocation bugs and save memory allocations in triton code involving complex control flows. Changes to memory bar and allocation are on the way. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -46,6 +46,12 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
|
||||
if (dyn_cast<scf::ForOp>(op) || dyn_cast<scf::IfOp>(op) ||
|
||||
dyn_cast<scf::YieldOp>(op)) {
|
||||
// Do not insert barriers before control flow operations.
|
||||
return;
|
||||
}
|
||||
|
||||
if (dyn_cast<gpu::BarrierOp>(op)) {
|
||||
// If the current op is a barrier, we sync previous reads and writes
|
||||
regionInfo->sync();
|
||||
@@ -62,24 +68,28 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
auto addBuffer = [&](RegionInfo::BufferIdSetT &bufferSet,
|
||||
Allocation::BufferId bufferId) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
bufferSet.insert(bufferId);
|
||||
}
|
||||
};
|
||||
|
||||
RegionInfo curRegionInfo;
|
||||
for (Value value : op->getOperands()) {
|
||||
// ConvertLayoutOp: shared memory -> registers
|
||||
addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(value));
|
||||
// Need to consider all alias buffers
|
||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Value value : op->getResults()) {
|
||||
// ConvertLayoutOp: registers -> shared memory
|
||||
addBuffer(curRegionInfo.syncWriteBuffers, allocation->getBufferId(value));
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncWriteBuffers.insert(bufferId);
|
||||
}
|
||||
}
|
||||
// Scratch buffer is considered as a shared memory read
|
||||
addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(op));
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
|
||||
if (regionInfo->isIntersected(curRegionInfo, allocation)) {
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
|
Reference in New Issue
Block a user