.
This commit is contained in:
@@ -1266,14 +1266,16 @@ public:
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::Operation* op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto tmpOp = dyn_cast_or_null<triton::TransOp>(dstOp.src().getDefiningOp());
|
||||
if(!tmpOp)
|
||||
if (!tmpOp)
|
||||
return mlir::failure();
|
||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(tmpOp.src().getDefiningOp());
|
||||
if(!srcOp)
|
||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
tmpOp.src().getDefiningOp());
|
||||
if (!srcOp)
|
||||
return mlir::failure();
|
||||
auto arg = srcOp.src();
|
||||
auto X = tmpOp.src();
|
||||
@@ -1285,25 +1287,74 @@ public:
|
||||
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||
// encodings
|
||||
auto argEncoding = argType.getEncoding();
|
||||
auto XEncoding = XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto YEncoding = YType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto ZEncoding = ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if(!ZEncoding)
|
||||
auto XEncoding =
|
||||
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto YEncoding =
|
||||
YType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto ZEncoding =
|
||||
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!ZEncoding)
|
||||
return mlir::failure();
|
||||
// new X encoding
|
||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||
XType.getElementType());
|
||||
auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(),
|
||||
newXEncoding);
|
||||
if(XEncoding == newXEncoding)
|
||||
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||
XType.getElementType(), newXEncoding);
|
||||
if (XEncoding == newXEncoding)
|
||||
return mlir::failure();
|
||||
|
||||
|
||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(), newXType, arg);
|
||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
||||
newXType, arg);
|
||||
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType, newY);
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
||||
newY);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
ConvertDotConvert(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto dotOp = dyn_cast_or_null<triton::DotOp>(dstOp.src().getDefiningOp());
|
||||
if (!dotOp)
|
||||
return mlir::failure();
|
||||
if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 ||
|
||||
std::distance(dotOp->user_begin(), dotOp->user_end()) != 1)
|
||||
return mlir::failure();
|
||||
auto cvtOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
dotOp.getOperand(2).getDefiningOp());
|
||||
if (!cvtOp)
|
||||
return mlir::failure();
|
||||
auto loadOp = dyn_cast_or_null<triton::LoadOp>(cvtOp.src().getDefiningOp());
|
||||
if (!loadOp)
|
||||
return mlir::failure();
|
||||
auto dstTy = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcTy = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
if (dstTy != srcTy)
|
||||
return mlir::failure();
|
||||
|
||||
// TODO: int tensor cores
|
||||
auto _0f = rewriter.create<arith::ConstantFloatOp>(
|
||||
op->getLoc(), APFloat(0.0f), dstTy.getElementType().cast<FloatType>());
|
||||
auto _0 = rewriter.create<triton::SplatOp>(
|
||||
op->getLoc(), dotOp.getResult().getType(), _0f);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0),
|
||||
dotOp.getOperand(1), _0, dotOp.allowTF32());
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), dstTy, newDot.getResult());
|
||||
auto newAdd = rewriter.replaceOpWithNewOp<arith::AddFOp>(
|
||||
op, newCvt, cvtOp.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
@@ -1477,6 +1528,7 @@ public:
|
||||
patterns.add<MoveConvertOutOfIf>(context);
|
||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
patterns.add<ConvertDotConvert>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
|
Reference in New Issue
Block a user