[OPTIMIZER] Improved layout simplification pass so it handles swizzled layouts better (#789)

Note: uncommented `test_gemm`, since backend has an issue with swizzling. This will get uncommented in a subsequent PR.
This commit is contained in:
Philippe Tillet
2022-10-20 19:03:37 -07:00
committed by GitHub
parent 0d22d2bc03
commit dc0588a898
7 changed files with 68 additions and 28 deletions

View File

@@ -71,7 +71,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
// convert operands
SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands()) {
if (v.getType().isa<RankedTensorType>())
auto vTy = v.getType().dyn_cast<RankedTensorType>();
if (vTy && !vTy.getEncoding().isa<triton::gpu::SharedEncodingAttr>())
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), convertType(v.getType()), v));
else

View File

@@ -56,7 +56,40 @@ public:
// block argument
if (!arg)
return mlir::failure();
// cvt(type2, cvt(type1, x)) -> cvt(type2, x)
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
// cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2))
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
}
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
auto newType = op->getResult(0).getType();
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.dst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.src(), new_arg.getResult(),
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
insert_slice.axis());
return mlir::success();
}
// cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
if (extract_slice) {
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.src());
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
return mlir::success();
}
// cvt(type2, x)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));

View File

@@ -50,8 +50,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
int maxPhase = mmaStride / perPhase;
std::cout << perPhase << " " << mat_shape[0] << " " << mat_shape[1]
<< " " << mat_shape[2] << std::endl;
return SwizzleInfo{vec, perPhase, maxPhase};
}
// compute swizzling for B operand