More conversion patterns

This commit is contained in:
Yan Da
2022-05-04 12:50:02 +08:00
parent 5f08e2fdae
commit 3ad7bee35e
4 changed files with 69 additions and 20 deletions

View File

@@ -19,8 +19,23 @@ public:
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op.getOperation(), retType, op.getLhs(), op.getRhs()
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
return success();
}
};
template<class Op>
class ArithCmpPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
);
return success();
}
@@ -51,13 +66,13 @@ void populateArithmeticPatternsAndLegality(
// non-null encoding
// --------------
MLIRContext *context = patterns.getContext();
// Legality rule
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// TODO: check above rule here
[](Operation *op){
return true;
}
);
// // Legality rule
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// // TODO: check above rule here
// [](Operation *op){
// return true;
// }
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithBinaryPattern<arith::AddIOp>,
@@ -89,7 +104,10 @@ void populateArithmeticPatternsAndLegality(
// Floating point
ArithBinaryPattern<arith::MulFOp>,
ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>
ArithBinaryPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp>,
ArithCmpPattern<arith::CmpFOp>
>(typeConverter, context);
}
@@ -103,9 +121,8 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op.getOperation(), retType, op.start(), op.end()
op, retType, adaptor.start(), adaptor.end()
);
return success();
}
@@ -118,7 +135,7 @@ struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp>
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op.getOperation(), retType, op.src()
op, retType, adaptor.src()
);
return success();
}
@@ -131,7 +148,7 @@ struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::GEPOp>(
op.getOperation(), retType, op.ptr(), op.offset()
op, retType, adaptor.getOperands()
);
return success();
}
@@ -144,8 +161,21 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op.getOperation(), retType,
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
op, retType,
adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
);
return success();
}
};
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
);
return success();
}
@@ -158,7 +188,8 @@ void populateTritonPatterns(
patterns.add<TritonMakeRangePattern,
TritonBroadcastPattern,
TritonGEPPattern,
TritonLoadPattern
TritonLoadPattern,
TritonStorePattern
>(typeConverter, context);
}