DotOp conversion

This commit is contained in:
Yan Da
2022-05-04 15:56:24 +08:00
parent 2d281cbc0a
commit a96fe07e1c

View File

@@ -144,27 +144,14 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
}
};
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op, retType, adaptor.src()
);
return success();
}
};
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::GEPOp>(
op, retType, adaptor.getOperands()
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32()
);
return success();
}
@@ -197,13 +184,29 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
}
};
template <class Op>
struct TritonGenericPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
return success();
}
};
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonMakeRangePattern,
TritonBroadcastPattern,
TritonGEPPattern,
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>,
TritonMakeRangePattern,
TritonDotPattern,
TritonLoadPattern,
TritonStorePattern
>(typeConverter, context);