Fix ReduceOp conversion

This commit is contained in:
Yan Da
2022-05-25 16:03:06 +08:00
parent 9b670cfb9f
commit e6f89a5777

View File

@@ -224,6 +224,19 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
}
};
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
);
return success();
}
};
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
@@ -231,6 +244,7 @@ void populateTritonPatterns(
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>,
TritonReducePattern,
TritonMakeRangePattern,
TritonDotPattern,
TritonLoadPattern,