diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 18ebd035d..7be905fab 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -11,6 +11,21 @@ using namespace mlir::triton; namespace { +template +class ArithBinaryPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp( + op.getOperation(), retType, op.getLhs(), op.getRhs() + ); + return success(); + } +}; + class ConvertArithmeticOp: public ConversionPattern { public: ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) @@ -44,7 +59,38 @@ void populateArithmeticPatternsAndLegality( } ); // Rewrite rule - patterns.add(typeConverter, context); + // patterns.add(typeConverter, context); + patterns.add, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, // NegFOp + // Floating point + ArithBinaryPattern, + ArithBinaryPattern, + // MaxMin + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern, + // Floating point + ArithBinaryPattern, + ArithBinaryPattern, + ArithBinaryPattern + >(typeConverter, context); } //