Fix ReduceOp conversion
This commit is contained in:
@@ -224,6 +224,19 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||||
|
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
|
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||||
|
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonPatterns(
|
void populateTritonPatterns(
|
||||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||||
) {
|
) {
|
||||||
@@ -231,6 +244,7 @@ void populateTritonPatterns(
|
|||||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||||
TritonGenericPattern<triton::BroadcastOp>,
|
TritonGenericPattern<triton::BroadcastOp>,
|
||||||
TritonGenericPattern<triton::GEPOp>,
|
TritonGenericPattern<triton::GEPOp>,
|
||||||
|
TritonReducePattern,
|
||||||
TritonMakeRangePattern,
|
TritonMakeRangePattern,
|
||||||
TritonDotPattern,
|
TritonDotPattern,
|
||||||
TritonLoadPattern,
|
TritonLoadPattern,
|
||||||
|
Reference in New Issue
Block a user