From 205a493b10a5112ec1fccdbe9d59fe9f172e027d Mon Sep 17 00:00:00 2001 From: daadaada Date: Sun, 22 May 2022 00:45:54 +0800 Subject: [PATCH] [FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520) Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas --- python/test/unit/language/test_core.py | 32 ++++++++++++++++++++++++++ python/triton/language/core.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 22c6f99f4..952922f6b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -514,9 +514,41 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) +def test_atomic_cas(): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) + change_value[(1,)](Lock) + + assert(Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1,), device='cuda', dtype=torch.int32) + data = torch.zeros((128,), device='cuda', dtype=torch.float32) + ref = torch.full((128,), 64.0) + serialized_add[(64,)](data, Lock) + triton.testing.assert_almost_equal(data, ref) + + # --------------- # test cast # --------------- + + @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ (dtype_x, dtype_z, False) for dtype_x in dtypes diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 9af259a14..7ef63abba 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -806,7 +806,7 @@ def _add_atomic_docstr(name): @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): cmp = _to_tensor(cmp, _builder) - val = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) return semantic.atomic_cas(pointer, cmp, val, _builder)