diff --git a/include/triton/codegen/extern_lib.h b/include/triton/codegen/extern_lib.h new file mode 100644 index 000000000..c161ff142 --- /dev/null +++ b/include/triton/codegen/extern_lib.h @@ -0,0 +1,89 @@ +#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_ +#define _TRITON_CODE_GEN_EXTERN_LIB_H_ + +#include +#include + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" + +namespace triton { +namespace codegen { + +/// +/// \brief ExternLib is a class that represents a library of external functions. +/// +class ExternLib { + public: + ExternLib(const std::string &name, const std::string &path) + : name_(name), path_(path) {} + + virtual ~ExternLib() = default; + + virtual const std::string &name() const { return name_; } + + virtual const std::string &path() const { return path_; } + + /// + /// \brief Load the library and return the module. + /// + std::unique_ptr load(llvm::LLVMContext &ctx); + + /// + /// \brief Link the module into the given module. + /// + void link(std::unique_ptr &llvm, + std::unique_ptr &mod); + + /// + /// \brief Run load, link, and opt on the module. + /// + virtual void install(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) { + auto mod = load(ctx); + link(llvm, mod); + opt(ctx, llvm); + } + + /// + /// \brief Run opt on the module. + /// + virtual void opt(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) = 0; + + private: + std::string name_; + std::string path_; +}; + +/// +/// \brief ExternLibMap is a map of ExternLibs from their names to their paths. +/// +typedef std::map> ExternLibMap; + +/// +/// \brief Concrete class for NVIDIA's libdevice library. +/// +class LibDevice final : public ExternLib { + public: + LibDevice(const std::string &name, const std::string &path) + : ExternLib(name, path) {} + + virtual ~LibDevice() = default; + + virtual void opt(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) override; +}; + +/// +/// \brief Create an ExternLib instance based on the name and path. +/// +std::unique_ptr create_extern_lib(const std::string &lib_name, + const std::string &lib_path); + +} // namespace codegen +} // namespace triton + +#endif diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index 0c8f11315..95b00b807 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -3,6 +3,7 @@ #include +#include "extern_lib.h" namespace llvm{ class Module; @@ -30,12 +31,10 @@ namespace codegen{ // TODO: // There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, - codegen::target* target, - int sm, int num_warps, - int num_stages, int &shared_static); - - +std::unique_ptr add_passes_to_emit_bin( + ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target, + int num_warps, int num_stages, int &shared_static, + const ExternLibMap &extern_libs); } } diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index b408a46ca..7867c356b 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -6,6 +6,7 @@ #include "triton/ir/visitor.h" #include "triton/ir/instructions.h" #include "triton/codegen/analysis/layout.h" +#include "triton/codegen/extern_lib.h" #include // forward @@ -199,6 +200,7 @@ private: void visit_make_range(ir::make_range*); void visit_clock_inst(ir::clock_inst*); void visit_globaltimer_inst(ir::globaltimer_inst*); + void visit_extern_elementwise_inst(ir::extern_elementwise_inst*); // void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); @@ -209,18 +211,26 @@ private: void visit_argument(ir::argument*); void visit(ir::module &, llvm::Module &); - // layouts void visit_layout_mma(analysis::mma_layout*); void visit_layout_scanline(analysis::scanline_layout*); void visit_layout_shared(analysis::shared_layout*); + // Add a new external library based on given name and path if it doesn't exist + void add_extern_lib(const std::string &lib_name, const std::string &lib_path); -private: + // Get all external libraries + const ExternLibMap &get_extern_lib_map() { + return extern_lib_map_; + } + + private: LLVMContext *ctx_; Builder* builder_; Module *mod_; + std::map> extern_lib_map_; + analysis::axes *a_axes_; analysis::swizzle *swizzle_; std::map axes_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 74028f822..8eb1c2ce3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -169,6 +169,12 @@ public: // Utilities value *create_clock(); value *create_globaltimer(); + // Extern instruction + value *create_extern_elementwise(const std::string &lib_name, + const std::string &lib_path, + const std::string &symbol_name, + const std::vector &args, + type *ret_ty); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 3fa008606..4e60d3444 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -154,6 +154,8 @@ enum value_id_t: unsigned { INST_COS, INST_SIN, INST_LOG, + // extern + INST_EXTERN_ELEMENTWISE, // array arithmetic INST_TRANS, INST_REDUCE, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 402208a8b..1bad86c33 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -1097,7 +1097,28 @@ public: static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); }; +class extern_elementwise_inst : public instruction { + extern_elementwise_inst(context &ctx, const std::vector &args, + type *dst_ty, const std::string &lib_name, + const std::string &extern_lib_path, + const std::string &symbol_name, instruction *next); + std::string repr_impl() const { return "extern_elementwise"; } + _TRITON_DEFINE_CLONE(extern_elementwise_inst) + _TRITON_DEFINE_ACCEPT(extern_elementwise_inst) + public: + static extern_elementwise_inst *create( + context &ctx, const std::vector &args, type *dst_ty, + const std::string &lib_name = "", const std::string &lib_path = "", + const std::string &symbol_name = "", instruction *next = nullptr); + + const std::string &get_lib_name() const { return lib_name_; } + const std::string &get_lib_path() const { return lib_path_; } + + private: + std::string lib_name_ = ""; + std::string lib_path_ = ""; +}; } } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 774f2e172..5f84f414f 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -84,6 +84,8 @@ class prefetch_s_inst; class clock_inst; class globaltimer_inst; +class extern_elementwise_inst; + class make_range_sta; class undef_value; class constant_int; @@ -177,6 +179,8 @@ public: virtual void visit_constant_int(constant_int*) = 0; virtual void visit_constant_fp(constant_fp*) = 0; virtual void visit_alloc_const(alloc_const*) = 0; + + virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0; }; } diff --git a/lib/codegen/extern_lib.cc b/lib/codegen/extern_lib.cc new file mode 100644 index 000000000..0a1f165ea --- /dev/null +++ b/lib/codegen/extern_lib.cc @@ -0,0 +1,63 @@ +#include "triton/codegen/extern_lib.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Type.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "triton/codegen/pass.h" + +namespace triton { + +namespace codegen { + +std::unique_ptr ExternLib::load(llvm::LLVMContext& ctx) { + llvm::SMDiagnostic err; + auto mod = llvm::parseIRFile(this->path_, err, ctx); + if (!mod) { + throw std::runtime_error("Failed to load extern lib " + this->name_ + + " at " + this->path_); + } + return mod; +} + +void ExternLib::link(std::unique_ptr& llvm, + std::unique_ptr& mod) { + // Set triple and data layout to match the target module + mod->setTargetTriple(llvm->getTargetTriple()); + mod->setDataLayout(llvm->getDataLayout()); + if (llvm::Linker::linkModules(*llvm, std::move(mod))) { + throw std::runtime_error("Failed to link extern lib " + this->name_ + + " at " + this->path_); + } +} + +void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr& llvm) { + // Add nvvm reflect flags to llvm module + // https://llvm.org/docs/LangRef.html#module-flags-metadata + // i32 4: Override the other module. + // i32 1: Emit an error + // If both modules specify Override, but the values differ, an error + // will be emitted. + llvm::Type* I32 = llvm::Type::getInt32Ty(ctx); + llvm::Metadata* md_four = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); + llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz"); + llvm::Metadata* md_one = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); + llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one}); + llvm->addModuleFlag(reflect); +} + +std::unique_ptr create_extern_lib(const std::string& lib_name, + const std::string& lib_path) { + if (lib_name == "libdevice") { + return std::make_unique(lib_name, lib_path); + } else { + throw std::runtime_error("Unknown external library: " + lib_name); + } +} + +} // namespace codegen +} // namespace triton diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 412e2f4c8..645f10978 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -1,4 +1,14 @@ #include "triton/codegen/pass.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/axes.h" @@ -9,24 +19,66 @@ #include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/disassociate.h" +#include "triton/codegen/transform/inline.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" #include "triton/codegen/transform/prefetch.h" -#include "triton/codegen/transform/inline.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/print.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" + namespace triton { namespace codegen { +static void link_extern_libs(const ExternLibMap& user_extern_lib_map, + const ExternLibMap& target_extern_lib_map, + ir::module& ir, llvm::LLVMContext& ctx, + std::unique_ptr& llvm) { + for (const auto& iter : target_extern_lib_map) { + auto &lib_name = iter.first; + if (user_extern_lib_map.count(lib_name) != 0 && + user_extern_lib_map.at(lib_name)->path() != "") { + // If the user specified a path for this library, use it. + user_extern_lib_map.at(lib_name)->install(ctx, llvm); + } else { + // Otherwise, use the default path. + iter.second->install(ctx, llvm); + } + } + + std::set function_names; + for (auto& func : ir.get_function_list()) { + function_names.insert(func->get_name()); + } + llvm::legacy::PassManager pass; + pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool { + if (function_names.count(v.getName()) != 0) { + // Preserve global functions + return true; + } + // Internalize all device functions + return false; + })); + + llvm::legacy::PassManager pm; + pm.add(llvm::createVerifierPass()); + pm.run(*llvm); + + llvm::PassManagerBuilder builder; + builder.OptLevel = 3; + builder.SizeLevel = 0; + builder.populateModulePassManager(pass); + + pass.run(*llvm); +} + // TODO: // There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, - int cc, int num_warps, int num_stages, int& shared_static) { +std::unique_ptr add_passes_to_emit_bin( + ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target, + int num_warps, int num_stages, int& shared_static, + const ExternLibMap& extern_lib_map) { // generate llvm code std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); @@ -47,8 +99,10 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC codegen::transform::peephole peephole(target, &layouts); codegen::transform::coalesce coalesce(&align, &layouts, has_sm80); codegen::transform::prefetch prefetch_s(target); - codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); + codegen::transform::membar barriers(&liveness, &layouts, &allocation, + &prefetch_s, target); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, + target, num_warps); // run passes inliner.run(ir); dce.run(ir); @@ -56,7 +110,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC peephole.run(ir); dce.run(ir); pipeline.run(ir); - dce.run(ir); + dce.run(ir); disassociate.run(ir); dce.run(ir); align.run(ir); @@ -64,8 +118,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC layouts.run(ir); peephole.run(ir); dce.run(ir); - if (target->is_gpu()) - cts.run(ir); + if (target->is_gpu()) cts.run(ir); align.run(ir); axes.run(ir); layouts.run(ir); @@ -73,8 +126,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC dce.run(ir); align.run(ir); dce.run(ir); - if (target->is_gpu()) - cts.run(ir); + if (target->is_gpu()) cts.run(ir); dce.run(ir); align.run(ir); axes.run(ir); @@ -97,8 +149,15 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); + + if (isel.get_extern_lib_map().size() > 0) { + // If there's any extern lib calls, + // we need to link them in. + link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm); + } + return llvm; } -} // namespace codegen -} // namespace triton +} // namespace codegen +} // namespace triton diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e2303b990..b30283ced 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){ for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); } - } +} /** * \brief Code Generation for `umulhi` @@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) { call(iasm); } +/** + * \brief Code Generation for `extern_elementwise` + */ +void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) { + std::vector operand_types; + for (size_t j = 0; j < i->get_num_operands(); j++) { + operand_types.push_back( + cvt(i->get_operand(j)->get_type()->get_scalar_ty())); + } + Type *ret_type = cvt(i->get_type()->get_scalar_ty()); + FunctionType *FT = + FunctionType::get(ret_type, std::move(operand_types), false); + Function *F = llvm::cast( + mod_->getOrInsertFunction(i->get_name(), FT).getCallee()); + for (auto idx : idxs_.at(i)) { + std::vector args; + for (size_t j = 0; j < i->get_num_operands(); j++) { + args.emplace_back(vals_[i->get_operand(j)][idx]); + } + vals_[i][idx] = call(F, std::move(args)); + } + add_extern_lib(i->get_lib_name(), i->get_lib_path()); +} + //void generator::visit_make_range_dyn(ir::make_range_dyn* x) { // for(indices_t idx: idxs_.at(x)){ // assert(idx.size() == 1); @@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) { visit_function(fn); } +void generator::add_extern_lib(const std::string &lib_name, + const std::string &lib_path) { + if (extern_lib_map_.count(lib_name) == 0) { + extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path); + } else if (extern_lib_map_.at(lib_name)->path() != lib_path) { + throw std::runtime_error("A library has multiple paths (1) " + lib_path + + " (2) " + extern_lib_map_.at(lib_name)->path()); + } +} -} -} +} // namespace codegen +} // namespace triton diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 92a6b75de..c4a13b806 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -358,8 +358,5 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) { return ret; } - - -} -} - +} // namespace driver +} // namespace triton diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 510994fd8..120b575cf 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -379,6 +379,19 @@ value *builder::create_globaltimer() { return insert(globaltimer_inst::create(ctx_)); } +//===----------------------------------------------------------------------===// +// externs +//===----------------------------------------------------------------------===// + +value *builder::create_extern_elementwise(const std::string &lib_name, + const std::string &lib_path, + const std::string &symbol_name, + const std::vector &args, + type *ret_ty) { + return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name, + lib_path, symbol_name)); +} + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index dbee5e0ee..7831e1650 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -988,6 +988,28 @@ globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name return new globaltimer_inst(ctx, name, next); } +// extern elementwise +extern_elementwise_inst::extern_elementwise_inst( + context &ctx, const std::vector &args, type *ret_ty, + const std::string &lib_name, const std::string &lib_path, + const std::string &symbol_name, instruction *next) + : instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), symbol_name, + next), + lib_name_(lib_name), + lib_path_(lib_path) { + for (size_t i = 0; i < args.size(); i++) { + set_operand(i, args[i]); + } +} + +extern_elementwise_inst *extern_elementwise_inst::create( + context &ctx, const std::vector &args, type *ret_ty, + const std::string &lib_name, const std::string &lib_path, + const std::string &symbol_name, instruction *next) { + return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path, + symbol_name, next); +} + // clock clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next) : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { } diff --git a/python/setup.py b/python/setup.py index 7ed6ab444..6c136b6c7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -98,7 +98,7 @@ class CMakeBuild(build_ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) # python directories - python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] + python_include_dirs = [distutils.sysconfig.get_python_inc()] cmake_args = [ "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DBUILD_TUTORIALS=OFF", diff --git a/python/src/triton.cc b/python/src/triton.cc index 4e1849733..fcebeeb5f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,5 +1,6 @@ #include "triton/codegen/pass.h" #include "triton/codegen/target.h" +#include "triton/codegen/extern_lib.h" #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" @@ -19,7 +20,6 @@ #include #include #include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" namespace py = pybind11; @@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){ // Launch void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, - int num_warps, int num_stages) { + int num_warps, int num_stages, py::dict& extern_libs) { size_t len = PyList_Size(args.ptr()); params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; @@ -256,6 +256,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f throw std::runtime_error(err_msg); } params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); + + for (auto item : extern_libs) { + cache_key += "-" + item.first.cast(); + cache_key += "_" + item.second.cast(); + } } // @@ -288,7 +293,7 @@ void init_triton_runtime(py::module &&m) { // cache key m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, - py::function add_to_cache, py::object grid){ + py::dict extern_libs, py::function add_to_cache, py::object grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments long _num_warps = PyLong_AsLong(num_warps.ptr()); long _num_stages = PyLong_AsLong(num_stages.ptr()); @@ -296,13 +301,14 @@ void init_triton_runtime(py::module &&m) { std::string params; size_t params_size; py::dict constants; - parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, + params_size, constants, _num_warps, _num_stages, extern_libs); // get cached binary py::str key(cache_key); py::bool_ noop = false; if(!bin_cache.contains(key)) { - noop = add_to_cache(key, args, device, num_warps, num_stages); + noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs); } if (noop) return (py::object)py::none(); @@ -467,11 +473,10 @@ std::tuple hip_load_binary(const std::st // --------------------------------------- // CUDA -std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ - - int n_shared_bytes; +std::tuple cu_compile_ttir( + const std::string &name, ir::module &ir, uint64_t device, int num_warps, + int num_stages, asm_map_t &asm_map, + const triton::codegen::ExternLibMap &extern_lib_map) { py::gil_scoped_release allow_threads; llvm::LLVMContext ctx; // device properties @@ -483,7 +488,9 @@ std::tuple cu_compile_ttir(const std::string& name, std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); + int n_shared_bytes; + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -502,14 +509,16 @@ std::tuple cu_compile_ttir(const std::string& name, } // HIP -std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ +std::tuple hip_compile_ttir( + const std::string &name, ir::module &ir, uint64_t device, int num_warps, + int num_stages, asm_map_t &asm_map, + const triton::codegen::ExternLibMap &extern_lib_map) { llvm::LLVMContext ctx; // Triton-IR -> NVPTX LLVM-IR triton::codegen::amd_cl_target target; int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes); + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -523,7 +532,9 @@ std::tuple hip_compile_ttir(const std::string& name void init_triton_codegen(py::module &&m) { m.def( - "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) { + "compile_ttir", + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, + int num_stages, py::dict& extern_libs) { std::string name = ir.get_function_list()[0]->get_name(); // record asm as we generate asm_map_t asm_map; @@ -531,11 +542,20 @@ void init_triton_codegen(py::module &&m) { ir.print(ttir); asm_map["ttir"] = py::cast(ttir.str()); llvm::LLVMContext ctx; + // construct extern lib map + triton::codegen::ExternLibMap extern_lib_map; + for (auto item : extern_libs) { + auto name = item.first.cast(); + auto path = item.second.cast(); + extern_lib_map.emplace( + name, triton::codegen::create_extern_lib(name, path)); + } if(backend == CUDA) - return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); + return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); - }, py::return_value_policy::take_ownership); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); + }, + py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::gil_scoped_release allow_threads; if(backend == CUDA) @@ -931,7 +951,8 @@ void init_triton_ir(py::module &&m) { // Utilities .def("create_clock", &ir::builder::create_clock, ret::reference) .def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference) - + // Extern instruction + .def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference) // Built-in instruction .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d032d1e39..cb2cb9c33 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1300,3 +1300,49 @@ def test_num_warps_pow2(): _kernel[(1,)](dst=dst, num_warps=1) _kernel[(1,)](dst=dst, num_warps=2) _kernel[(1,)](dst=dst, num_warps=4) + +# ------------- +# test extern +# ------------- + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('int32', 'libdevice.ffs', ''), + ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('float64', 'libdevice.norm4d', '')]) +def test_libdevice(dtype_str, expr, lib_path): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + + if expr == 'libdevice.ffs': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) + y_ref = np.zeros(shape, dtype=x.dtype) + for i in range(shape[0]): + y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() + elif expr == 'libdevice.pow': + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref = np.power(x, x) + elif expr == 'libdevice.norm4d': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) + y_ref = np.sqrt(4 * np.power(x, 2)) + + x_tri = to_triton(x) + # triton result + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + if expr == 'libdevice.ffs': + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + else: + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 30a79bcc9..3951d8b6b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -689,7 +689,7 @@ class CodeGenerator(ast.NodeVisitor): ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type) return ret # built-in function - if sys.modules[fn.__module__] is triton.language.core: + if sys.modules[fn.__module__] is triton.language.core or isinstance(fn, triton.language.extern.ExternalFunction): ret = fn(*args, _builder=self.builder, **kws) if fn in self.value_constructor.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg @@ -933,7 +933,7 @@ class Kernel: self.fn = fn self.cache_key = {} - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes @@ -953,9 +953,10 @@ class Kernel: constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) + return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, + extern_libs=extern_libs, is_manual_warmup=False) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs): assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2." # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -985,7 +986,7 @@ class Kernel: cache_key = self.cache_key[device] stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, - device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, + device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache, grid) @@ -1242,7 +1243,7 @@ class JITFunction: def warmup(self, compile): return self._warmup(**compile, is_manual_warmup=True) - def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup): + def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup): hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory @@ -1264,7 +1265,7 @@ class JITFunction: with open(bin_cache_path, 'rb') as f: binary = pickle.load(f)["binary"] - compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages) + compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs) if JITFunction.cache_hook is not None: name = self.__name__ info = key.split('-')[-3:] @@ -1293,7 +1294,7 @@ class JITFunction: self.bin_cache[key] = LoadedBinary(device, binary) return False - def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): + def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs): # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -1316,7 +1317,7 @@ class JITFunction: backend = _triton.runtime.backend.CUDA else: backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs) max_shared_memory = _triton.runtime.max_shared_memory(backend, device) if shared_mem > max_shared_memory: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0b04465eb..6b0058dd5 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa: F401 -from . import core, random +from . import core, extern, libdevice, random from .core import * from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index cc0db5566..4197a3333 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -248,8 +248,10 @@ class block_type(dtype): # while tensor's shape is a list of constexpr self.shape = shape self.numel = 1 - for s in self.shape: - self.numel *= s + for i, s in enumerate(self.shape): + if isinstance(s, constexpr): + self.shape[i] = s.value + self.numel *= self.shape[i] self.name = self.__str__() diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py new file mode 100644 index 000000000..a306a2e9a --- /dev/null +++ b/python/triton/language/extern.py @@ -0,0 +1,107 @@ +from __future__ import annotations # remove after python 3.11 + +from . import core, semantic + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None): + ''' + Dispatch a function to a library + + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, core.tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type + return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type) + + +def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None): + ''' + Dispatch an elementwise function to a library + + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param _builder: the builder + + :return: the return value of the function + ''' + dispatch_args = args.copy() + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape + func = getattr(_builder, "create_extern_elementwise") + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) + + +class ExternalFunction: + ''' + A wrapper for external functions + ''' + + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + if '_builder' not in kwargs or \ + kwargs['_builder'] is None: + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") + return self.fn(*args, **kwargs) + + +def extern(fn): + ''' + A decorator for external functions + ''' + return ExternalFunction(fn) diff --git a/python/triton/language/libdevice.10.bc b/python/triton/language/libdevice.10.bc new file mode 100644 index 000000000..ef3ae8d81 Binary files /dev/null and b/python/triton/language/libdevice.10.bc differ diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py new file mode 100644 index 000000000..226480fa2 --- /dev/null +++ b/python/triton/language/libdevice.py @@ -0,0 +1,1661 @@ +import os + +from . import core, extern + +LIBDEVICE_PATH = os.path.dirname( + os.path.abspath(__file__)) + "/libdevice.10.bc" + + +@extern.extern +def clz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def popc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), + }, _builder) + + +@extern.extern +def min(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def max(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def mulhi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def mul64hi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), + }, _builder) + + +@extern.extern +def mul24(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def brev(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), + }, _builder) + + +@extern.extern +def sad(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def abs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def floor(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcp64h(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rsqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ceil(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def trunc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def exp2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def saturatef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fast_fdividef(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ddiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), + (core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def float2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def hiloint2double(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double2loint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2hiint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def ll2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def int_as_float(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def float_as_int(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float_as_uint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def longlong_as_double(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double_as_longlong(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fast_sinf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_cosf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log2f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_logf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_expf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_tanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_exp10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def pow(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def rhadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def fsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ffs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def rint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llrint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), + }, _builder) + + +@extern.extern +def nearbyint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def isnanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def signbitf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def copysign(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def finitef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinff(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), + }, _builder) + + +@extern.extern +def nextafter(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinpi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cospi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan2(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log1p(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def expm1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rhypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def yn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def jn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcx(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def lgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ldexp(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def scalbn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fmod(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def remainder(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def powi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def round(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llround(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fdim(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ilogb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), + }, _builder) + + +@extern.extern +def logb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def signbitd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isfinited(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinfd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isnand(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), + }, _builder) + + +@extern.extern +def dsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), + }, _builder) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py new file mode 100644 index 000000000..6d0a04e8e --- /dev/null +++ b/python/triton/tools/build_extern.py @@ -0,0 +1,340 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod + + +class Symbol: + def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: + ''' + A symbol is a function declaration. + + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = arg_names + self._arg_types = arg_types + + @property + def name(self): + return self._name + + @property + def op_name(self): + return self._op_name + + @property + def ret_type(self): + return self._ret_type + + @property + def arg_names(self): + return self._arg_names + + @property + def arg_types(self): + return self._arg_types + + +def convert_type(type_str): + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str): + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: + ''' + Abstract class for extern library. + + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = True + self._grouping = grouping + + @property + def name(self): + return self._name + + @property + def path(self): + return self._path + + @property + def symbols(self): + return self._symbols + + @property + def grouping(self): + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file): + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir): + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], + stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + + def _extract_symbol(self, line): + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self): + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + # The following cases are grouped together: + # op_name, op_name + for symbol in self._symbols.values(): + op_name = symbol.op_name + if "max" in op_name: + op_name = "max" + elif "min" in op_name: + op_name = "min" + elif "abs" in op_name: + op_name = "abs" + elif "pow" in op_name and "fast" in op_name: + op_name = "pow" + elif "round" in op_name: + if "llround" in op_name: + op_name = "llround" + else: + op_name = "round" + elif "rint" in op_name: + if "llrint" in op_name: + op_name = "llrint" + else: + op_name = "rint" + elif op_name.startswith("ull"): + if "2" not in op_name: + # e.g., ullmax->max + op_name = op_name[3:] + else: + # e.g., ull2double->ll2double + op_name = op_name[1:] + elif op_name.startswith("u"): + if "2" not in op_name: + # e.g., uhadd->hadd + op_name = op_name[1:] + else: + # e.g., uint2double_rn->int2double_rn + op_name = op_name[1:] + elif op_name.startswith("ll"): + if "2" not in op_name: + # e.g., llmax->max + op_name = op_name[2:] + elif op_name.endswith("ll"): + op_name = op_name[:-2] + elif op_name.endswith("f"): + op_name = op_name[:-1] + if op_name in symbol_set: + # Update op_name only if there's an existing symbol + symbol._op_name = op_name + else: + op_name = symbol._op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file): + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self): + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return extern.dispatch("libdevice", , , , _builder) + import_str = "from . import core, extern\n" + import_str += "import os\n" + header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@extern.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f"core.dtype(\"{arg_type}\")," + ret_type = f"core.dtype(\"{symbol.ret_type}\")" + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += ", _builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + def __init__(self, path): + ''' + Invoke llvm-dis to disassemble the given file. + + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path): + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], + stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self): + return self._ll_file + + @property + def path(self): + return self._path + + +extern_libs = ["libdevice"] + + +def build(llvm_dis_path, lib_path, lib_name, output_dir): + ''' + Interface function to build the library file. + + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library") + parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/python/tutorials/07-libdevice-function.py b/python/tutorials/07-libdevice-function.py new file mode 100644 index 000000000..bb5f7b26d --- /dev/null +++ b/python/tutorials/07-libdevice-function.py @@ -0,0 +1,74 @@ +""" +Libdevice function +=============== +Triton can invoke a custom function from an external library. +In this example, we will use the `libdevice` library to apply `asin` on a tensor. +Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions. + +In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. +For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. +Using triton, you can simply call `tl.libdevice.asinf`. +triton automatically selects the correct underlying device function to invoke based on input and output types. +""" + +# %% +# asin Kernel +# -------------------------- + +import torch + +import triton +import triton.language as tl + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + +# %% +# Using the default libdevice library path +# -------------------------- +# We can use the default libdevice library path encoded in `triton/language/libdevice.py` + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device='cuda') +output_triton = torch.zeros(size, device='cuda') +output_torch = torch.asin(x) +assert x.is_cuda and output_triton.is_cuda +n_elements = output_torch.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) + +# %% +# Customize the libdevice library path +# -------------------------- +# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. + +output_triton = torch.empty_like(x) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, + extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +)