[Analysis/Allocation] Allocation passes now assumes that slices always alias (#108)
This code in this branch assumes the `src` operand in `insert_slice_async` always aliases the result, which shouldn't hold for generally cases but is just a workaround to make the pipeline pass work. I'm also working on the complete analysis in another [branch](https://github.com/openai/triton-mlir/tree/keren/analyze-slice).
This commit is contained in:
@@ -292,44 +292,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CopyAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseCopyAsyncOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type resultTypes[1], ptrType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(ptrType) || parser.parseArrow() ||
|
||||
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
||||
return failure();
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(ptrType); // ptr
|
||||
if (allOperands.size() >= 2)
|
||||
operandTypes.push_back(triton::getI1SameShape(ptrType)); // mask
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(triton::getPointeeType(ptrType)); // other
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
|
||||
printer << " ";
|
||||
printer << copyAsyncOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(copyAsyncOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(copyAsyncOp.ptr().getType());
|
||||
printer << " -> ";
|
||||
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -350,7 +312,7 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
operandTypes.push_back(srcType); // src
|
||||
operandTypes.push_back(dstType); // dst
|
||||
operandTypes.push_back(
|
||||
IntegerType::get(parser.getBuilder().getContext(), 32)); // offset
|
||||
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
|
||||
if (allOperands.size() >= 4)
|
||||
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
||||
if (allOperands.size() >= 5)
|
||||
@@ -389,6 +351,8 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (axis < 0 || axis > srcShape.size())
|
||||
return failure();
|
||||
// Since we only extract a slice from a certain index on the axis,
|
||||
// the dims before the axis can be dropped.
|
||||
auto dstShape = srcShape.drop_front(axis + 1);
|
||||
auto returnType =
|
||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||
@@ -438,16 +402,10 @@ void TritonGPUDialect::initialize() {
|
||||
// Verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertSliceAsyncOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
return op.emitOpError(
|
||||
"insert_slice_async should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@@ -69,7 +69,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
// convert output types
|
||||
SmallVector<Type, 4> newTypes;
|
||||
for (auto t : op->getResultTypes()) {
|
||||
bool is_async = std::is_same<T, triton::gpu::CopyAsyncOp>::value;
|
||||
bool is_async = std::is_same<T, triton::gpu::InsertSliceAsyncOp>::value;
|
||||
newTypes.push_back(is_async ? t : convertType(t));
|
||||
}
|
||||
// construct new op with the new encoding
|
||||
@@ -106,9 +106,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
builder.setInsertionPoint(curr);
|
||||
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
||||
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
||||
if (auto load = dyn_cast<triton::gpu::CopyAsyncOp>(curr))
|
||||
coalesceOp<triton::gpu::CopyAsyncOp>(axisInfo, curr, load.ptr(),
|
||||
builder);
|
||||
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
|
||||
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
|
||||
builder);
|
||||
if (auto store = dyn_cast<triton::StoreOp>(curr))
|
||||
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
|
||||
});
|
||||
|
@@ -119,7 +119,9 @@ public:
|
||||
return mlir::failure();
|
||||
|
||||
auto blacklist = [](Operation *op) {
|
||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(op))
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp>(
|
||||
op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp>(op))
|
||||
return true;
|
||||
|
@@ -275,7 +275,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
|
||||
}
|
||||
// load => copy async
|
||||
// TODO: check if the hardware supports copyasync
|
||||
// TODO: check if the hardware supports async copy
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
|
||||
op->getLoc(), loadsBuffer[loadOp].getType(),
|
||||
|
Reference in New Issue
Block a user