Fix ReduceOp conversion
This commit is contained in:
@@ -224,6 +224,19 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
@@ -231,6 +244,7 @@ void populateTritonPatterns(
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>,
|
||||
TritonReducePattern,
|
||||
TritonMakeRangePattern,
|
||||
TritonDotPattern,
|
||||
TritonLoadPattern,
|
||||
|
Reference in New Issue
Block a user