More arith patterns
This commit is contained in:
@@ -11,6 +11,21 @@ using namespace mlir::triton;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template<class Op>
|
||||||
|
class ArithBinaryPattern : public OpConversionPattern<Op> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<Op>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<Op>(
|
||||||
|
op.getOperation(), retType, op.getLhs(), op.getRhs()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ConvertArithmeticOp: public ConversionPattern {
|
class ConvertArithmeticOp: public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||||
@@ -44,7 +59,38 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
// Rewrite rule
|
// Rewrite rule
|
||||||
patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||||
|
patterns.add<ArithBinaryPattern<arith::AddIOp>,
|
||||||
|
ArithBinaryPattern<arith::SubIOp>,
|
||||||
|
ArithBinaryPattern<arith::MulIOp>,
|
||||||
|
ArithBinaryPattern<arith::DivUIOp>,
|
||||||
|
ArithBinaryPattern<arith::DivSIOp>,
|
||||||
|
ArithBinaryPattern<arith::CeilDivUIOp>,
|
||||||
|
ArithBinaryPattern<arith::CeilDivSIOp>,
|
||||||
|
ArithBinaryPattern<arith::FloorDivSIOp>,
|
||||||
|
ArithBinaryPattern<arith::RemUIOp>,
|
||||||
|
ArithBinaryPattern<arith::RemSIOp>,
|
||||||
|
ArithBinaryPattern<arith::AndIOp>,
|
||||||
|
ArithBinaryPattern<arith::OrIOp>,
|
||||||
|
ArithBinaryPattern<arith::XOrIOp>,
|
||||||
|
ArithBinaryPattern<arith::ShLIOp>,
|
||||||
|
ArithBinaryPattern<arith::ShRUIOp>,
|
||||||
|
ArithBinaryPattern<arith::ShRSIOp>, // NegFOp
|
||||||
|
// Floating point
|
||||||
|
ArithBinaryPattern<arith::AddFOp>,
|
||||||
|
ArithBinaryPattern<arith::SubFOp>,
|
||||||
|
// MaxMin
|
||||||
|
ArithBinaryPattern<arith::MaxFOp>,
|
||||||
|
ArithBinaryPattern<arith::MaxSIOp>,
|
||||||
|
ArithBinaryPattern<arith::MaxUIOp>,
|
||||||
|
ArithBinaryPattern<arith::MinFOp>,
|
||||||
|
ArithBinaryPattern<arith::MinSIOp>,
|
||||||
|
ArithBinaryPattern<arith::MinUIOp>,
|
||||||
|
// Floating point
|
||||||
|
ArithBinaryPattern<arith::MulFOp>,
|
||||||
|
ArithBinaryPattern<arith::DivFOp>,
|
||||||
|
ArithBinaryPattern<arith::RemFOp>
|
||||||
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
Reference in New Issue
Block a user