[TRITON-MLIR][BACKEND] Atomic support mask (#889)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-19 19:57:19 +08:00
committed by GitHub
parent dab4855bdf
commit afaf59b0c9
2 changed files with 15 additions and 5 deletions

View File

@@ -5921,11 +5921,13 @@ struct AtomicRMWOpConversion
Value llPtr = adaptor.ptr();
Value llVal = adaptor.val();
Value llMask = adaptor.mask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
// TODO[dongdongl]: Support mask and scalar
// TODO[dongdongl]: Support scalar
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy)
@@ -5940,6 +5942,14 @@ struct AtomicRMWOpConversion
auto vecTy = vec_ty(valueElemTy, vec);
auto elemsPerThread = getElemsPerThread(val.getType());
// mask
Value mask = int_val(1, 1);
auto shape = valueTy.getShape();
auto numElements = product(shape);
auto tid = tid_val();
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy);
@@ -5949,6 +5959,8 @@ struct AtomicRMWOpConversion
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
std::string sTy;
PTXBuilder ptxBuilder;
@@ -5996,9 +6008,7 @@ struct AtomicRMWOpConversion
return failure();
}
atom.o(rmwOp).o(sTy);
//TODO:[dongdongl] actual mask support
Value pred = int_val(1, 1);
atom(dstOpr, ptrOpr, valOpr).predicate(pred);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] =