diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d8cb20ae7..2cbb47e85 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -5958,19 +5958,19 @@ struct AtomicRMWOpConversion auto elemsPerThread = getElemsPerThread(val.getType()); SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { - Value rmvVal = undef(vecTy); + Value rmwVal = undef(vecTy); for (int ii = 0; ii < vec; ++ii) { Value iiVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), ii); - rmvVal = insert_element(vecTy, rmvVal, valElements[i], iiVal); + rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); } - Value rmwPtr = bitcast(ptrElements[i], ptr_ty(valTy.getElementType())); + Value rmwPtr = ptrElements[i]; std::string sTy; PTXBuilder ptxBuilder; auto *dstOpr = ptxBuilder.newOperand("=r"); - auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r"); - auto *valOpr = ptxBuilder.newOperand(rmvVal, "r"); + auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l"); + auto *valOpr = ptxBuilder.newOperand(rmwVal, "r"); auto &atom = ptxBuilder.create<>("atom")->global().o("gpu"); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); @@ -6012,12 +6012,13 @@ struct AtomicRMWOpConversion return failure(); } atom.o(rmwOp).o(sTy); - - atom(dstOpr, ptrOpr, valOpr); - auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy, false); + //TODO:[dongdongl] actual mask support + Value pred = int_val(1, 1); + atom(dstOpr, ptrOpr, valOpr).predicate(pred); + auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); for (int ii = 0; ii < vec; ++ii) { resultVals[i * vec + ii] = - vec == 1 ? ret : extract_element(vecTy, ret, idx_val(ii)); + vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii)); } } Type structTy = getTypeConverter()->convertType(valueTy); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 4f4d56847..d4473f5b8 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -378,7 +378,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonReducePattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, - TritonPrintfPattern>(typeConverter, context); + TritonPrintfPattern, TritonAtomicRMWPattern>(typeConverter, context); } // diff --git a/python/src/triton.cc b/python/src/triton.cc index d2294210e..5a0d25732 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1086,9 +1086,7 @@ void init_triton_ir(py::module &&m) { mlir::Value &ptr, mlir::Value &val, mlir::Value &mask) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = mlir::getElementTypeOrSelf(ptr) - .cast(); - mlir::Type dstType = ptrType.getPointeeType(); + mlir::Type dstType = val.getType(); return self.create(loc, dstType, rmwOp, ptr, val, mask); }) diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 7b994c3cb..59cc8d755 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -671,6 +671,25 @@ def test_tuples(): # kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) # np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) +def test_tensor_atomic_rmw_add_elementwise(device="cuda"): + shape0, shape1 = 16, 16 + @triton.jit + def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference + z_ref = z + x + # triton result + x_tri = torch.from_numpy(x).to(device=device) + z_tri = torch.from_numpy(z).to(device=device) + kernel[(1,)](z_tri, x_tri, shape0, shape1) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) # def test_atomic_cas(): # # 1. make sure that atomic_cas changes the original value (Lock)