[TRITON-MLIR][BACKEND] Atomic support mask (#889)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -5921,11 +5921,13 @@ struct AtomicRMWOpConversion
|
|||||||
|
|
||||||
Value llPtr = adaptor.ptr();
|
Value llPtr = adaptor.ptr();
|
||||||
Value llVal = adaptor.val();
|
Value llVal = adaptor.val();
|
||||||
|
Value llMask = adaptor.mask();
|
||||||
|
|
||||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||||
|
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||||
|
|
||||||
// TODO[dongdongl]: Support mask and scalar
|
// TODO[dongdongl]: Support scalar
|
||||||
|
|
||||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!valueTy)
|
if (!valueTy)
|
||||||
@@ -5940,6 +5942,14 @@ struct AtomicRMWOpConversion
|
|||||||
|
|
||||||
auto vecTy = vec_ty(valueElemTy, vec);
|
auto vecTy = vec_ty(valueElemTy, vec);
|
||||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||||
|
// mask
|
||||||
|
Value mask = int_val(1, 1);
|
||||||
|
auto shape = valueTy.getShape();
|
||||||
|
auto numElements = product(shape);
|
||||||
|
auto tid = tid_val();
|
||||||
|
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||||
|
i32_val(numElements)));
|
||||||
|
|
||||||
SmallVector<Value> resultVals(elemsPerThread);
|
SmallVector<Value> resultVals(elemsPerThread);
|
||||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||||
Value rmwVal = undef(vecTy);
|
Value rmwVal = undef(vecTy);
|
||||||
@@ -5949,6 +5959,8 @@ struct AtomicRMWOpConversion
|
|||||||
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
|
||||||
}
|
}
|
||||||
Value rmwPtr = ptrElements[i];
|
Value rmwPtr = ptrElements[i];
|
||||||
|
Value rmwMask = maskElements[i];
|
||||||
|
rmwMask = and_(rmwMask, mask);
|
||||||
std::string sTy;
|
std::string sTy;
|
||||||
PTXBuilder ptxBuilder;
|
PTXBuilder ptxBuilder;
|
||||||
|
|
||||||
@@ -5996,9 +6008,7 @@ struct AtomicRMWOpConversion
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
atom.o(rmwOp).o(sTy);
|
atom.o(rmwOp).o(sTy);
|
||||||
//TODO:[dongdongl] actual mask support
|
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||||
Value pred = int_val(1, 1);
|
|
||||||
atom(dstOpr, ptrOpr, valOpr).predicate(pred);
|
|
||||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||||
for (int ii = 0; ii < vec; ++ii) {
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
resultVals[i * vec + ii] =
|
resultVals[i * vec + ii] =
|
||||||
|
@@ -672,7 +672,7 @@ def test_tuples():
|
|||||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
||||||
shape0, shape1 = 16, 16
|
shape0, shape1 = 2, 8
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||||
off0 = tl.arange(0, SHAPE0)
|
off0 = tl.arange(0, SHAPE0)
|
||||||
|
Reference in New Issue
Block a user