diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 1f05ce8a6..18bf95be4 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -106,9 +106,13 @@ Atomic Ops :nosignatures: atomic_cas + atomic_xchg atomic_add atomic_max atomic_min + atomic_and + atomic_or + atomic_xor Comparison ops diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d775abf40..1c54ef2c7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -806,6 +806,25 @@ def store(pointer, value, mask=None, _builder=None): # Atomic Memory Operations # ----------------------- +@builtin +def atomic_cas(pointer, cmp, val, _builder=None): + """ + Performs an atomic compare-and-swap at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to compare-and-swap. + :type pointer: Block of dtype=triton.PointerDType + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=`pointer.dtype.element_ty` + :param val: The values to copy in case the expected value matches the contained value. + :type val: Block of dtype=`pointer.dtype.element_ty` + """ + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_cas(pointer, cmp, val, _builder) + + def _add_atomic_docstr(name): def _decorator(func): @@ -814,12 +833,12 @@ def _add_atomic_docstr(name): Return the data stored at :code:`pointer` before the atomic operation. - :param pointer: The memory locations to compare-and-swap. + :param pointer: The memory locations to apply {name}. :type pointer: Block of dtype=triton.PointerDType - :param cmp: The values expected to be found in the atomic object - :type cmp: Block of dtype=`pointer.dtype.element_ty` - :param val: The values to copy in case the expected value matches the contained value. + :param val: The values to {name} in the atomic object. :type val: Block of dtype=`pointer.dtype.element_ty` + :param mask: If mask[idx] is false, do not apply {name}. + :type mask: Block of triton.int1, optional """ func.__doc__ = docstr.format(name=name) return func @@ -827,14 +846,6 @@ def _add_atomic_docstr(name): return _decorator -@builtin -@_add_atomic_docstr("compare-and-swap") -def atomic_cas(pointer, cmp, val, _builder=None): - cmp = _to_tensor(cmp, _builder) - val = _to_tensor(val, _builder) - return semantic.atomic_cas(pointer, cmp, val, _builder) - - @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None):