[FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)
Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
This commit is contained in:
@@ -514,9 +514,41 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
tl.atomic_cas(Lock, 0, 1)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
change_value[(1,)](Lock)
|
||||
|
||||
assert(Lock[0] == 1)
|
||||
|
||||
# 2. only one block enters the critical section
|
||||
@triton.jit
|
||||
def serialized_add(data, Lock):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
|
||||
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
triton.testing.assert_almost_equal(data, ref)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test cast
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False)
|
||||
for dtype_x in dtypes
|
||||
|
@@ -806,7 +806,7 @@ def _add_atomic_docstr(name):
|
||||
@_add_atomic_docstr("compare-and-swap")
|
||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_cas(pointer, cmp, val, _builder)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user