DotOp conversion
This commit is contained in:
@@ -144,27 +144,14 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
|
||||
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
|
||||
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||
op, retType, adaptor.src()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
||||
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
||||
op, retType, adaptor.getOperands()
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -197,13 +184,29 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
}
|
||||
};
|
||||
|
||||
template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonMakeRangePattern,
|
||||
TritonBroadcastPattern,
|
||||
TritonGEPPattern,
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>,
|
||||
TritonMakeRangePattern,
|
||||
TritonDotPattern,
|
||||
TritonLoadPattern,
|
||||
TritonStorePattern
|
||||
>(typeConverter, context);
|
||||
|
Reference in New Issue
Block a user