[OPTIMIZER] Improved layout simplification pass so it handles swizzled layouts better (#789)

Note: uncommented `test_gemm`, since backend has an issue with swizzling. This will get uncommented in a subsequent PR.
This commit is contained in:
Philippe Tillet
2022-10-20 19:03:37 -07:00
committed by GitHub
parent 0d22d2bc03
commit dc0588a898
7 changed files with 68 additions and 28 deletions

View File

@@ -56,7 +56,40 @@ public:
// block argument
if (!arg)
return mlir::failure();
// cvt(type2, cvt(type1, x)) -> cvt(type2, x)
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
// cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2))
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
}
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
auto newType = op->getResult(0).getType();
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.dst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.src(), new_arg.getResult(),
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
insert_slice.axis());
return mlir::success();
}
// cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
if (extract_slice) {
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.src());
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
return mlir::success();
}
// cvt(type2, x)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));