From 5f08e2fdae5cd53d125dfd777b3138dc7faa946b Mon Sep 17 00:00:00 2001 From: Yan Da Date: Mon, 2 May 2022 22:31:29 +0800 Subject: [PATCH] More arith patterns --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 18ebd035d..7be905fab 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -11,6 +11,21 @@ using namespace mlir::triton; namespace { +template +class ArithBinaryPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp( + op.getOperation(), retType, op.getLhs(), op.getRhs() + ); + return success(); + } +}; + class ConvertArithmeticOp: public ConversionPattern { public: ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) @@ -44,7 +59,38 @@ void populateArithmeticPatternsAndLegality( } ); // Rewrite rule - patterns.add(typeConverter, context); + // patterns.add(typeConverter, context); + patterns.add, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, // NegFOp + // Floating point + ArithBinaryPattern, + ArithBinaryPattern, + // MaxMin + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + // Floating point + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern + >(typeConverter, context); } //