diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 788d20eaa..b580f5971 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -242,10 +242,17 @@ struct TritonLoadPattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), adaptor.ptr(), - adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), - adaptor.isVolatile()); + if (op.getNumOperands() == 2) { // ptr & mask + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), adaptor.ptr(), + adaptor.getOperands()[1], adaptor.other(), adaptor.cache(), + adaptor.evict(), adaptor.isVolatile()); + } else { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), adaptor.ptr(), + adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(), + adaptor.isVolatile()); + } return success(); } };