diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index bb1af4b24..1a6512648 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -144,27 +144,14 @@ struct TritonMakeRangePattern : public OpConversionPattern } }; -struct TritonBroadcastPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, retType, adaptor.src() - ); - return success(); - } -}; - -struct TritonGEPPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, retType, adaptor.getOperands() + rewriter.replaceOpWithNewOp( + op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32() ); return success(); } @@ -197,13 +184,29 @@ struct TritonStorePattern : public OpConversionPattern { } }; +template +struct TritonGenericPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands() + ); + return success(); + } +}; + void populateTritonPatterns( TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns ) { MLIRContext *context = patterns.getContext(); - patterns.add, + TritonGenericPattern, + TritonGenericPattern, + TritonMakeRangePattern, + TritonDotPattern, TritonLoadPattern, TritonStorePattern >(typeConverter, context);