From 3c635449e51008ad32e262602298d1c89d9cf7d6 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Thu, 1 Sep 2022 16:34:27 -0700 Subject: [PATCH] [Triton] Support math and libdevice ops (#91) This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) - Currently till TritonGPU. It cannot be lowered to PTX now. - No special optimizations (e.g., constant folding etc) are applied. - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td - No constant folding etc for `libdevice` ops. ```py import triton import triton.language as tl import sys @triton.jit def add_kernel( x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, ): offsets = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets) x = tl.sin(x) output = tl.libdevice.sin(x) output = tl.libdevice.fdiv_rn(output, output) output = tl.libdevice.fmaf_rd(output, output, output) tl.store(y_ptr + offsets, output) if __name__ == "__main__" and len(sys.argv) >= 2: signature = "*fp32,*fp32" constants = {'BLOCK_SIZE': 1024} output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir") print(output) ``` -> ```llvm #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr, %arg1: !tt.ptr) { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr, #blocked> %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> %4 = math.sin %3 : tensor<1024xf32, #blocked> %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %8 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr, #blocked> tt.store %9, %7 : tensor<1024xf32, #blocked> return } } ``` --- CMakeLists.txt | 1 + bin/CMakeLists.txt | 1 + bin/triton-opt.cpp | 8 +- bin/triton-translate.cpp | 6 +- include/triton/Conversion/Passes.td | 4 +- include/triton/Dialect/Triton/IR/Dialect.h | 1 + .../triton/Dialect/Triton/IR/TritonDialect.td | 3 + include/triton/Dialect/Triton/IR/TritonOps.td | 20 + .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 69 +- .../Transforms/TritonGPUConversion.cpp | 16 +- python/src/triton.cc | 41 +- python/tests/test_math_ops.py | 33 + python/triton/compiler.py | 3 +- python/triton/language/__init__.py | 2 +- python/triton/language/core.py | 12 +- python/triton/language/extern.py | 104 ++ python/triton/language/libdevice.py | 1661 +++++++++++++++++ 18 files changed, 1938 insertions(+), 51 deletions(-) create mode 100644 python/tests/test_math_ops.py create mode 100644 python/triton/language/extern.py create mode 100644 python/triton/language/libdevice.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ae7c245e..6b431b37d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ target_link_libraries(triton MLIRSupport MLIRTargetLLVMIRExport MLIRExecutionEngine + MLIRMathToLLVM MLIRNVVMToLLVMIRTranslation MLIRIR ) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 462873630..ac7e877dc 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -52,6 +52,7 @@ target_link_libraries(triton-translate PRIVATE MLIRSupport MLIRTransforms MLIRExecutionEngine + MLIRMathToLLVM MLIRTransformUtils MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index f67c14135..9f3b53b7a 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -32,10 +32,10 @@ int main(int argc, char **argv) { // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; - registry - .insert(); + registry.insert(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Triton (GPU) optimizer driver\n", registry)); diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index 4a0355486..9aa8881f5 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -36,9 +36,9 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, } mlir::DialectRegistry registry; - registry - .insert(); + registry.insert(); context.appendDialectRegistry(registry); diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index 2fa727b70..086cf2dd8 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -6,11 +6,12 @@ include "mlir/Pass/PassBase.td" def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { let summary = "Convert Triton to TritonGPU"; let description = [{ - + }]; let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; let dependentDialects = ["mlir::arith::ArithmeticDialect", + "mlir::math::MathDialect", "mlir::StandardOpsDialect", // TODO: Does this pass depend on SCF? "mlir::scf::SCFDialect", @@ -33,6 +34,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()"; let dependentDialects = ["mlir::arith::ArithmeticDialect", + "mlir::math::MathDialect", "mlir::gpu::GPUDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect", diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 725aed72d..6ce3bd76a 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -2,6 +2,7 @@ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index 9262dc098..eb5d5c8f2 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -16,12 +16,15 @@ def Triton_Dialect : Dialect { Dependent Dialects: * Arithmetic: * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Math: + * exp, sin, cos, log, ... * StructuredControlFlow: * ForOp, IfOp, WhileOp, YieldOp, ConditionOp }]; let dependentDialects = [ "arith::ArithmeticDialect", + "math::MathDialect", "StandardOpsDialect", "scf::SCFDialect", "gpu::GPUDialect", diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 32ebe4123..69a301c31 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -283,6 +283,26 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { let results = (outs TT_Type:$result); } +// +// External Function Ops +// +def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape, + SameVariadicOperandSize]> { + let summary = "ext_elemwise"; + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)"; +} + // // Intrinsics // diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 3964de000..9c6ca2488 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4,6 +4,7 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -1328,9 +1329,10 @@ public: populateTritonToLLVMPatterns(typeConverter, patterns, numWarps, *axisAnalysis, 10 /*benefit*/); - // Add arith's patterns to help convert scalar expression to LLVM. + // Add arith/math's patterns to help convert scalar expression to LLVM. mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index e000d2604..19cefee86 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -7,13 +7,12 @@ #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "llvm/ADT/APSInt.h" #include - using namespace mlir; using namespace mlir::triton; namespace { -template class ArithGenericPattern : public OpConversionPattern { +template class GenericOpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -23,6 +22,7 @@ public: Type retType = this->getTypeConverter()->convertType(op.getType()); Op res = rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()); + return success(); } }; @@ -98,32 +98,41 @@ void populateArithmeticPatternsAndLegality( // Rewrite rule // patterns.add(typeConverter, context); patterns.add< - ArithConstantPattern, ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, - ArithGenericPattern, // NegFOp + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp // Floating point - ArithGenericPattern, ArithGenericPattern, + GenericOpPattern, GenericOpPattern, // MaxMin - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, ArithGenericPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, // Floating point - ArithGenericPattern, ArithGenericPattern, - ArithGenericPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // Cmp ArithCmpPattern, ArithCmpPattern, // Cast Ops - ArithGenericPattern, - ArithGenericPattern>(typeConverter, context); + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); } // @@ -246,6 +255,20 @@ struct TritonStorePattern : public OpConversionPattern { } }; +struct TritonExtElemwisePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), adaptor.args(), + adaptor.libname(), adaptor.libpath(), adaptor.symbol()); + return success(); + } +}; + template struct TritonGenericPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -302,7 +325,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonReducePattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, - TritonLoadPattern, TritonStorePattern>(typeConverter, context); + TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>( + typeConverter, context); } // @@ -389,6 +413,7 @@ public: RewritePatternSet patterns(context); // add rules populateArithmeticPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 4adb11143..b303efab8 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -81,13 +81,13 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( addIllegalOp(); - addDynamicallyLegalDialect( - [&](Operation *op) { - if (typeConverter.isLegal(op)) - return true; - return false; - }); + addDynamicallyLegalDialect([&](Operation *op) { + if (typeConverter.isLegal(op)) + return true; + return false; + }); // We have requirements for the data layouts addDynamicallyLegalOp([this](triton::DotOp dotOp) -> bool { @@ -100,4 +100,4 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( return true; return false; }); -} \ No newline at end of file +} diff --git a/python/src/triton.cc b/python/src/triton.cc index 8366b90b9..41f99f84a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1542,7 +1542,16 @@ void init_triton_ir(py::module &&m) { return self.create(loc, dstType, rmwOp, ptr, val, mask); }) - + // External + .def("create_external_elementwise", + [](mlir::OpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, + mlir::Type retType) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create( + loc, retType, argList, libName, libPath, symbol); + }) // Built-in instruction .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value { @@ -1563,12 +1572,32 @@ void init_triton_ir(py::module &&m) { return self.create(loc, c.getType(), a, b, c, allowTF32); }) - // .def("create_exp", &ir::builder::create_exp, ret::reference) - // .def("create_cos", &ir::builder::create_cos, ret::reference) - // .def("create_sin", &ir::builder::create_sin, ret::reference) - // .def("create_log", &ir::builder::create_log, ret::reference) + .def("create_exp", + [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, val); + }) + .def("create_cos", + [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, val); + }) + .def("create_sin", + [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, val); + }) + .def("create_log", + [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, val); + }) + .def("create_sqrt", + [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, val); + }) // .def("create_trans", &ir::builder::create_trans, ret::reference) - // .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) .def("create_reduce", [](mlir::OpBuilder &self, mlir::Value &operand, mlir::triton::RedOp redOp, int axis) -> mlir::Value { diff --git a/python/tests/test_math_ops.py b/python/tests/test_math_ops.py new file mode 100644 index 000000000..f5ed9fdf9 --- /dev/null +++ b/python/tests/test_math_ops.py @@ -0,0 +1,33 @@ + +import triton +import triton.language as tl + + +@triton.jit +def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + x1 = tl.load(x1_ptr + offsets, mask=offsets < n) + x2 = tl.load(x2_ptr + offsets, mask=offsets < n) + x3 = tl.load(x3_ptr + offsets, mask=offsets < n) + x4 = tl.load(x4_ptr + offsets, mask=offsets < n) + + y1 = tl.sin(x1) + y2 = tl.libdevice.sin(x2) + y3 = tl.libdevice.fdiv_rn(x3, x3) + y4 = tl.libdevice.fmaf_rd(x4, x4, x4) + + tl.store(x1_ptr + offsets, y1, mask=offsets < n) + tl.store(x2_ptr + offsets, y2, mask=offsets < n) + tl.store(x3_ptr + offsets, y3, mask=offsets < n) + tl.store(x4_ptr + offsets, y4, mask=offsets < n) + + +def test_empty_kernel_cubin_compile(): + kernel = triton.compile(math_kernel, + "*fp32,*fp32,*fp32,*fp32,i32", + device=0, + constants={"BLOCK_SIZE": 256}, + output="ttgir") # "cubin" + assert kernel + # TODO: Check if the values are correct. + # TODO: Cover all the math operators diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 821889dd6..203c114a6 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -671,7 +671,8 @@ class CodeGenerator(ast.NodeVisitor): results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) return tuple(results) if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language.core: + sys.modules[fn.__module__] is triton.language.core or \ + isinstance(fn, triton.language.extern.ExternalFunction): return fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0b04465eb..6b0058dd5 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa: F401 -from . import core, random +from . import core, extern, libdevice, random from .core import * from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df5477db2..26277257b 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -234,11 +234,15 @@ class pointer_type(dtype): class block_type(dtype): - def __init__(self, element_ty: dtype, shape: List[int]): + def __init__(self, element_ty: dtype, shape: List): self.element_ty = element_ty - # FIXME: - # block_type's shape is a list of int - # while tensor's shape is a list of constexpr + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert shape + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + self.shape = shape self.numel = 1 for s in self.shape: diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py new file mode 100644 index 000000000..2ef440633 --- /dev/null +++ b/python/triton/language/extern.py @@ -0,0 +1,104 @@ +from __future__ import annotations # remove after python 3.11 + +from . import core, semantic + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, core.tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = core.block_type(ret_type, ret_shape) + return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type) + + +def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape + func = getattr(_builder, "create_external_elementwise") + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) + + +class ExternalFunction: + ''' + A wrapper for external functions + ''' + + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + if '_builder' not in kwargs or \ + kwargs['_builder'] is None: + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") + return self.fn(*args, **kwargs) + + +def extern(fn): + ''' + A decorator for external functions + ''' + return ExternalFunction(fn) diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py new file mode 100644 index 000000000..226480fa2 --- /dev/null +++ b/python/triton/language/libdevice.py @@ -0,0 +1,1661 @@ +import os + +from . import core, extern + +LIBDEVICE_PATH = os.path.dirname( + os.path.abspath(__file__)) + "/libdevice.10.bc" + + +@extern.extern +def clz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def popc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), + }, _builder) + + +@extern.extern +def min(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def max(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def mulhi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def mul64hi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), + }, _builder) + + +@extern.extern +def mul24(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def brev(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), + }, _builder) + + +@extern.extern +def sad(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def abs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def floor(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcp64h(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rsqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ceil(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def trunc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def exp2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def saturatef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fast_fdividef(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ddiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), + (core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def float2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def hiloint2double(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double2loint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2hiint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def ll2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def int_as_float(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def float_as_int(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float_as_uint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def longlong_as_double(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double_as_longlong(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fast_sinf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_cosf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log2f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_logf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_expf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_tanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_exp10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def pow(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def rhadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def fsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ffs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def rint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llrint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), + }, _builder) + + +@extern.extern +def nearbyint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def isnanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def signbitf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def copysign(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def finitef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinff(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), + }, _builder) + + +@extern.extern +def nextafter(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinpi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cospi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan2(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log1p(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def expm1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rhypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def yn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def jn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcx(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def lgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ldexp(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def scalbn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fmod(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def remainder(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def powi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def round(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llround(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fdim(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ilogb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), + }, _builder) + + +@extern.extern +def logb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def signbitd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isfinited(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinfd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isnand(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), + }, _builder) + + +@extern.extern +def dsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), + }, _builder)