[PYTHON] Added atomic_add (#94)

This commit is contained in:
Philippe Tillet
2021-04-29 09:13:45 -04:00
committed by Philippe Tillet
parent d7f87929fa
commit 2b75158426
5 changed files with 65 additions and 2 deletions

View File

@@ -69,6 +69,7 @@ struct dispatch{
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);

View File

@@ -521,6 +521,15 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
return builder->create_atomic_exch(ptr, val);
}
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
return builder->create_atomic_add(ptr, val, mask);
}
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//

View File

@@ -134,6 +134,7 @@ void init_triton_frontend(py::module &&m) {
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing

View File

@@ -249,8 +249,12 @@ class CodeGenerator(ast.NodeVisitor):
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, builder=self.builder)
elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_UnaryOp(self, node):

View File

@@ -128,6 +128,9 @@ class block:
def __sub__(self, other, builder=None):
return frontend.sub(self, other, builder)
def __rsub__(self, other, builder=None):
return frontend.sub(other, self, builder)
@builtin
def __mul__(self, other, builder=None):
return frontend.mul(self, other, builder)
@@ -183,22 +186,42 @@ class block:
# comparison operators
# >
@builtin
def __gt__(self, other, builder=None):
return frontend.greater_than(self, other, builder)
@builtin
def __rgt__(self, other, builder=None):
return frontend.greater_than(other, self, builder)
# >=
@builtin
def __ge__(self, other, builder=None):
return frontend.greater_equal(self, other, builder)
def __rge__(self, other, builder=None):
return frontend.greater_equal(other, self, builder)
# <
@builtin
def __lt__(self, other, builder=None):
return frontend.less_than(self, other, builder)
@builtin
def __rlt__(self, other, builder=None):
return frontend.less_than(other, self, builder)
# <=
@builtin
def __le__(self, other, builder=None):
return frontend.less_equal(self, other, builder)
@builtin
def __rle__(self, other, builder=None):
return frontend.less_equal(other, self, builder)
# ==
@builtin
def __eq__(self, other, builder=None):
return frontend.equal(self, other, builder)
@@ -421,6 +444,20 @@ def atomic_xchg(pointer, val, builder=None):
return frontend.atomic_xchg(pointer, val, builder)
@builtin
def atomic_add(pointer, val, mask=None, builder=None):
"""
Performs an atomic add and the memory locations specified by :code:`pointer`.
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to add
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_add(pointer, val, mask, builder)
# -----------------------
# Conditioning
# -----------------------
@@ -475,6 +512,17 @@ def log(x, builder=None):
return frontend.log(x, builder)
@builtin
def sqrt(x, builder=None):
"""
Computes the element-wise square root of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.sqrt(x, builder)
# -----------------------
# Reductions
# -----------------------