[Triton-MLIR][FRONTEND] [BACKEND] fix atomics (#879)
minor fix to backend and frontend of atomics, we can pass 1 test without mask and the shape aligned with CTA size now Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -671,6 +671,25 @@ def test_tuples():
|
||||
# kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
# np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
def test_tensor_atomic_rmw_add_elementwise(device="cuda"):
|
||||
shape0, shape1 = 16, 16
|
||||
@triton.jit
|
||||
def kernel(Z, X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
tl.atomic_add(Z + off0[:, None] * SHAPE1 + off1[None, :], x)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
z = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# reference
|
||||
z_ref = z + x
|
||||
# triton result
|
||||
x_tri = torch.from_numpy(x).to(device=device)
|
||||
z_tri = torch.from_numpy(z).to(device=device)
|
||||
kernel[(1,)](z_tri, x_tri, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
# def test_atomic_cas():
|
||||
# # 1. make sure that atomic_cas changes the original value (Lock)
|
||||
|
Reference in New Issue
Block a user