Analyze shared memory alias (#81)

The purpose of this PR is analyzing shared memory aliases so that we can
fix memory allocation bugs and save memory allocations in triton code
involving complex control flows.

Changes to memory bar and allocation are on the way.

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2022-08-29 10:43:20 -07:00
committed by GitHub
parent 83287d7193
commit 02ebf24d35
15 changed files with 761 additions and 61 deletions

View File

@@ -46,6 +46,12 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
if (op->getNumResults() < 1)
return;
if (dyn_cast<scf::ForOp>(op) || dyn_cast<scf::IfOp>(op) ||
dyn_cast<scf::YieldOp>(op)) {
// Do not insert barriers before control flow operations.
return;
}
if (dyn_cast<gpu::BarrierOp>(op)) {
// If the current op is a barrier, we sync previous reads and writes
regionInfo->sync();
@@ -62,24 +68,28 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
return;
}
auto addBuffer = [&](RegionInfo::BufferIdSetT &bufferSet,
Allocation::BufferId bufferId) {
if (bufferId != Allocation::InvalidBufferId) {
bufferSet.insert(bufferId);
}
};
RegionInfo curRegionInfo;
for (Value value : op->getOperands()) {
// ConvertLayoutOp: shared memory -> registers
addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(value));
// Need to consider all alias buffers
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncReadBuffers.insert(bufferId);
}
}
}
for (Value value : op->getResults()) {
// ConvertLayoutOp: registers -> shared memory
addBuffer(curRegionInfo.syncWriteBuffers, allocation->getBufferId(value));
auto bufferId = allocation->getBufferId(value);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncWriteBuffers.insert(bufferId);
}
}
// Scratch buffer is considered as a shared memory read
addBuffer(curRegionInfo.syncReadBuffers, allocation->getBufferId(op));
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncReadBuffers.insert(bufferId);
}
if (regionInfo->isIntersected(curRegionInfo, allocation)) {
OpBuilder::InsertionGuard g(*builder);