From 189491727ae3c0c35ca719f7bc9a8acfde868c64 Mon Sep 17 00:00:00 2001 From: Crutcher Dunnavant Date: Mon, 5 Dec 2022 14:59:41 -0800 Subject: [PATCH] [FRONTEND] Extract and unify @builtin/@extern (#913) This change attaches builtin-ness as an explicit attribute, rather than a module prefix expectation. This permits us to source those builtins from multiple sub-modules (useful when some builtins are part of the true cyclic implementation core, and some are just useful library additions); but also prevents accidental inclusion of non-builtins that happen to be in the right library. Once the flag exists, and the compiler is using `is_builtin()` for decision making; the existence of the current `@extern` interface becomes isomorphic to `@builtin`; and the interface can be unified. Leaving `@extern` a thin-wrapper, and encouraging continued use of it, establishes future-proofing towards adding additional extern tracing, metric hooks, or scanning in the future. * Add `triton.impl` package to hold the core, order dependent impl details. * Extract `@builtin` and unify `@extern`; add `is_builtin()` * Add sense bit so that `@builtin` detection is less fragile. * Modify the compiler to use `is_builtin()` --- python/triton/__init__.py | 2 + python/triton/compiler.py | 7 +- python/triton/impl/__init__.py | 22 ++ python/triton/impl/base.py | 36 +++ python/triton/language/__init__.py | 6 +- python/triton/language/core.py | 14 +- python/triton/language/extern.py | 22 -- python/triton/language/libdevice.py | 399 ++++++++++++++-------------- 8 files changed, 269 insertions(+), 239 deletions(-) create mode 100644 python/triton/impl/__init__.py create mode 100644 python/triton/impl/base.py diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 426a7e40b..9b43de73d 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -9,6 +9,7 @@ __version__ = '2.0.0' import torch # noqa: F401 # submodules +from . import impl from .utils import ( cdiv, MockTensor, @@ -36,6 +37,7 @@ __all__ = [ "compile", "Config", "heuristics", + "impl", "jit", "JITFunction", "KernelInterface", diff --git a/python/triton/compiler.py b/python/triton/compiler.py index ad238cdc0..86b655f21 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -25,6 +25,8 @@ from filelock import FileLock import triton import triton._C.libtriton.triton as _triton + +from . import impl from .tools.disasm import extract @@ -715,9 +717,8 @@ class CodeGenerator(ast.NodeVisitor): for i in range(call_op.get_num_results()): results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) return tuple(results) - if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language.core or \ - isinstance(fn, triton.language.extern.ExternalFunction): + if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \ + or impl.is_builtin(fn): return fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg diff --git a/python/triton/impl/__init__.py b/python/triton/impl/__init__.py new file mode 100644 index 000000000..d0221991b --- /dev/null +++ b/python/triton/impl/__init__.py @@ -0,0 +1,22 @@ +"""Triton internal implementation details. + +Client libraries should not import interfaces from the `triton.impl` module; +as the details are subject to change. + +APIs defined in the `triton.impl` module which are public will be re-exported +in other relevant `triton` module namespaces. +""" + +from triton._C.libtriton.triton import ir +from .base import ( + builtin, + extern, + is_builtin, +) + +__all__ = [ + "builtin", + "extern", + "ir", + "is_builtin", +] diff --git a/python/triton/impl/base.py b/python/triton/impl/base.py new file mode 100644 index 000000000..24048c56d --- /dev/null +++ b/python/triton/impl/base.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from functools import wraps +from typing import TypeVar + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError( + "Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)" + ) + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +def extern(fn: T) -> T: + """A decorator for external functions.""" + return builtin(fn) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index f9acbd3dc..4b7df9515 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,8 +1,10 @@ """isort:skip_file""" # Import order is significant here. -from triton._C.libtriton.triton import ir - +from ..impl import ( + ir, + builtin, +) from . import core, extern, libdevice, random from .core import ( abs, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7fb2cf2ae..d208cc45e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,11 +1,10 @@ from __future__ import annotations from enum import Enum -from functools import wraps from typing import List, Callable, TypeVar import triton -from . import semantic +from . import builtin, semantic from triton._C.libtriton.triton import ir T = TypeVar('T') @@ -34,17 +33,6 @@ def _to_tensor(x, builder): assert False, f'cannot convert {x} to tensor' -def builtin(fn: T) -> T: - @wraps(fn) - def wrapper(*args, **kwargs): - if '_builder' not in kwargs or \ - kwargs['_builder'] is None: - raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") - return fn(*args, **kwargs) - - return wrapper - - class dtype: SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py index 3bb457fb8..400ba6645 100644 --- a/python/triton/language/extern.py +++ b/python/triton/language/extern.py @@ -86,25 +86,3 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: ret_shape = broadcast_arg.shape func = getattr(_builder, "create_external_elementwise") return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) - - -class ExternalFunction: - ''' - A wrapper for external functions - ''' - - def __init__(self, fn): - self.fn = fn - - def __call__(self, *args, **kwargs): - if '_builder' not in kwargs or \ - kwargs['_builder'] is None: - raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") - return self.fn(*args, **kwargs) - - -def extern(fn): - ''' - A decorator for external functions - ''' - return ExternalFunction(fn) diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py index cae14a797..f42705dbc 100644 --- a/python/triton/language/libdevice.py +++ b/python/triton/language/libdevice.py @@ -1,12 +1,13 @@ import os +from .. import impl from . import core, extern LIBDEVICE_PATH = os.path.dirname( os.path.abspath(__file__)) + "/libdevice.10.bc" -@extern.extern +@impl.extern def clz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), @@ -14,7 +15,7 @@ def clz(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def popc(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), @@ -22,14 +23,14 @@ def popc(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def byte_perm(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def min(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), @@ -41,7 +42,7 @@ def min(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def max(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), @@ -53,7 +54,7 @@ def max(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def mulhi(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), @@ -63,7 +64,7 @@ def mulhi(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def mul24(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), @@ -71,7 +72,7 @@ def mul24(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def brev(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), @@ -79,7 +80,7 @@ def brev(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sad(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), @@ -87,7 +88,7 @@ def sad(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def abs(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), @@ -97,7 +98,7 @@ def abs(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def floor(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), @@ -105,14 +106,14 @@ def floor(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def rcp64h(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def rsqrt(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), @@ -120,7 +121,7 @@ def rsqrt(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def ceil(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), @@ -128,7 +129,7 @@ def ceil(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def trunc(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), @@ -136,7 +137,7 @@ def trunc(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def exp2(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), @@ -144,14 +145,14 @@ def exp2(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def saturatef(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fma_rn(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), @@ -159,7 +160,7 @@ def fma_rn(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def fma_rz(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), @@ -167,7 +168,7 @@ def fma_rz(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def fma_rd(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), @@ -175,7 +176,7 @@ def fma_rd(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def fma_ru(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), @@ -183,14 +184,14 @@ def fma_ru(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def fast_dividef(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def div_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), @@ -198,7 +199,7 @@ def div_rn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def div_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), @@ -206,7 +207,7 @@ def div_rz(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def div_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), @@ -214,7 +215,7 @@ def div_rd(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def div_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), @@ -222,7 +223,7 @@ def div_ru(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def rcp_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), @@ -230,7 +231,7 @@ def rcp_rn(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def rcp_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), @@ -238,7 +239,7 @@ def rcp_rz(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def rcp_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), @@ -246,7 +247,7 @@ def rcp_rd(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def rcp_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), @@ -254,7 +255,7 @@ def rcp_ru(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sqrt_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), @@ -262,7 +263,7 @@ def sqrt_rn(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sqrt_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), @@ -270,7 +271,7 @@ def sqrt_rz(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sqrt_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), @@ -278,7 +279,7 @@ def sqrt_rd(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sqrt_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), @@ -286,7 +287,7 @@ def sqrt_ru(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sqrt(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), @@ -294,7 +295,7 @@ def sqrt(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def add_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), @@ -302,7 +303,7 @@ def add_rn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def add_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), @@ -310,7 +311,7 @@ def add_rz(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def add_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), @@ -318,7 +319,7 @@ def add_rd(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def add_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), @@ -326,7 +327,7 @@ def add_ru(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def mul_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), @@ -334,7 +335,7 @@ def mul_rn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def mul_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), @@ -342,7 +343,7 @@ def mul_rz(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def mul_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), @@ -350,7 +351,7 @@ def mul_rd(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def mul_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), @@ -358,567 +359,567 @@ def mul_ru(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def double2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def double2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def double2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def double2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def double2int_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2int_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2int_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2int_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2uint_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2uint_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2uint_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2uint_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def int2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def uint2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def float2int_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2int_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2int_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2int_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2uint_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2uint_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2uint_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2uint_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def int2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def int2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def int2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def int2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def uint2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def uint2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def uint2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def uint2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def hiloint2double(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def double2loint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def double2hiint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def float2ll_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ll_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ll_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ll_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ull_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ull_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ull_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def float2ull_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ll_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ll_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ll_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ll_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ull_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ull_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ull_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def double2ull_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def ll2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ll2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ll2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ll2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ull2float_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ull2float_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ull2float_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ull2float_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ll2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ll2double_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ll2double_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ll2double_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ull2double_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ull2double_rz(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ull2double_rd(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def ull2double_ru(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def int_as_float(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def float_as_int(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def uint_as_float(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def float_as_uint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def longlong_as_double(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), }, _builder) -@extern.extern +@impl.extern def double_as_longlong(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), }, _builder) -@extern.extern +@impl.extern def fast_sinf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_cosf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_log2f(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_logf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_expf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_tanf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_exp10f(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_log10f(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def fast_powf(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def hadd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), @@ -926,7 +927,7 @@ def hadd(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def rhadd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), @@ -934,7 +935,7 @@ def rhadd(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def sub_rn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), @@ -942,7 +943,7 @@ def sub_rn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def sub_rz(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), @@ -950,7 +951,7 @@ def sub_rz(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def sub_rd(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), @@ -958,7 +959,7 @@ def sub_rd(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def sub_ru(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), @@ -966,14 +967,14 @@ def sub_ru(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def rsqrt_rn(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), }, _builder) -@extern.extern +@impl.extern def ffs(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), @@ -981,7 +982,7 @@ def ffs(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def rint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), @@ -989,7 +990,7 @@ def rint(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def llrint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), @@ -997,7 +998,7 @@ def llrint(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def nearbyint(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), @@ -1005,7 +1006,7 @@ def nearbyint(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def isnan(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), @@ -1013,7 +1014,7 @@ def isnan(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def signbit(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), @@ -1021,7 +1022,7 @@ def signbit(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def copysign(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), @@ -1029,14 +1030,14 @@ def copysign(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def finitef(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), }, _builder) -@extern.extern +@impl.extern def isinf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), @@ -1044,7 +1045,7 @@ def isinf(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def nextafter(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), @@ -1052,7 +1053,7 @@ def nextafter(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def sin(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), @@ -1060,7 +1061,7 @@ def sin(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def cos(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), @@ -1068,7 +1069,7 @@ def cos(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sinpi(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), @@ -1076,7 +1077,7 @@ def sinpi(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def cospi(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), @@ -1084,7 +1085,7 @@ def cospi(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def tan(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), @@ -1092,7 +1093,7 @@ def tan(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def log2(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), @@ -1100,7 +1101,7 @@ def log2(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def exp(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), @@ -1108,7 +1109,7 @@ def exp(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def exp10(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), @@ -1116,7 +1117,7 @@ def exp10(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def cosh(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), @@ -1124,7 +1125,7 @@ def cosh(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def sinh(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), @@ -1132,7 +1133,7 @@ def sinh(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def tanh(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), @@ -1140,7 +1141,7 @@ def tanh(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def atan2(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), @@ -1148,7 +1149,7 @@ def atan2(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def atan(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), @@ -1156,7 +1157,7 @@ def atan(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def asin(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), @@ -1164,7 +1165,7 @@ def asin(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def acos(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), @@ -1172,7 +1173,7 @@ def acos(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def log(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), @@ -1180,7 +1181,7 @@ def log(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def log10(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), @@ -1188,7 +1189,7 @@ def log10(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def log1p(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), @@ -1196,7 +1197,7 @@ def log1p(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def acosh(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), @@ -1204,7 +1205,7 @@ def acosh(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def asinh(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), @@ -1212,7 +1213,7 @@ def asinh(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def atanh(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), @@ -1220,7 +1221,7 @@ def atanh(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def expm1(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), @@ -1228,7 +1229,7 @@ def expm1(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def hypot(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), @@ -1236,7 +1237,7 @@ def hypot(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def rhypot(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), @@ -1244,7 +1245,7 @@ def rhypot(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def norm3d(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), @@ -1252,7 +1253,7 @@ def norm3d(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def rnorm3d(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), @@ -1260,7 +1261,7 @@ def rnorm3d(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def norm4d(arg0, arg1, arg2, arg3, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), @@ -1268,7 +1269,7 @@ def norm4d(arg0, arg1, arg2, arg3, _builder=None): }, _builder) -@extern.extern +@impl.extern def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), @@ -1276,7 +1277,7 @@ def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): }, _builder) -@extern.extern +@impl.extern def cbrt(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), @@ -1284,7 +1285,7 @@ def cbrt(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def rcbrt(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), @@ -1292,7 +1293,7 @@ def rcbrt(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def j0(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), @@ -1300,7 +1301,7 @@ def j0(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def j1(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), @@ -1308,7 +1309,7 @@ def j1(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def y0(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), @@ -1316,7 +1317,7 @@ def y0(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def y1(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), @@ -1324,7 +1325,7 @@ def y1(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def yn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), @@ -1332,7 +1333,7 @@ def yn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def jn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), @@ -1340,7 +1341,7 @@ def jn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def cyl_bessel_i0(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), @@ -1348,7 +1349,7 @@ def cyl_bessel_i0(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def cyl_bessel_i1(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), @@ -1356,7 +1357,7 @@ def cyl_bessel_i1(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def erf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), @@ -1364,7 +1365,7 @@ def erf(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def erfinv(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), @@ -1372,7 +1373,7 @@ def erfinv(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def erfc(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), @@ -1380,7 +1381,7 @@ def erfc(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def erfcx(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), @@ -1388,7 +1389,7 @@ def erfcx(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def erfcinv(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), @@ -1396,7 +1397,7 @@ def erfcinv(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def normcdfinv(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), @@ -1404,7 +1405,7 @@ def normcdfinv(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def normcdf(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), @@ -1412,7 +1413,7 @@ def normcdf(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def lgamma(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), @@ -1420,7 +1421,7 @@ def lgamma(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def ldexp(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), @@ -1428,7 +1429,7 @@ def ldexp(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def scalbn(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), @@ -1436,7 +1437,7 @@ def scalbn(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def fmod(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), @@ -1444,7 +1445,7 @@ def fmod(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def remainder(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), @@ -1452,7 +1453,7 @@ def remainder(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def fma(arg0, arg1, arg2, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), @@ -1460,7 +1461,7 @@ def fma(arg0, arg1, arg2, _builder=None): }, _builder) -@extern.extern +@impl.extern def pow(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), @@ -1470,7 +1471,7 @@ def pow(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def tgamma(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), @@ -1478,7 +1479,7 @@ def tgamma(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def round(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), @@ -1486,7 +1487,7 @@ def round(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def llround(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), @@ -1494,7 +1495,7 @@ def llround(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def fdim(arg0, arg1, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), @@ -1502,7 +1503,7 @@ def fdim(arg0, arg1, _builder=None): }, _builder) -@extern.extern +@impl.extern def ilogb(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), @@ -1510,7 +1511,7 @@ def ilogb(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def logb(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), @@ -1518,7 +1519,7 @@ def logb(arg0, _builder=None): }, _builder) -@extern.extern +@impl.extern def isfinited(arg0, _builder=None): return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")),