[FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)
Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
This commit is contained in:
@@ -514,9 +514,41 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
|||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def test_atomic_cas():
|
||||||
|
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||||
|
@triton.jit
|
||||||
|
def change_value(Lock):
|
||||||
|
tl.atomic_cas(Lock, 0, 1)
|
||||||
|
|
||||||
|
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||||
|
change_value[(1,)](Lock)
|
||||||
|
|
||||||
|
assert(Lock[0] == 1)
|
||||||
|
|
||||||
|
# 2. only one block enters the critical section
|
||||||
|
@triton.jit
|
||||||
|
def serialized_add(data, Lock):
|
||||||
|
ptrs = data + tl.arange(0, 128)
|
||||||
|
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||||
|
|
||||||
|
# release lock
|
||||||
|
tl.atomic_xchg(Lock, 0)
|
||||||
|
|
||||||
|
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||||
|
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||||
|
ref = torch.full((128,), 64.0)
|
||||||
|
serialized_add[(64,)](data, Lock)
|
||||||
|
triton.testing.assert_almost_equal(data, ref)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test cast
|
# test cast
|
||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||||
(dtype_x, dtype_z, False)
|
(dtype_x, dtype_z, False)
|
||||||
for dtype_x in dtypes
|
for dtype_x in dtypes
|
||||||
|
@@ -806,7 +806,7 @@ def _add_atomic_docstr(name):
|
|||||||
@_add_atomic_docstr("compare-and-swap")
|
@_add_atomic_docstr("compare-and-swap")
|
||||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||||
cmp = _to_tensor(cmp, _builder)
|
cmp = _to_tensor(cmp, _builder)
|
||||||
val = _to_tensor(cmp, _builder)
|
val = _to_tensor(val, _builder)
|
||||||
return semantic.atomic_cas(pointer, cmp, val, _builder)
|
return semantic.atomic_cas(pointer, cmp, val, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user