diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 91485c3e5..fd4dfcf28 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4420,6 +4420,7 @@ struct AtomicRMWOpConversion rewriter, loc, getTypeConverter()->getIndexType(), ii); rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); } + Value rmwPtr = ptrElements[i]; Value rmwMask = maskElements[i]; rmwMask = and_(rmwMask, mask); @@ -4477,9 +4478,10 @@ struct AtomicRMWOpConversion atom.o(rmwOp).o(sTy); if (valueTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); + auto retType = vec == 1 ? valueElemTy : vecTy; + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); for (int ii = 0; ii < vec; ++ii) { - resultVals[i * vec + ii] = + resultVals[i + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); } } else {