[TRITON-MLIR][BACKEND] Atomic support mask (#889)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-19 19:57:19 +08:00
committed by GitHub
parent dab4855bdf
commit afaf59b0c9
2 changed files with 15 additions and 5 deletions

View File

@@ -5921,11 +5921,13 @@ struct AtomicRMWOpConversion
Value llPtr = adaptor.ptr(); Value llPtr = adaptor.ptr();
Value llVal = adaptor.val(); Value llVal = adaptor.val();
Value llMask = adaptor.mask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter); auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
// TODO[dongdongl]: Support mask and scalar // TODO[dongdongl]: Support scalar
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>(); auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy) if (!valueTy)
@@ -5940,6 +5942,14 @@ struct AtomicRMWOpConversion
auto vecTy = vec_ty(valueElemTy, vec); auto vecTy = vec_ty(valueElemTy, vec);
auto elemsPerThread = getElemsPerThread(val.getType()); auto elemsPerThread = getElemsPerThread(val.getType());
// mask
Value mask = int_val(1, 1);
auto shape = valueTy.getShape();
auto numElements = product(shape);
auto tid = tid_val();
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
SmallVector<Value> resultVals(elemsPerThread); SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) { for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy); Value rmwVal = undef(vecTy);
@@ -5949,6 +5959,8 @@ struct AtomicRMWOpConversion
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
} }
Value rmwPtr = ptrElements[i]; Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
std::string sTy; std::string sTy;
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
@@ -5996,9 +6008,7 @@ struct AtomicRMWOpConversion
return failure(); return failure();
} }
atom.o(rmwOp).o(sTy); atom.o(rmwOp).o(sTy);
//TODO:[dongdongl] actual mask support atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
Value pred = int_val(1, 1);
atom(dstOpr, ptrOpr, valOpr).predicate(pred);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
for (int ii = 0; ii < vec; ++ii) { for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] = resultVals[i * vec + ii] =

View File

@@ -672,7 +672,7 @@ def test_tuples():
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) # np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_add_elementwise(device="cuda"): def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
shape0, shape1 = 16, 16 shape0, shape1 = 2, 8
@triton.jit @triton.jit
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0) off0 = tl.arange(0, SHAPE0)