[DOCS] Separate atomic cas from other atomic operations since operands are very different (#559)

This commit is contained in:
Keren Zhou
2022-06-22 17:51:17 -07:00
committed by GitHub
parent b02bac41ba
commit d345ddf837
2 changed files with 27 additions and 12 deletions

View File

@@ -106,9 +106,13 @@ Atomic Ops
:nosignatures:
atomic_cas
atomic_xchg
atomic_add
atomic_max
atomic_min
atomic_and
atomic_or
atomic_xor
Comparison ops

View File

@@ -806,6 +806,25 @@ def store(pointer, value, mask=None, _builder=None):
# Atomic Memory Operations
# -----------------------
@builtin
def atomic_cas(pointer, cmp, val, _builder=None):
"""
Performs an atomic compare-and-swap at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
def _add_atomic_docstr(name):
def _decorator(func):
@@ -814,12 +833,12 @@ def _add_atomic_docstr(name):
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap.
:param pointer: The memory locations to apply {name}.
:type pointer: Block of dtype=triton.PointerDType
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:param val: The values to {name} in the atomic object.
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, do not apply {name}.
:type mask: Block of triton.int1, optional
"""
func.__doc__ = docstr.format(name=name)
return func
@@ -827,14 +846,6 @@ def _add_atomic_docstr(name):
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, _builder=None):