From 2b75158426d982a1ac34f91effcc6631be58790d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 29 Apr 2021 09:13:45 -0400 Subject: [PATCH] [PYTHON] Added atomic_add (#94) --- include/triton/ir/dispatch.h | 1 + lib/ir/dispatch.cc | 9 +++++++ python/src/triton.cc | 1 + python/triton/code_gen.py | 8 ++++-- python/triton/language.py | 48 ++++++++++++++++++++++++++++++++++++ 5 files changed, 65 insertions(+), 2 deletions(-) diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index c2e5ac320..6d06baff1 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -69,6 +69,7 @@ struct dispatch{ static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder); + static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); // linear algebra static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 1a3b8460c..53c0dd2e8 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -521,6 +521,15 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu return builder->create_atomic_exch(ptr, val); } +ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + if(!mask){ + mask = builder->get_int1(true); + if(ptr->get_type()->is_block_ty()) + mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); + } + return builder->create_atomic_add(ptr, val, mask); +} + //===----------------------------------------------------------------------===// // Linear Algebra //===----------------------------------------------------------------------===// diff --git a/python/src/triton.cc b/python/src/triton.cc index 29efa6f6b..97fd20a40 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -134,6 +134,7 @@ void init_triton_frontend(py::module &&m) { m.def("store", &ir::dispatch::store, ret::reference); m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); + m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); // linear algebra m.def("dot", &ir::dispatch::dot, ret::reference); // indexing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 27ae177f4..c2f6fdfd8 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -249,9 +249,13 @@ class CodeGenerator(ast.NodeVisitor): ast.Is: '__eq__', ast.IsNot: '__ne__', }[type(node.ops[0])] - if self.is_triton_object(lhs) or self.is_triton_object(rhs): + if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, builder=self.builder) - return getattr(lhs, fn)(rhs) + elif self.is_triton_object(rhs): + fn = fn[:2] + 'r' + fn[2:] + return getattr(rhs, fn)(lhs, builder=self.builder) + else: + return getattr(lhs, fn)(rhs) def visit_UnaryOp(self, node): op = self.visit(node.operand) diff --git a/python/triton/language.py b/python/triton/language.py index a9d8b75db..0ee34151f 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -128,6 +128,9 @@ class block: def __sub__(self, other, builder=None): return frontend.sub(self, other, builder) + def __rsub__(self, other, builder=None): + return frontend.sub(other, self, builder) + @builtin def __mul__(self, other, builder=None): return frontend.mul(self, other, builder) @@ -183,22 +186,42 @@ class block: # comparison operators + # > @builtin def __gt__(self, other, builder=None): return frontend.greater_than(self, other, builder) + @builtin + def __rgt__(self, other, builder=None): + return frontend.greater_than(other, self, builder) + + # >= @builtin def __ge__(self, other, builder=None): return frontend.greater_equal(self, other, builder) + def __rge__(self, other, builder=None): + return frontend.greater_equal(other, self, builder) + + # < @builtin def __lt__(self, other, builder=None): return frontend.less_than(self, other, builder) + @builtin + def __rlt__(self, other, builder=None): + return frontend.less_than(other, self, builder) + + # <= @builtin def __le__(self, other, builder=None): return frontend.less_equal(self, other, builder) + @builtin + def __rle__(self, other, builder=None): + return frontend.less_equal(other, self, builder) + + # == @builtin def __eq__(self, other, builder=None): return frontend.equal(self, other, builder) @@ -421,6 +444,20 @@ def atomic_xchg(pointer, val, builder=None): return frontend.atomic_xchg(pointer, val, builder) +@builtin +def atomic_add(pointer, val, mask=None, builder=None): + """ + Performs an atomic add and the memory locations specified by :code:`pointer`. + :param pointer: The memory locations which contain the old values + :type pointer: Block of dtype=triton.PointerDType + :param val: The values to add + :type val: Block of dtype=`pointer.dtype.element_ty` + :param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected. + :type mask: Block of triton.int1, optional + """ + return frontend.atomic_add(pointer, val, mask, builder) + + # ----------------------- # Conditioning # ----------------------- @@ -475,6 +512,17 @@ def log(x, builder=None): return frontend.log(x, builder) +@builtin +def sqrt(x, builder=None): + """ + Computes the element-wise square root of :code:`x` + + :param x: the input values + :type x: Block + """ + return frontend.sqrt(x, builder) + + # ----------------------- # Reductions # -----------------------