DotOp conversion

This commit is contained in:
Yan Da
2022-05-04 15:56:24 +08:00
parent 2d281cbc0a
commit a96fe07e1c

View File

@@ -144,27 +144,14 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
} }
}; };
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> { struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern; using OpConversionPattern<triton::DotOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType()); Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>( rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, adaptor.src() op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32()
);
return success();
}
};
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::GEPOp>(
op, retType, adaptor.getOperands()
); );
return success(); return success();
} }
@@ -197,13 +184,29 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
} }
}; };
template <class Op>
struct TritonGenericPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
return success();
}
};
void populateTritonPatterns( void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) { ) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
patterns.add<TritonMakeRangePattern, patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonBroadcastPattern, TritonGenericPattern<triton::BroadcastOp>,
TritonGEPPattern, TritonGenericPattern<triton::GEPOp>,
TritonMakeRangePattern,
TritonDotPattern,
TritonLoadPattern, TritonLoadPattern,
TritonStorePattern TritonStorePattern
>(typeConverter, context); >(typeConverter, context);