ConstantOp conversion pattern
This commit is contained in:
@@ -41,6 +41,20 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retType, adaptor.getValue()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertArithmeticOp: public ConversionPattern {
|
||||
public:
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||
@@ -75,7 +89,8 @@ void populateArithmeticPatternsAndLegality(
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<ArithBinaryPattern<arith::AddIOp>,
|
||||
patterns.add<ArithConstantPattern,
|
||||
ArithBinaryPattern<arith::AddIOp>,
|
||||
ArithBinaryPattern<arith::SubIOp>,
|
||||
ArithBinaryPattern<arith::MulIOp>,
|
||||
ArithBinaryPattern<arith::DivUIOp>,
|
||||
@@ -106,10 +121,9 @@ void populateArithmeticPatternsAndLegality(
|
||||
ArithBinaryPattern<arith::DivFOp>,
|
||||
ArithBinaryPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
// ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
// ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||
// Cast Ops
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
@@ -205,7 +219,7 @@ public:
|
||||
ModuleOp mod = getOperation();
|
||||
// int numThreads = mod.getAttr();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/32);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
|
Reference in New Issue
Block a user