More progress on TritonGPU conversion
This commit is contained in:
@@ -26,15 +26,15 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class Op>
|
||||
class ArithCmpPattern : public OpConversionPattern<Op> {
|
||||
template<class SrcOp, class DstOp>
|
||||
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
|
||||
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||
);
|
||||
return success();
|
||||
@@ -106,8 +106,10 @@ void populateArithmeticPatternsAndLegality(
|
||||
ArithBinaryPattern<arith::DivFOp>,
|
||||
ArithBinaryPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp>
|
||||
// ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
// ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user