[Triton-MLIR][FRONTEND] [BACKEND] fix atomics (#879)
minor fix to backend and frontend of atomics, we can pass 1 test without mask and the shape aligned with CTA size now Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -5958,19 +5958,19 @@ struct AtomicRMWOpConversion
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value rmvVal = undef(vecTy);
|
||||
Value rmwVal = undef(vecTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
Value iiVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||
rmvVal = insert_element(vecTy, rmvVal, valElements[i], iiVal);
|
||||
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
||||
}
|
||||
Value rmwPtr = bitcast(ptrElements[i], ptr_ty(valTy.getElementType()));
|
||||
Value rmwPtr = ptrElements[i];
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilder;
|
||||
|
||||
auto *dstOpr = ptxBuilder.newOperand("=r");
|
||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
|
||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, "r");
|
||||
|
||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
@@ -6012,12 +6012,13 @@ struct AtomicRMWOpConversion
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
|
||||
atom(dstOpr, ptrOpr, valOpr);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy, false);
|
||||
//TODO:[dongdongl] actual mask support
|
||||
Value pred = int_val(1, 1);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(pred);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
vec == 1 ? ret : extract_element(vecTy, ret, idx_val(ii));
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
}
|
||||
}
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
|
Reference in New Issue
Block a user