Keren/tensor slice insert alloc (#94)

This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:

```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```

We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
This commit is contained in:
Keren Zhou
2022-09-01 12:37:17 -07:00
committed by GitHub
parent d01353de07
commit 328b87aec6
10 changed files with 260 additions and 40 deletions

View File

@@ -27,7 +27,7 @@ private:
/// Value -> Liveness Range
/// Use MapVector to ensure determinism.
using BufferRangeMapT = llvm::MapVector<BufferT *, Range<size_t>>;
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
/// Nodes -> Nodes
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
@@ -110,7 +110,7 @@ private:
/// Computes the liveness range of the allocated value.
/// Each buffer is allocated only once.
void resolveExplicitBufferLiveness(
function_ref<Range<size_t>(Value value)> getLiveness) {
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
@@ -122,7 +122,7 @@ private:
/// values because each allocated buffer could be an alias of others, if block
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Range<size_t>(Value value)> getLiveness) {
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
@@ -135,7 +135,7 @@ private:
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
}
bufferRange[buffer] = Range(minId, maxId);
bufferRange[buffer] = Interval(minId, maxId);
}
}
}
@@ -151,8 +151,8 @@ private:
// range.
auto *op = opScratchIter.first;
auto *buffer = opScratchIter.second;
bufferRange.insert(
{buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)});
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
}
@@ -179,7 +179,7 @@ private:
maxId = operationId[liveOp] + 1;
}
});
return Range(minId, maxId);
return Interval(minId, maxId);
};
resolveExplicitBufferLiveness(getValueLivenessRange);
@@ -223,9 +223,9 @@ private:
// |---------------------------------------------| liveness range
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
/// Start -> Liveness Range
using TripleMapT = std::multimap<size_t, Range<size_t>>;
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Range<size_t>()));
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
@@ -246,12 +246,12 @@ private:
auto xRange = bufferRange.lookup(buffer);
bufferStart[buffer] = size;
tripleMap.insert(
{size + xSize, Range{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
{size + xSize, Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
if (range.start() < xRange.start())
tripleMap.insert({size, Range{range.start(), xRange.end()}});
tripleMap.insert({size, Interval{range.start(), xRange.end()}});
if (xRange.end() < range.end())
tripleMap.insert({size, Range{xRange.start(), range.end()}});
tripleMap.insert({size, Interval{xRange.start(), range.end()}});
xBuffers.erase(bufferIt);
}
}
@@ -270,8 +270,8 @@ private:
auto yStart = bufferStart.lookup(y);
auto xSize = x->size;
auto ySize = y->size;
Range xSizeRange = {xStart, xStart + xSize};
Range ySizeRange = {yStart, yStart + ySize};
Interval xSizeRange = {xStart, xStart + xSize};
Interval ySizeRange = {yStart, yStart + ySize};
auto xOpRange = bufferRange.lookup(x);
auto yOpRange = bufferRange.lookup(y);
if (xOpRange.intersects(yOpRange) &&

View File

@@ -7,7 +7,6 @@ namespace mlir {
void MembarAnalysis::run() {
auto *operation = allocation->getOperation();
operation->getContext()->getOrLoadDialect<mlir::gpu::GPUDialect>();
RegionInfo regionInfo;
OpBuilder builder(operation);
dfsOperation(operation, &regionInfo, &builder);