More arith patterns

This commit is contained in:
Yan Da
2022-05-02 22:31:29 +08:00
parent 75d32e2442
commit 5f08e2fdae

View File

@@ -11,6 +11,21 @@ using namespace mlir::triton;
namespace {
template<class Op>
class ArithBinaryPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op.getOperation(), retType, op.getLhs(), op.getRhs()
);
return success();
}
};
class ConvertArithmeticOp: public ConversionPattern {
public:
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
@@ -44,7 +59,38 @@ void populateArithmeticPatternsAndLegality(
}
);
// Rewrite rule
patterns.add<ConvertArithmeticOp>(typeConverter, context);
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithBinaryPattern<arith::AddIOp>,
ArithBinaryPattern<arith::SubIOp>,
ArithBinaryPattern<arith::MulIOp>,
ArithBinaryPattern<arith::DivUIOp>,
ArithBinaryPattern<arith::DivSIOp>,
ArithBinaryPattern<arith::CeilDivUIOp>,
ArithBinaryPattern<arith::CeilDivSIOp>,
ArithBinaryPattern<arith::FloorDivSIOp>,
ArithBinaryPattern<arith::RemUIOp>,
ArithBinaryPattern<arith::RemSIOp>,
ArithBinaryPattern<arith::AndIOp>,
ArithBinaryPattern<arith::OrIOp>,
ArithBinaryPattern<arith::XOrIOp>,
ArithBinaryPattern<arith::ShLIOp>,
ArithBinaryPattern<arith::ShRUIOp>,
ArithBinaryPattern<arith::ShRSIOp>, // NegFOp
// Floating point
ArithBinaryPattern<arith::AddFOp>,
ArithBinaryPattern<arith::SubFOp>,
// MaxMin
ArithBinaryPattern<arith::MaxFOp>,
ArithBinaryPattern<arith::MaxSIOp>,
ArithBinaryPattern<arith::MaxUIOp>,
ArithBinaryPattern<arith::MinFOp>,
ArithBinaryPattern<arith::MinSIOp>,
ArithBinaryPattern<arith::MinUIOp>,
// Floating point
ArithBinaryPattern<arith::MulFOp>,
ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>
>(typeConverter, context);
}
//