DotOp conversion
This commit is contained in:
@@ -144,27 +144,14 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
|
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||||
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||||
op, retType, adaptor.src()
|
op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32()
|
||||||
);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
|
||||||
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
|
||||||
op, retType, adaptor.getOperands()
|
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -197,13 +184,29 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class Op>
|
||||||
|
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||||
|
using OpConversionPattern<Op>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<Op>(
|
||||||
|
op, retType, adaptor.getOperands()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonPatterns(
|
void populateTritonPatterns(
|
||||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||||
) {
|
) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<TritonMakeRangePattern,
|
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||||
TritonBroadcastPattern,
|
TritonGenericPattern<triton::BroadcastOp>,
|
||||||
TritonGEPPattern,
|
TritonGenericPattern<triton::GEPOp>,
|
||||||
|
TritonMakeRangePattern,
|
||||||
|
TritonDotPattern,
|
||||||
TritonLoadPattern,
|
TritonLoadPattern,
|
||||||
TritonStorePattern
|
TritonStorePattern
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
|
Reference in New Issue
Block a user