[DOCS] Separate atomic cas from other atomic operations since operands are very different (#559)

This commit is contained in:
Keren Zhou
2022-06-22 17:51:17 -07:00
committed by GitHub
parent b02bac41ba
commit d345ddf837
2 changed files with 27 additions and 12 deletions

View File

@@ -106,9 +106,13 @@ Atomic Ops
:nosignatures: :nosignatures:
atomic_cas atomic_cas
atomic_xchg
atomic_add atomic_add
atomic_max atomic_max
atomic_min atomic_min
atomic_and
atomic_or
atomic_xor
Comparison ops Comparison ops

View File

@@ -806,11 +806,10 @@ def store(pointer, value, mask=None, _builder=None):
# Atomic Memory Operations # Atomic Memory Operations
# ----------------------- # -----------------------
def _add_atomic_docstr(name): @builtin
def atomic_cas(pointer, cmp, val, _builder=None):
def _decorator(func): """
docstr = """ Performs an atomic compare-and-swap at the memory location specified by :code:`pointer`.
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation. Return the data stored at :code:`pointer` before the atomic operation.
@@ -820,6 +819,26 @@ def _add_atomic_docstr(name):
:type cmp: Block of dtype=`pointer.dtype.element_ty` :type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value. :param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty` :type val: Block of dtype=`pointer.dtype.element_ty`
"""
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
def _add_atomic_docstr(name):
def _decorator(func):
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to apply {name}.
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to {name} in the atomic object.
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, do not apply {name}.
:type mask: Block of triton.int1, optional
""" """
func.__doc__ = docstr.format(name=name) func.__doc__ = docstr.format(name=name)
return func return func
@@ -827,14 +846,6 @@ def _add_atomic_docstr(name):
return _decorator return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
@builtin @builtin
@_add_atomic_docstr("exchange") @_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, _builder=None): def atomic_xchg(pointer, val, mask=None, _builder=None):