From e6f89a5777e09f5c1850b79d0327a06fa29d40d7 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 25 May 2022 16:03:06 +0800 Subject: [PATCH] Fix ReduceOp conversion --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 224ddd22a..dc720a0d0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -224,6 +224,19 @@ struct TritonGenericPattern : public OpConversionPattern { } }; +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + auto newOp = rewriter.replaceOpWithNewOp( + op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis() + ); + return success(); + } +}; + void populateTritonPatterns( TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns ) { @@ -231,6 +244,7 @@ void populateTritonPatterns( patterns.add, TritonGenericPattern, TritonGenericPattern, + TritonReducePattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,