More progress on TritonGPU conversion

This commit is contained in:
Yan Da
2022-05-04 14:54:31 +08:00
parent 3ad7bee35e
commit b9279d2e3b
4 changed files with 48 additions and 26 deletions

View File

@@ -26,15 +26,15 @@ public:
}
};
template<class Op>
class ArithCmpPattern : public OpConversionPattern<Op> {
template<class SrcOp, class DstOp>
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
using OpConversionPattern<SrcOp>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
);
return success();
@@ -106,8 +106,10 @@ void populateArithmeticPatternsAndLegality(
ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp>,
ArithCmpPattern<arith::CmpFOp>
// ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
// ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
>(typeConverter, context);
}