[TRITON-MLIR][BACKEND]support atomic_cas (#914)

1. support atomics-cas
2. add xchg support in atomic_rmw

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-25 12:02:08 +08:00
committed by GitHub
parent 153aecb339
commit f63be0e9b5
6 changed files with 150 additions and 15 deletions

View File

@@ -700,6 +700,16 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
# serialized_add[(64,)](data, Lock)
# triton.testing.assert_almost_equal(data, ref)
def test_simple_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
change_value[(1,)](Lock)
assert (Lock[0] == 1)
# # ---------------
# # test cast