[PYTHON] Added atomic_add (#94)

This commit is contained in:
Philippe Tillet
2021-04-29 09:13:45 -04:00
committed by Philippe Tillet
parent d7f87929fa
commit 2b75158426
5 changed files with 65 additions and 2 deletions

View File

@@ -134,6 +134,7 @@ void init_triton_frontend(py::module &&m) {
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing

View File

@@ -249,9 +249,13 @@ class CodeGenerator(ast.NodeVisitor):
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, builder=self.builder)
return getattr(lhs, fn)(rhs)
elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_UnaryOp(self, node):
op = self.visit(node.operand)

View File

@@ -128,6 +128,9 @@ class block:
def __sub__(self, other, builder=None):
return frontend.sub(self, other, builder)
def __rsub__(self, other, builder=None):
return frontend.sub(other, self, builder)
@builtin
def __mul__(self, other, builder=None):
return frontend.mul(self, other, builder)
@@ -183,22 +186,42 @@ class block:
# comparison operators
# >
@builtin
def __gt__(self, other, builder=None):
return frontend.greater_than(self, other, builder)
@builtin
def __rgt__(self, other, builder=None):
return frontend.greater_than(other, self, builder)
# >=
@builtin
def __ge__(self, other, builder=None):
return frontend.greater_equal(self, other, builder)
def __rge__(self, other, builder=None):
return frontend.greater_equal(other, self, builder)
# <
@builtin
def __lt__(self, other, builder=None):
return frontend.less_than(self, other, builder)
@builtin
def __rlt__(self, other, builder=None):
return frontend.less_than(other, self, builder)
# <=
@builtin
def __le__(self, other, builder=None):
return frontend.less_equal(self, other, builder)
@builtin
def __rle__(self, other, builder=None):
return frontend.less_equal(other, self, builder)
# ==
@builtin
def __eq__(self, other, builder=None):
return frontend.equal(self, other, builder)
@@ -421,6 +444,20 @@ def atomic_xchg(pointer, val, builder=None):
return frontend.atomic_xchg(pointer, val, builder)
@builtin
def atomic_add(pointer, val, mask=None, builder=None):
"""
Performs an atomic add and the memory locations specified by :code:`pointer`.
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to add
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_add(pointer, val, mask, builder)
# -----------------------
# Conditioning
# -----------------------
@@ -475,6 +512,17 @@ def log(x, builder=None):
return frontend.log(x, builder)
@builtin
def sqrt(x, builder=None):
"""
Computes the element-wise square root of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.sqrt(x, builder)
# -----------------------
# Reductions
# -----------------------