From afaf59b0c9bbab3d9a3b9cc71c028b6fcb0e04a4 Mon Sep 17 00:00:00 2001 From: donproc Date: Sat, 19 Nov 2022 19:57:19 +0800 Subject: [PATCH] [TRITON-MLIR][BACKEND] Atomic support mask (#889) Co-authored-by: dongdongl --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 18 ++++++++++++++---- python/tests/test_core.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 137512c69..5804c06e3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -5921,11 +5921,13 @@ struct AtomicRMWOpConversion Value llPtr = adaptor.ptr(); Value llVal = adaptor.val(); + Value llMask = adaptor.mask(); auto valElements = getElementsFromStruct(loc, llVal, rewriter); auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); + auto maskElements = getElementsFromStruct(loc, llMask, rewriter); - // TODO[dongdongl]: Support mask and scalar + // TODO[dongdongl]: Support scalar auto valueTy = op.getResult().getType().dyn_cast(); if (!valueTy) @@ -5940,6 +5942,14 @@ struct AtomicRMWOpConversion auto vecTy = vec_ty(valueElemTy, vec); auto elemsPerThread = getElemsPerThread(val.getType()); + // mask + Value mask = int_val(1, 1); + auto shape = valueTy.getShape(); + auto numElements = product(shape); + auto tid = tid_val(); + mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), + i32_val(numElements))); + SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwVal = undef(vecTy); @@ -5949,6 +5959,8 @@ struct AtomicRMWOpConversion rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); } Value rmwPtr = ptrElements[i]; + Value rmwMask = maskElements[i]; + rmwMask = and_(rmwMask, mask); std::string sTy; PTXBuilder ptxBuilder; @@ -5996,9 +6008,7 @@ struct AtomicRMWOpConversion return failure(); } atom.o(rmwOp).o(sTy); - //TODO:[dongdongl] actual mask support - Value pred = int_val(1, 1); - atom(dstOpr, ptrOpr, valOpr).predicate(pred); + atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); for (int ii = 0; ii < vec; ++ii) { resultVals[i * vec + ii] = diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 59cc8d755..5f46f2517 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -672,7 +672,7 @@ def test_tuples(): # np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) def test_tensor_atomic_rmw_add_elementwise(device="cuda"): - shape0, shape1 = 16, 16 + shape0, shape1 = 2, 8 @triton.jit def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): off0 = tl.arange(0, SHAPE0)