diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index a354993a8..5574b3ffd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -306,6 +306,8 @@ public: auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt); auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt); + if (!iterArg.value().getType().isa()) + continue; if (fwdCvtIt != fwdEndIt) { auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(), (*fwdCvtIt)->getResult(0).getType());