Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview: ``` %tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A> %b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A> %c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A> ``` We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
This commit is contained in:
@@ -27,7 +27,7 @@ private:
|
||||
|
||||
/// Value -> Liveness Range
|
||||
/// Use MapVector to ensure determinism.
|
||||
using BufferRangeMapT = llvm::MapVector<BufferT *, Range<size_t>>;
|
||||
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
|
||||
/// Nodes -> Nodes
|
||||
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
|
||||
|
||||
@@ -110,7 +110,7 @@ private:
|
||||
/// Computes the liveness range of the allocated value.
|
||||
/// Each buffer is allocated only once.
|
||||
void resolveExplicitBufferLiveness(
|
||||
function_ref<Range<size_t>(Value value)> getLiveness) {
|
||||
function_ref<Interval<size_t>(Value value)> getLiveness) {
|
||||
for (auto valueBufferIter : allocation->valueBuffer) {
|
||||
auto value = valueBufferIter.first;
|
||||
auto *buffer = valueBufferIter.second;
|
||||
@@ -122,7 +122,7 @@ private:
|
||||
/// values because each allocated buffer could be an alias of others, if block
|
||||
/// arguments are involved.
|
||||
void resolveAliasBufferLiveness(
|
||||
function_ref<Range<size_t>(Value value)> getLiveness) {
|
||||
function_ref<Interval<size_t>(Value value)> getLiveness) {
|
||||
for (auto aliasBufferIter : allocation->aliasBuffer) {
|
||||
auto value = aliasBufferIter.first;
|
||||
auto buffers = aliasBufferIter.second;
|
||||
@@ -135,7 +135,7 @@ private:
|
||||
minId = std::min(minId, bufferRange[buffer].start());
|
||||
maxId = std::max(maxId, bufferRange[buffer].end());
|
||||
}
|
||||
bufferRange[buffer] = Range(minId, maxId);
|
||||
bufferRange[buffer] = Interval(minId, maxId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -151,8 +151,8 @@ private:
|
||||
// range.
|
||||
auto *op = opScratchIter.first;
|
||||
auto *buffer = opScratchIter.second;
|
||||
bufferRange.insert(
|
||||
{buffer, Range(operationId.lookup(op), operationId.lookup(op) + 1)});
|
||||
bufferRange.insert({buffer, Interval(operationId.lookup(op),
|
||||
operationId.lookup(op) + 1)});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +179,7 @@ private:
|
||||
maxId = operationId[liveOp] + 1;
|
||||
}
|
||||
});
|
||||
return Range(minId, maxId);
|
||||
return Interval(minId, maxId);
|
||||
};
|
||||
|
||||
resolveExplicitBufferLiveness(getValueLivenessRange);
|
||||
@@ -223,9 +223,9 @@ private:
|
||||
// |---------------------------------------------| liveness range
|
||||
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
|
||||
/// Start -> Liveness Range
|
||||
using TripleMapT = std::multimap<size_t, Range<size_t>>;
|
||||
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
|
||||
TripleMapT tripleMap;
|
||||
tripleMap.insert(std::make_pair(0, Range<size_t>()));
|
||||
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
|
||||
SmallVector<BufferT *> xBuffers = buffers;
|
||||
while (!xBuffers.empty()) {
|
||||
auto tripleIt = tripleMap.begin();
|
||||
@@ -246,12 +246,12 @@ private:
|
||||
auto xRange = bufferRange.lookup(buffer);
|
||||
bufferStart[buffer] = size;
|
||||
tripleMap.insert(
|
||||
{size + xSize, Range{std::max(range.start(), xRange.start()),
|
||||
std::min(range.end(), xRange.end())}});
|
||||
{size + xSize, Interval{std::max(range.start(), xRange.start()),
|
||||
std::min(range.end(), xRange.end())}});
|
||||
if (range.start() < xRange.start())
|
||||
tripleMap.insert({size, Range{range.start(), xRange.end()}});
|
||||
tripleMap.insert({size, Interval{range.start(), xRange.end()}});
|
||||
if (xRange.end() < range.end())
|
||||
tripleMap.insert({size, Range{xRange.start(), range.end()}});
|
||||
tripleMap.insert({size, Interval{xRange.start(), range.end()}});
|
||||
xBuffers.erase(bufferIt);
|
||||
}
|
||||
}
|
||||
@@ -270,8 +270,8 @@ private:
|
||||
auto yStart = bufferStart.lookup(y);
|
||||
auto xSize = x->size;
|
||||
auto ySize = y->size;
|
||||
Range xSizeRange = {xStart, xStart + xSize};
|
||||
Range ySizeRange = {yStart, yStart + ySize};
|
||||
Interval xSizeRange = {xStart, xStart + xSize};
|
||||
Interval ySizeRange = {yStart, yStart + ySize};
|
||||
auto xOpRange = bufferRange.lookup(x);
|
||||
auto yOpRange = bufferRange.lookup(y);
|
||||
if (xOpRange.intersects(yOpRange) &&
|
||||
|
@@ -7,7 +7,6 @@ namespace mlir {
|
||||
|
||||
void MembarAnalysis::run() {
|
||||
auto *operation = allocation->getOperation();
|
||||
operation->getContext()->getOrLoadDialect<mlir::gpu::GPUDialect>();
|
||||
RegionInfo regionInfo;
|
||||
OpBuilder builder(operation);
|
||||
dfsOperation(operation, ®ionInfo, &builder);
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
@@ -329,6 +330,72 @@ void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
|
||||
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type srcType, dstType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(srcType) || parser.parseArrow() ||
|
||||
parser.parseCustomTypeWithFallback(dstType))
|
||||
return failure();
|
||||
result.addTypes(dstType);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(srcType); // src
|
||||
operandTypes.push_back(dstType); // dst
|
||||
operandTypes.push_back(
|
||||
IntegerType::get(parser.getBuilder().getContext(), 32)); // offset
|
||||
if (allOperands.size() >= 4)
|
||||
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
||||
if (allOperands.size() >= 5)
|
||||
operandTypes.push_back(triton::getPointeeType(srcType)); // other
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
InsertSliceAsyncOp insertSliceAsyncOp) {
|
||||
printer << " ";
|
||||
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(),
|
||||
/*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
||||
printer << " -> ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
|
||||
::mlir::RegionRange regions,
|
||||
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
auto srcType = operands[0].getType().cast<RankedTensorType>();
|
||||
auto encoding = srcType.getEncoding();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (axis < 0 || axis > srcShape.size())
|
||||
return failure();
|
||||
auto dstShape = srcShape.drop_front(axis + 1);
|
||||
auto returnType =
|
||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||
inferredReturnTypes.assign({returnType});
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -372,13 +439,30 @@ void TritonGPUDialect::initialize() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding.isa<SharedEncodingAttr>())
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
} else
|
||||
return op.emitOpError("copy_async should return a tensor");
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertSliceAsyncOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractSliceOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("extract_slice should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AllocTensorOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("alloc_tensor should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user