diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 224ddd22a..dc720a0d0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -224,6 +224,19 @@ struct TritonGenericPattern : public OpConversionPattern { } }; +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + auto newOp = rewriter.replaceOpWithNewOp( + op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis() + ); + return success(); + } +}; + void populateTritonPatterns( TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns ) { @@ -231,6 +244,7 @@ void populateTritonPatterns( patterns.add, TritonGenericPattern, TritonGenericPattern, + TritonReducePattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,