more work

This commit is contained in:
Phil Tillet
2023-01-09 15:45:06 -08:00
parent 6c750b6856
commit 8ebb593bbb
5 changed files with 179 additions and 125 deletions

View File

@@ -1255,6 +1255,59 @@ public:
}
};
// Convert + trans + convert
// x = convert_layout distributed -> #shared_x
// y = trans x -> #shared_y
// z = convert_layout y -> #dot_operand
class ConvertTransConvert : public mlir::RewritePattern {
public:
ConvertTransConvert(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
LogicalResult matchAndRewrite(mlir::Operation* op,
mlir::PatternRewriter &rewriter) const override {
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
auto tmpOp = dyn_cast_or_null<triton::TransOp>(dstOp.src().getDefiningOp());
if(!tmpOp)
return mlir::failure();
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(tmpOp.src().getDefiningOp());
if(!srcOp)
return mlir::failure();
auto arg = srcOp.src();
auto X = tmpOp.src();
auto Y = dstOp.src();
// types
auto argType = arg.getType().cast<RankedTensorType>();
auto XType = X.getType().cast<RankedTensorType>();
auto YType = Y.getType().cast<RankedTensorType>();
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
// encodings
auto argEncoding = argType.getEncoding();
auto XEncoding = XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto YEncoding = YType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto ZEncoding = ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if(!ZEncoding)
return mlir::failure();
// new X encoding
auto newXOrder = triton::gpu::getOrder(argEncoding);
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XType.getElementType());
auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(),
newXEncoding);
if(XEncoding == newXEncoding)
return mlir::failure();
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(), newXType, arg);
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType, newY);
return mlir::success();
}
};
// Correct the versionMinor field in MmaEncodingAttr for Volta.
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
@@ -1423,6 +1476,7 @@ public:
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
patterns.add<ConvertTransConvert>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();