[PYTHON] Added atomic_add (#94)
This commit is contained in:
committed by
Philippe Tillet
parent
d7f87929fa
commit
2b75158426
@@ -69,6 +69,7 @@ struct dispatch{
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
|
@@ -521,6 +521,15 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
if(!mask){
|
||||
mask = builder->get_int1(true);
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
return builder->create_atomic_add(ptr, val, mask);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linear Algebra
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -134,6 +134,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("store", &ir::dispatch::store, ret::reference);
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
|
@@ -249,8 +249,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.Is: '__eq__',
|
||||
ast.IsNot: '__ne__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
|
||||
if self.is_triton_object(lhs):
|
||||
return getattr(lhs, fn)(rhs, builder=self.builder)
|
||||
elif self.is_triton_object(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
|
@@ -128,6 +128,9 @@ class block:
|
||||
def __sub__(self, other, builder=None):
|
||||
return frontend.sub(self, other, builder)
|
||||
|
||||
def __rsub__(self, other, builder=None):
|
||||
return frontend.sub(other, self, builder)
|
||||
|
||||
@builtin
|
||||
def __mul__(self, other, builder=None):
|
||||
return frontend.mul(self, other, builder)
|
||||
@@ -183,22 +186,42 @@ class block:
|
||||
|
||||
# comparison operators
|
||||
|
||||
# >
|
||||
@builtin
|
||||
def __gt__(self, other, builder=None):
|
||||
return frontend.greater_than(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __rgt__(self, other, builder=None):
|
||||
return frontend.greater_than(other, self, builder)
|
||||
|
||||
# >=
|
||||
@builtin
|
||||
def __ge__(self, other, builder=None):
|
||||
return frontend.greater_equal(self, other, builder)
|
||||
|
||||
def __rge__(self, other, builder=None):
|
||||
return frontend.greater_equal(other, self, builder)
|
||||
|
||||
# <
|
||||
@builtin
|
||||
def __lt__(self, other, builder=None):
|
||||
return frontend.less_than(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __rlt__(self, other, builder=None):
|
||||
return frontend.less_than(other, self, builder)
|
||||
|
||||
# <=
|
||||
@builtin
|
||||
def __le__(self, other, builder=None):
|
||||
return frontend.less_equal(self, other, builder)
|
||||
|
||||
@builtin
|
||||
def __rle__(self, other, builder=None):
|
||||
return frontend.less_equal(other, self, builder)
|
||||
|
||||
# ==
|
||||
@builtin
|
||||
def __eq__(self, other, builder=None):
|
||||
return frontend.equal(self, other, builder)
|
||||
@@ -421,6 +444,20 @@ def atomic_xchg(pointer, val, builder=None):
|
||||
return frontend.atomic_xchg(pointer, val, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def atomic_add(pointer, val, mask=None, builder=None):
|
||||
"""
|
||||
Performs an atomic add and the memory locations specified by :code:`pointer`.
|
||||
:param pointer: The memory locations which contain the old values
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param val: The values to add
|
||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
|
||||
:type mask: Block of triton.int1, optional
|
||||
"""
|
||||
return frontend.atomic_add(pointer, val, mask, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
@@ -475,6 +512,17 @@ def log(x, builder=None):
|
||||
return frontend.log(x, builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def sqrt(x, builder=None):
|
||||
"""
|
||||
Computes the element-wise square root of :code:`x`
|
||||
|
||||
:param x: the input values
|
||||
:type x: Block
|
||||
"""
|
||||
return frontend.sqrt(x, builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
Reference in New Issue
Block a user