From c7cf9c6a323baa83f41a6ad2c26364717e1af8b5 Mon Sep 17 00:00:00 2001 From: Dongdong Li Date: Thu, 8 Dec 2022 20:03:40 +0800 Subject: [PATCH] [TRITON-MLIR][BACKEND]fix atomic_rmw for vector (#966) Co-authored-by: dongdongl --- lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 91485c3e5..fd4dfcf28 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4420,6 +4420,7 @@ struct AtomicRMWOpConversion rewriter, loc, getTypeConverter()->getIndexType(), ii); rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); } + Value rmwPtr = ptrElements[i]; Value rmwMask = maskElements[i]; rmwMask = and_(rmwMask, mask); @@ -4477,9 +4478,10 @@ struct AtomicRMWOpConversion atom.o(rmwOp).o(sTy); if (valueTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); + auto retType = vec == 1 ? valueElemTy : vecTy; + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); for (int ii = 0; ii < vec; ++ii) { - resultVals[i * vec + ii] = + resultVals[i + ii] = vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); } } else {