[OPTIMIZER] Improved layout simplification pass so it handles swizzled layouts better (#789)
Note: uncommented `test_gemm`, since backend has an issue with swizzling. This will get uncommented in a subsequent PR.
This commit is contained in:
@@ -71,7 +71,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
// convert operands
|
||||
SmallVector<Value, 4> newArgs;
|
||||
for (auto v : op->getOperands()) {
|
||||
if (v.getType().isa<RankedTensorType>())
|
||||
auto vTy = v.getType().dyn_cast<RankedTensorType>();
|
||||
if (vTy && !vTy.getEncoding().isa<triton::gpu::SharedEncodingAttr>())
|
||||
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), convertType(v.getType()), v));
|
||||
else
|
||||
|
@@ -56,7 +56,40 @@ public:
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(type2, cvt(type1, x)) -> cvt(type2, x)
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
// cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
auto newType = op->getResult(0).getType();
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.dst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.src(), new_arg.getResult(),
|
||||
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
||||
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
|
||||
insert_slice.axis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.src());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
|
||||
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type2, x)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
|
@@ -50,8 +50,6 @@ struct SwizzlePass : public TritonGPUSwizzleBase<SwizzlePass> {
|
||||
int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
std::cout << perPhase << " " << mat_shape[0] << " " << mat_shape[1]
|
||||
<< " " << mat_shape[2] << std::endl;
|
||||
return SwizzleInfo{vec, perPhase, maxPhase};
|
||||
}
|
||||
// compute swizzling for B operand
|
||||
|
Reference in New Issue
Block a user