[TRITON-MLIR][BACKEND]fix atomic_rmw for vector (#966)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -4420,6 +4420,7 @@ struct AtomicRMWOpConversion
|
|||||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||||
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value rmwPtr = ptrElements[i];
|
Value rmwPtr = ptrElements[i];
|
||||||
Value rmwMask = maskElements[i];
|
Value rmwMask = maskElements[i];
|
||||||
rmwMask = and_(rmwMask, mask);
|
rmwMask = and_(rmwMask, mask);
|
||||||
@@ -4477,9 +4478,10 @@ struct AtomicRMWOpConversion
|
|||||||
atom.o(rmwOp).o(sTy);
|
atom.o(rmwOp).o(sTy);
|
||||||
if (valueTy) {
|
if (valueTy) {
|
||||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||||
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||||
|
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
|
||||||
for (int ii = 0; ii < vec; ++ii) {
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
resultVals[i * vec + ii] =
|
resultVals[i + ii] =
|
||||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
Reference in New Issue
Block a user