[FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)

Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
This commit is contained in:
daadaada
2022-05-22 00:45:54 +08:00
committed by GitHub
parent abea3dc2c6
commit 205a493b10
2 changed files with 33 additions and 1 deletions

View File

@@ -514,9 +514,41 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
change_value[(1,)](Lock)
assert(Lock[0] == 1)
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock):
ptrs = data + tl.arange(0, 128)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
# release lock
tl.atomic_xchg(Lock, 0)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
triton.testing.assert_almost_equal(data, ref)
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False)
for dtype_x in dtypes

View File

@@ -806,7 +806,7 @@ def _add_atomic_docstr(name):
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)