ConstantOp conversion pattern

This commit is contained in:
Yan Da
2022-05-04 15:35:43 +08:00
parent b9279d2e3b
commit 2d281cbc0a
3 changed files with 34 additions and 21 deletions

View File

@@ -41,6 +41,20 @@ public:
}
};
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
public:
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retType, adaptor.getValue()
);
return success();
}
};
class ConvertArithmeticOp: public ConversionPattern {
public:
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
@@ -75,7 +89,8 @@ void populateArithmeticPatternsAndLegality(
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithBinaryPattern<arith::AddIOp>,
patterns.add<ArithConstantPattern,
ArithBinaryPattern<arith::AddIOp>,
ArithBinaryPattern<arith::SubIOp>,
ArithBinaryPattern<arith::MulIOp>,
ArithBinaryPattern<arith::DivUIOp>,
@@ -106,10 +121,9 @@ void populateArithmeticPatternsAndLegality(
ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>,
// Cmp
// ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
// ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
// Cast Ops
>(typeConverter, context);
}
@@ -205,7 +219,7 @@ public:
ModuleOp mod = getOperation();
// int numThreads = mod.getAttr();
// type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);