[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)
This commit is contained in:
89
include/triton/codegen/extern_lib.h
Normal file
89
include/triton/codegen/extern_lib.h
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
|
||||||
|
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "llvm/IR/LLVMContext.h"
|
||||||
|
#include "llvm/IR/Module.h"
|
||||||
|
#include "llvm/IRReader/IRReader.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief ExternLib is a class that represents a library of external functions.
|
||||||
|
///
|
||||||
|
class ExternLib {
|
||||||
|
public:
|
||||||
|
ExternLib(const std::string &name, const std::string &path)
|
||||||
|
: name_(name), path_(path) {}
|
||||||
|
|
||||||
|
virtual ~ExternLib() = default;
|
||||||
|
|
||||||
|
virtual const std::string &name() const { return name_; }
|
||||||
|
|
||||||
|
virtual const std::string &path() const { return path_; }
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief Load the library and return the module.
|
||||||
|
///
|
||||||
|
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief Link the module into the given module.
|
||||||
|
///
|
||||||
|
void link(std::unique_ptr<llvm::Module> &llvm,
|
||||||
|
std::unique_ptr<llvm::Module> &mod);
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief Run load, link, and opt on the module.
|
||||||
|
///
|
||||||
|
virtual void install(llvm::LLVMContext &ctx,
|
||||||
|
std::unique_ptr<llvm::Module> &llvm) {
|
||||||
|
auto mod = load(ctx);
|
||||||
|
link(llvm, mod);
|
||||||
|
opt(ctx, llvm);
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief Run opt on the module.
|
||||||
|
///
|
||||||
|
virtual void opt(llvm::LLVMContext &ctx,
|
||||||
|
std::unique_ptr<llvm::Module> &llvm) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
std::string path_;
|
||||||
|
};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
|
||||||
|
///
|
||||||
|
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief Concrete class for NVIDIA's libdevice library.
|
||||||
|
///
|
||||||
|
class LibDevice final : public ExternLib {
|
||||||
|
public:
|
||||||
|
LibDevice(const std::string &name, const std::string &path)
|
||||||
|
: ExternLib(name, path) {}
|
||||||
|
|
||||||
|
virtual ~LibDevice() = default;
|
||||||
|
|
||||||
|
virtual void opt(llvm::LLVMContext &ctx,
|
||||||
|
std::unique_ptr<llvm::Module> &llvm) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// \brief Create an ExternLib instance based on the name and path.
|
||||||
|
///
|
||||||
|
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
|
||||||
|
const std::string &lib_path);
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace triton
|
||||||
|
|
||||||
|
#endif
|
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include "extern_lib.h"
|
||||||
|
|
||||||
namespace llvm{
|
namespace llvm{
|
||||||
class Module;
|
class Module;
|
||||||
@@ -30,12 +31,10 @@ namespace codegen{
|
|||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// There should be a proper pass manager there!
|
// There should be a proper pass manager there!
|
||||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||||
codegen::target* target,
|
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
|
||||||
int sm, int num_warps,
|
int num_warps, int num_stages, int &shared_static,
|
||||||
int num_stages, int &shared_static);
|
const ExternLibMap &extern_libs);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
#include "triton/ir/visitor.h"
|
#include "triton/ir/visitor.h"
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
|
#include "triton/codegen/extern_lib.h"
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
// forward
|
// forward
|
||||||
@@ -199,6 +200,7 @@ private:
|
|||||||
void visit_make_range(ir::make_range*);
|
void visit_make_range(ir::make_range*);
|
||||||
void visit_clock_inst(ir::clock_inst*);
|
void visit_clock_inst(ir::clock_inst*);
|
||||||
void visit_globaltimer_inst(ir::globaltimer_inst*);
|
void visit_globaltimer_inst(ir::globaltimer_inst*);
|
||||||
|
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
|
||||||
// void visit_make_range_sta(ir::make_range_sta*);
|
// void visit_make_range_sta(ir::make_range_sta*);
|
||||||
void visit_undef_value(ir::undef_value*);
|
void visit_undef_value(ir::undef_value*);
|
||||||
void visit_constant_int(ir::constant_int*);
|
void visit_constant_int(ir::constant_int*);
|
||||||
@@ -209,18 +211,26 @@ private:
|
|||||||
void visit_argument(ir::argument*);
|
void visit_argument(ir::argument*);
|
||||||
void visit(ir::module &, llvm::Module &);
|
void visit(ir::module &, llvm::Module &);
|
||||||
|
|
||||||
|
|
||||||
// layouts
|
// layouts
|
||||||
void visit_layout_mma(analysis::mma_layout*);
|
void visit_layout_mma(analysis::mma_layout*);
|
||||||
void visit_layout_scanline(analysis::scanline_layout*);
|
void visit_layout_scanline(analysis::scanline_layout*);
|
||||||
void visit_layout_shared(analysis::shared_layout*);
|
void visit_layout_shared(analysis::shared_layout*);
|
||||||
|
|
||||||
|
// Add a new external library based on given name and path if it doesn't exist
|
||||||
|
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
|
||||||
|
|
||||||
private:
|
// Get all external libraries
|
||||||
|
const ExternLibMap &get_extern_lib_map() {
|
||||||
|
return extern_lib_map_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
LLVMContext *ctx_;
|
LLVMContext *ctx_;
|
||||||
Builder* builder_;
|
Builder* builder_;
|
||||||
Module *mod_;
|
Module *mod_;
|
||||||
|
|
||||||
|
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
|
||||||
|
|
||||||
analysis::axes *a_axes_;
|
analysis::axes *a_axes_;
|
||||||
analysis::swizzle *swizzle_;
|
analysis::swizzle *swizzle_;
|
||||||
std::map<unsigned, distributed_axis> axes_;
|
std::map<unsigned, distributed_axis> axes_;
|
||||||
|
@@ -169,6 +169,12 @@ public:
|
|||||||
// Utilities
|
// Utilities
|
||||||
value *create_clock();
|
value *create_clock();
|
||||||
value *create_globaltimer();
|
value *create_globaltimer();
|
||||||
|
// Extern instruction
|
||||||
|
value *create_extern_elementwise(const std::string &lib_name,
|
||||||
|
const std::string &lib_path,
|
||||||
|
const std::string &symbol_name,
|
||||||
|
const std::vector<value *> &args,
|
||||||
|
type *ret_ty);
|
||||||
// Built-in instruction
|
// Built-in instruction
|
||||||
value *create_get_program_id(unsigned axis);
|
value *create_get_program_id(unsigned axis);
|
||||||
value *create_get_num_programs(unsigned axis);
|
value *create_get_num_programs(unsigned axis);
|
||||||
|
@@ -154,6 +154,8 @@ enum value_id_t: unsigned {
|
|||||||
INST_COS,
|
INST_COS,
|
||||||
INST_SIN,
|
INST_SIN,
|
||||||
INST_LOG,
|
INST_LOG,
|
||||||
|
// extern
|
||||||
|
INST_EXTERN_ELEMENTWISE,
|
||||||
// array arithmetic
|
// array arithmetic
|
||||||
INST_TRANS,
|
INST_TRANS,
|
||||||
INST_REDUCE,
|
INST_REDUCE,
|
||||||
|
@@ -1097,7 +1097,28 @@ public:
|
|||||||
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class extern_elementwise_inst : public instruction {
|
||||||
|
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
|
||||||
|
type *dst_ty, const std::string &lib_name,
|
||||||
|
const std::string &extern_lib_path,
|
||||||
|
const std::string &symbol_name, instruction *next);
|
||||||
|
std::string repr_impl() const { return "extern_elementwise"; }
|
||||||
|
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
|
||||||
|
|
||||||
|
public:
|
||||||
|
static extern_elementwise_inst *create(
|
||||||
|
context &ctx, const std::vector<value *> &args, type *dst_ty,
|
||||||
|
const std::string &lib_name = "", const std::string &lib_path = "",
|
||||||
|
const std::string &symbol_name = "", instruction *next = nullptr);
|
||||||
|
|
||||||
|
const std::string &get_lib_name() const { return lib_name_; }
|
||||||
|
const std::string &get_lib_path() const { return lib_path_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string lib_name_ = "";
|
||||||
|
std::string lib_path_ = "";
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -84,6 +84,8 @@ class prefetch_s_inst;
|
|||||||
class clock_inst;
|
class clock_inst;
|
||||||
class globaltimer_inst;
|
class globaltimer_inst;
|
||||||
|
|
||||||
|
class extern_elementwise_inst;
|
||||||
|
|
||||||
class make_range_sta;
|
class make_range_sta;
|
||||||
class undef_value;
|
class undef_value;
|
||||||
class constant_int;
|
class constant_int;
|
||||||
@@ -177,6 +179,8 @@ public:
|
|||||||
virtual void visit_constant_int(constant_int*) = 0;
|
virtual void visit_constant_int(constant_int*) = 0;
|
||||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||||
|
|
||||||
|
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
63
lib/codegen/extern_lib.cc
Normal file
63
lib/codegen/extern_lib.cc
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#include "triton/codegen/extern_lib.h"
|
||||||
|
|
||||||
|
#include "llvm/IR/Constants.h"
|
||||||
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
|
#include "llvm/IR/Metadata.h"
|
||||||
|
#include "llvm/IR/Type.h"
|
||||||
|
#include "llvm/Linker/Linker.h"
|
||||||
|
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||||
|
#include "triton/codegen/pass.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
|
||||||
|
llvm::SMDiagnostic err;
|
||||||
|
auto mod = llvm::parseIRFile(this->path_, err, ctx);
|
||||||
|
if (!mod) {
|
||||||
|
throw std::runtime_error("Failed to load extern lib " + this->name_ +
|
||||||
|
" at " + this->path_);
|
||||||
|
}
|
||||||
|
return mod;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
|
||||||
|
std::unique_ptr<llvm::Module>& mod) {
|
||||||
|
// Set triple and data layout to match the target module
|
||||||
|
mod->setTargetTriple(llvm->getTargetTriple());
|
||||||
|
mod->setDataLayout(llvm->getDataLayout());
|
||||||
|
if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
|
||||||
|
throw std::runtime_error("Failed to link extern lib " + this->name_ +
|
||||||
|
" at " + this->path_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
|
||||||
|
// Add nvvm reflect flags to llvm module
|
||||||
|
// https://llvm.org/docs/LangRef.html#module-flags-metadata
|
||||||
|
// i32 4: Override the other module.
|
||||||
|
// i32 1: Emit an error
|
||||||
|
// If both modules specify Override, but the values differ, an error
|
||||||
|
// will be emitted.
|
||||||
|
llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
|
||||||
|
llvm::Metadata* md_four =
|
||||||
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||||
|
llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||||
|
llvm::Metadata* md_one =
|
||||||
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||||
|
llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
|
||||||
|
llvm->addModuleFlag(reflect);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
|
||||||
|
const std::string& lib_path) {
|
||||||
|
if (lib_name == "libdevice") {
|
||||||
|
return std::make_unique<LibDevice>(lib_name, lib_path);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unknown external library: " + lib_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace triton
|
@@ -1,4 +1,14 @@
|
|||||||
#include "triton/codegen/pass.h"
|
#include "triton/codegen/pass.h"
|
||||||
|
|
||||||
|
#include "llvm/IR/Constants.h"
|
||||||
|
#include "llvm/IR/LegacyPassManager.h"
|
||||||
|
#include "llvm/IR/Module.h"
|
||||||
|
#include "llvm/IR/Verifier.h"
|
||||||
|
#include "llvm/IRReader/IRReader.h"
|
||||||
|
#include "llvm/Linker/Linker.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "llvm/Transforms/IPO.h"
|
||||||
|
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||||
#include "triton/codegen/analysis/align.h"
|
#include "triton/codegen/analysis/align.h"
|
||||||
#include "triton/codegen/analysis/allocation.h"
|
#include "triton/codegen/analysis/allocation.h"
|
||||||
#include "triton/codegen/analysis/axes.h"
|
#include "triton/codegen/analysis/axes.h"
|
||||||
@@ -9,24 +19,66 @@
|
|||||||
#include "triton/codegen/transform/cts.h"
|
#include "triton/codegen/transform/cts.h"
|
||||||
#include "triton/codegen/transform/dce.h"
|
#include "triton/codegen/transform/dce.h"
|
||||||
#include "triton/codegen/transform/disassociate.h"
|
#include "triton/codegen/transform/disassociate.h"
|
||||||
|
#include "triton/codegen/transform/inline.h"
|
||||||
#include "triton/codegen/transform/membar.h"
|
#include "triton/codegen/transform/membar.h"
|
||||||
#include "triton/codegen/transform/peephole.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
#include "triton/codegen/transform/pipeline.h"
|
#include "triton/codegen/transform/pipeline.h"
|
||||||
#include "triton/codegen/transform/prefetch.h"
|
#include "triton/codegen/transform/prefetch.h"
|
||||||
#include "triton/codegen/transform/inline.h"
|
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/print.h"
|
#include "triton/ir/print.h"
|
||||||
#include "llvm/IR/Module.h"
|
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
|
||||||
#include "llvm/IR/Verifier.h"
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
namespace codegen {
|
namespace codegen {
|
||||||
|
|
||||||
|
static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
|
||||||
|
const ExternLibMap& target_extern_lib_map,
|
||||||
|
ir::module& ir, llvm::LLVMContext& ctx,
|
||||||
|
std::unique_ptr<llvm::Module>& llvm) {
|
||||||
|
for (const auto& iter : target_extern_lib_map) {
|
||||||
|
auto &lib_name = iter.first;
|
||||||
|
if (user_extern_lib_map.count(lib_name) != 0 &&
|
||||||
|
user_extern_lib_map.at(lib_name)->path() != "") {
|
||||||
|
// If the user specified a path for this library, use it.
|
||||||
|
user_extern_lib_map.at(lib_name)->install(ctx, llvm);
|
||||||
|
} else {
|
||||||
|
// Otherwise, use the default path.
|
||||||
|
iter.second->install(ctx, llvm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<llvm::StringRef> function_names;
|
||||||
|
for (auto& func : ir.get_function_list()) {
|
||||||
|
function_names.insert(func->get_name());
|
||||||
|
}
|
||||||
|
llvm::legacy::PassManager pass;
|
||||||
|
pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
|
||||||
|
if (function_names.count(v.getName()) != 0) {
|
||||||
|
// Preserve global functions
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Internalize all device functions
|
||||||
|
return false;
|
||||||
|
}));
|
||||||
|
|
||||||
|
llvm::legacy::PassManager pm;
|
||||||
|
pm.add(llvm::createVerifierPass());
|
||||||
|
pm.run(*llvm);
|
||||||
|
|
||||||
|
llvm::PassManagerBuilder builder;
|
||||||
|
builder.OptLevel = 3;
|
||||||
|
builder.SizeLevel = 0;
|
||||||
|
builder.populateModulePassManager(pass);
|
||||||
|
|
||||||
|
pass.run(*llvm);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// There should be a proper pass manager there!
|
// There should be a proper pass manager there!
|
||||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||||
|
int num_warps, int num_stages, int& shared_static,
|
||||||
|
const ExternLibMap& extern_lib_map) {
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
std::string name = ir.get_function_list()[0]->get_name();
|
std::string name = ir.get_function_list()[0]->get_name();
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||||
@@ -47,8 +99,10 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
codegen::transform::peephole peephole(target, &layouts);
|
codegen::transform::peephole peephole(target, &layouts);
|
||||||
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
|
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
|
||||||
codegen::transform::prefetch prefetch_s(target);
|
codegen::transform::prefetch prefetch_s(target);
|
||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation,
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
&prefetch_s, target);
|
||||||
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
|
||||||
|
target, num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
inliner.run(ir);
|
inliner.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
@@ -56,7 +110,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
pipeline.run(ir);
|
pipeline.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
disassociate.run(ir);
|
disassociate.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
@@ -64,8 +118,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
layouts.run(ir);
|
layouts.run(ir);
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
if (target->is_gpu())
|
if (target->is_gpu()) cts.run(ir);
|
||||||
cts.run(ir);
|
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
axes.run(ir);
|
axes.run(ir);
|
||||||
layouts.run(ir);
|
layouts.run(ir);
|
||||||
@@ -73,8 +126,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
if (target->is_gpu())
|
if (target->is_gpu()) cts.run(ir);
|
||||||
cts.run(ir);
|
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
axes.run(ir);
|
axes.run(ir);
|
||||||
@@ -97,8 +149,15 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
// ir.print(std::cout);
|
// ir.print(std::cout);
|
||||||
isel.visit(ir, *llvm);
|
isel.visit(ir, *llvm);
|
||||||
shared_static = allocation.allocated_size();
|
shared_static = allocation.allocated_size();
|
||||||
|
|
||||||
|
if (isel.get_extern_lib_map().size() > 0) {
|
||||||
|
// If there's any extern lib calls,
|
||||||
|
// we need to link them in.
|
||||||
|
link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
|
||||||
|
}
|
||||||
|
|
||||||
return llvm;
|
return llvm;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace codegen
|
} // namespace codegen
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
@@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
|||||||
for(auto idx: idxs_.at(x)){
|
for(auto idx: idxs_.at(x)){
|
||||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `umulhi`
|
* \brief Code Generation for `umulhi`
|
||||||
@@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
|||||||
call(iasm);
|
call(iasm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Code Generation for `extern_elementwise`
|
||||||
|
*/
|
||||||
|
void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
|
||||||
|
std::vector<Type *> operand_types;
|
||||||
|
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||||
|
operand_types.push_back(
|
||||||
|
cvt(i->get_operand(j)->get_type()->get_scalar_ty()));
|
||||||
|
}
|
||||||
|
Type *ret_type = cvt(i->get_type()->get_scalar_ty());
|
||||||
|
FunctionType *FT =
|
||||||
|
FunctionType::get(ret_type, std::move(operand_types), false);
|
||||||
|
Function *F = llvm::cast<llvm::Function>(
|
||||||
|
mod_->getOrInsertFunction(i->get_name(), FT).getCallee());
|
||||||
|
for (auto idx : idxs_.at(i)) {
|
||||||
|
std::vector<llvm::Value *> args;
|
||||||
|
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||||
|
args.emplace_back(vals_[i->get_operand(j)][idx]);
|
||||||
|
}
|
||||||
|
vals_[i][idx] = call(F, std::move(args));
|
||||||
|
}
|
||||||
|
add_extern_lib(i->get_lib_name(), i->get_lib_path());
|
||||||
|
}
|
||||||
|
|
||||||
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||||
// for(indices_t idx: idxs_.at(x)){
|
// for(indices_t idx: idxs_.at(x)){
|
||||||
// assert(idx.size() == 1);
|
// assert(idx.size() == 1);
|
||||||
@@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
|||||||
visit_function(fn);
|
visit_function(fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void generator::add_extern_lib(const std::string &lib_name,
|
||||||
|
const std::string &lib_path) {
|
||||||
|
if (extern_lib_map_.count(lib_name) == 0) {
|
||||||
|
extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path);
|
||||||
|
} else if (extern_lib_map_.at(lib_name)->path() != lib_path) {
|
||||||
|
throw std::runtime_error("A library has multiple paths (1) " + lib_path +
|
||||||
|
" (2) " + extern_lib_map_.at(lib_name)->path());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
} // namespace codegen
|
||||||
}
|
} // namespace triton
|
||||||
|
@@ -358,8 +358,5 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace driver
|
||||||
|
} // namespace triton
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@@ -379,6 +379,19 @@ value *builder::create_globaltimer() {
|
|||||||
return insert(globaltimer_inst::create(ctx_));
|
return insert(globaltimer_inst::create(ctx_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// externs
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
value *builder::create_extern_elementwise(const std::string &lib_name,
|
||||||
|
const std::string &lib_path,
|
||||||
|
const std::string &symbol_name,
|
||||||
|
const std::vector<value *> &args,
|
||||||
|
type *ret_ty) {
|
||||||
|
return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name,
|
||||||
|
lib_path, symbol_name));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// built-in instructions
|
// built-in instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -988,6 +988,28 @@ globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name
|
|||||||
return new globaltimer_inst(ctx, name, next);
|
return new globaltimer_inst(ctx, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extern elementwise
|
||||||
|
extern_elementwise_inst::extern_elementwise_inst(
|
||||||
|
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
||||||
|
const std::string &lib_name, const std::string &lib_path,
|
||||||
|
const std::string &symbol_name, instruction *next)
|
||||||
|
: instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), symbol_name,
|
||||||
|
next),
|
||||||
|
lib_name_(lib_name),
|
||||||
|
lib_path_(lib_path) {
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
set_operand(i, args[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern_elementwise_inst *extern_elementwise_inst::create(
|
||||||
|
context &ctx, const std::vector<value *> &args, type *ret_ty,
|
||||||
|
const std::string &lib_name, const std::string &lib_path,
|
||||||
|
const std::string &symbol_name, instruction *next) {
|
||||||
|
return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
|
||||||
|
symbol_name, next);
|
||||||
|
}
|
||||||
|
|
||||||
// clock
|
// clock
|
||||||
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
|
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
|
||||||
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
|
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
|
||||||
|
@@ -98,7 +98,7 @@ class CMakeBuild(build_ext):
|
|||||||
if not os.path.exists(self.build_temp):
|
if not os.path.exists(self.build_temp):
|
||||||
os.makedirs(self.build_temp)
|
os.makedirs(self.build_temp)
|
||||||
# python directories
|
# python directories
|
||||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
python_include_dirs = [distutils.sysconfig.get_python_inc()]
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||||
"-DBUILD_TUTORIALS=OFF",
|
"-DBUILD_TUTORIALS=OFF",
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#include "triton/codegen/pass.h"
|
#include "triton/codegen/pass.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
|
#include "triton/codegen/extern_lib.h"
|
||||||
#include "triton/driver/error.h"
|
#include "triton/driver/error.h"
|
||||||
#include "triton/driver/llvm.h"
|
#include "triton/driver/llvm.h"
|
||||||
#include "triton/ir/builder.h"
|
#include "triton/ir/builder.h"
|
||||||
@@ -19,7 +20,6 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/LegacyPassManager.h"
|
|
||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
@@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){
|
|||||||
// Launch
|
// Launch
|
||||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||||
int num_warps, int num_stages) {
|
int num_warps, int num_stages, py::dict& extern_libs) {
|
||||||
size_t len = PyList_Size(args.ptr());
|
size_t len = PyList_Size(args.ptr());
|
||||||
params.reserve(8*len); // 8 max bytes by argument
|
params.reserve(8*len); // 8 max bytes by argument
|
||||||
char* params_ptr = ¶ms[0];
|
char* params_ptr = ¶ms[0];
|
||||||
@@ -256,6 +256,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
throw std::runtime_error(err_msg);
|
throw std::runtime_error(err_msg);
|
||||||
}
|
}
|
||||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||||
|
|
||||||
|
for (auto item : extern_libs) {
|
||||||
|
cache_key += "-" + item.first.cast<std::string>();
|
||||||
|
cache_key += "_" + item.second.cast<std::string>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -288,7 +293,7 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
// cache key
|
// cache key
|
||||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||||
py::function add_to_cache, py::object grid){
|
py::dict extern_libs, py::function add_to_cache, py::object grid){
|
||||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||||
@@ -296,13 +301,14 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
std::string params;
|
std::string params;
|
||||||
size_t params_size;
|
size_t params_size;
|
||||||
py::dict constants;
|
py::dict constants;
|
||||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
|
||||||
|
params_size, constants, _num_warps, _num_stages, extern_libs);
|
||||||
|
|
||||||
// get cached binary
|
// get cached binary
|
||||||
py::str key(cache_key);
|
py::str key(cache_key);
|
||||||
py::bool_ noop = false;
|
py::bool_ noop = false;
|
||||||
if(!bin_cache.contains(key)) {
|
if(!bin_cache.contains(key)) {
|
||||||
noop = add_to_cache(key, args, device, num_warps, num_stages);
|
noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs);
|
||||||
}
|
}
|
||||||
if (noop)
|
if (noop)
|
||||||
return (py::object)py::none();
|
return (py::object)py::none();
|
||||||
@@ -467,11 +473,10 @@ std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::st
|
|||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
|
|
||||||
// CUDA
|
// CUDA
|
||||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(
|
||||||
uint64_t device, int num_warps, int num_stages,
|
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
|
||||||
asm_map_t &asm_map){
|
int num_stages, asm_map_t &asm_map,
|
||||||
|
const triton::codegen::ExternLibMap &extern_lib_map) {
|
||||||
int n_shared_bytes;
|
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
// device properties
|
// device properties
|
||||||
@@ -483,7 +488,9 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
|||||||
std::string ptxas_path = drv::path_to_ptxas(version);
|
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
triton::codegen::nvidia_cu_target target(cc);
|
triton::codegen::nvidia_cu_target target(cc);
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
int n_shared_bytes;
|
||||||
|
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||||
|
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||||
std::string tmp;
|
std::string tmp;
|
||||||
llvm::raw_string_ostream llir(tmp);
|
llvm::raw_string_ostream llir(tmp);
|
||||||
llir << *llvm;
|
llir << *llvm;
|
||||||
@@ -502,14 +509,16 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HIP
|
// HIP
|
||||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
|
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(
|
||||||
uint64_t device, int num_warps, int num_stages,
|
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
|
||||||
asm_map_t &asm_map){
|
int num_stages, asm_map_t &asm_map,
|
||||||
|
const triton::codegen::ExternLibMap &extern_lib_map) {
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
triton::codegen::amd_cl_target target;
|
triton::codegen::amd_cl_target target;
|
||||||
int n_shared_bytes;
|
int n_shared_bytes;
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
|
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||||
|
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||||
std::string tmp;
|
std::string tmp;
|
||||||
llvm::raw_string_ostream llir(tmp);
|
llvm::raw_string_ostream llir(tmp);
|
||||||
llir << *llvm;
|
llir << *llvm;
|
||||||
@@ -523,7 +532,9 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
|
|||||||
|
|
||||||
void init_triton_codegen(py::module &&m) {
|
void init_triton_codegen(py::module &&m) {
|
||||||
m.def(
|
m.def(
|
||||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
|
"compile_ttir",
|
||||||
|
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps,
|
||||||
|
int num_stages, py::dict& extern_libs) {
|
||||||
std::string name = ir.get_function_list()[0]->get_name();
|
std::string name = ir.get_function_list()[0]->get_name();
|
||||||
// record asm as we generate
|
// record asm as we generate
|
||||||
asm_map_t asm_map;
|
asm_map_t asm_map;
|
||||||
@@ -531,11 +542,20 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
ir.print(ttir);
|
ir.print(ttir);
|
||||||
asm_map["ttir"] = py::cast(ttir.str());
|
asm_map["ttir"] = py::cast(ttir.str());
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
|
// construct extern lib map
|
||||||
|
triton::codegen::ExternLibMap extern_lib_map;
|
||||||
|
for (auto item : extern_libs) {
|
||||||
|
auto name = item.first.cast<std::string>();
|
||||||
|
auto path = item.second.cast<std::string>();
|
||||||
|
extern_lib_map.emplace(
|
||||||
|
name, triton::codegen::create_extern_lib(name, path));
|
||||||
|
}
|
||||||
if(backend == CUDA)
|
if(backend == CUDA)
|
||||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
|
||||||
if(backend == ROCM)
|
if(backend == ROCM)
|
||||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
|
||||||
}, py::return_value_policy::take_ownership);
|
},
|
||||||
|
py::return_value_policy::take_ownership);
|
||||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
if(backend == CUDA)
|
if(backend == CUDA)
|
||||||
@@ -931,7 +951,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
// Utilities
|
// Utilities
|
||||||
.def("create_clock", &ir::builder::create_clock, ret::reference)
|
.def("create_clock", &ir::builder::create_clock, ret::reference)
|
||||||
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
|
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
|
||||||
|
// Extern instruction
|
||||||
|
.def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference)
|
||||||
// Built-in instruction
|
// Built-in instruction
|
||||||
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
||||||
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
||||||
|
@@ -1300,3 +1300,49 @@ def test_num_warps_pow2():
|
|||||||
_kernel[(1,)](dst=dst, num_warps=1)
|
_kernel[(1,)](dst=dst, num_warps=1)
|
||||||
_kernel[(1,)](dst=dst, num_warps=2)
|
_kernel[(1,)](dst=dst, num_warps=2)
|
||||||
_kernel[(1,)](dst=dst, num_warps=4)
|
_kernel[(1,)](dst=dst, num_warps=4)
|
||||||
|
|
||||||
|
# -------------
|
||||||
|
# test extern
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||||
|
[('int32', 'libdevice.ffs', ''),
|
||||||
|
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||||
|
('float64', 'libdevice.norm4d', '')])
|
||||||
|
def test_libdevice(dtype_str, expr, lib_path):
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
|
y = GENERATE_TEST_HERE
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
rs = RandomState(17)
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||||
|
|
||||||
|
if expr == 'libdevice.ffs':
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
|
||||||
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
||||||
|
for i in range(shape[0]):
|
||||||
|
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
||||||
|
elif expr == 'libdevice.pow':
|
||||||
|
# numpy does not allow negative factors in power, so we use abs()
|
||||||
|
x = np.abs(x)
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
|
||||||
|
y_ref = np.power(x, x)
|
||||||
|
elif expr == 'libdevice.norm4d':
|
||||||
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
|
||||||
|
y_ref = np.sqrt(4 * np.power(x, 2))
|
||||||
|
|
||||||
|
x_tri = to_triton(x)
|
||||||
|
# triton result
|
||||||
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
||||||
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||||
|
# compare
|
||||||
|
if expr == 'libdevice.ffs':
|
||||||
|
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
||||||
|
else:
|
||||||
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||||
|
@@ -689,7 +689,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type)
|
ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type)
|
||||||
return ret
|
return ret
|
||||||
# built-in function
|
# built-in function
|
||||||
if sys.modules[fn.__module__] is triton.language.core:
|
if sys.modules[fn.__module__] is triton.language.core or isinstance(fn, triton.language.extern.ExternalFunction):
|
||||||
ret = fn(*args, _builder=self.builder, **kws)
|
ret = fn(*args, _builder=self.builder, **kws)
|
||||||
if fn in self.value_constructor.builtins.values():
|
if fn in self.value_constructor.builtins.values():
|
||||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||||
@@ -933,7 +933,7 @@ class Kernel:
|
|||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.cache_key = {}
|
self.cache_key = {}
|
||||||
|
|
||||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs):
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
|
|
||||||
# attributes
|
# attributes
|
||||||
@@ -953,9 +953,10 @@ class Kernel:
|
|||||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||||
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False)
|
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages,
|
||||||
|
extern_libs=extern_libs, is_manual_warmup=False)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs):
|
||||||
assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2."
|
assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2."
|
||||||
# handle arguments passed by name
|
# handle arguments passed by name
|
||||||
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||||
@@ -985,7 +986,7 @@ class Kernel:
|
|||||||
cache_key = self.cache_key[device]
|
cache_key = self.cache_key[device]
|
||||||
stream = current_cuda_stream(device)
|
stream = current_cuda_stream(device)
|
||||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
|
||||||
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
|
device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache,
|
||||||
grid)
|
grid)
|
||||||
|
|
||||||
|
|
||||||
@@ -1242,7 +1243,7 @@ class JITFunction:
|
|||||||
def warmup(self, compile):
|
def warmup(self, compile):
|
||||||
return self._warmup(**compile, is_manual_warmup=True)
|
return self._warmup(**compile, is_manual_warmup=True)
|
||||||
|
|
||||||
def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup):
|
def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup):
|
||||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
# create cache directory
|
# create cache directory
|
||||||
@@ -1264,7 +1265,7 @@ class JITFunction:
|
|||||||
with open(bin_cache_path, 'rb') as f:
|
with open(bin_cache_path, 'rb') as f:
|
||||||
binary = pickle.load(f)["binary"]
|
binary = pickle.load(f)["binary"]
|
||||||
|
|
||||||
compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages)
|
compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs)
|
||||||
if JITFunction.cache_hook is not None:
|
if JITFunction.cache_hook is not None:
|
||||||
name = self.__name__
|
name = self.__name__
|
||||||
info = key.split('-')[-3:]
|
info = key.split('-')[-3:]
|
||||||
@@ -1293,7 +1294,7 @@ class JITFunction:
|
|||||||
self.bin_cache[key] = LoadedBinary(device, binary)
|
self.bin_cache[key] = LoadedBinary(device, binary)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
|
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs):
|
||||||
# create IR module
|
# create IR module
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
@@ -1316,7 +1317,7 @@ class JITFunction:
|
|||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
else:
|
else:
|
||||||
backend = _triton.runtime.backend.ROCM
|
backend = _triton.runtime.backend.ROCM
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs)
|
||||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||||
if shared_mem > max_shared_memory:
|
if shared_mem > max_shared_memory:
|
||||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
# flake8: noqa: F401
|
# flake8: noqa: F401
|
||||||
from . import core, random
|
from . import core, extern, libdevice, random
|
||||||
from .core import *
|
from .core import *
|
||||||
from .random import *
|
from .random import *
|
||||||
|
@@ -248,8 +248,10 @@ class block_type(dtype):
|
|||||||
# while tensor's shape is a list of constexpr
|
# while tensor's shape is a list of constexpr
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.numel = 1
|
self.numel = 1
|
||||||
for s in self.shape:
|
for i, s in enumerate(self.shape):
|
||||||
self.numel *= s
|
if isinstance(s, constexpr):
|
||||||
|
self.shape[i] = s.value
|
||||||
|
self.numel *= self.shape[i]
|
||||||
|
|
||||||
self.name = self.__str__()
|
self.name = self.__str__()
|
||||||
|
|
||||||
|
107
python/triton/language/extern.py
Normal file
107
python/triton/language/extern.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from __future__ import annotations # remove after python 3.11
|
||||||
|
|
||||||
|
from . import core, semantic
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
|
||||||
|
'''
|
||||||
|
Dispatch a function to a library
|
||||||
|
|
||||||
|
:param func: the function to dispatch
|
||||||
|
:param lib_name: the name of the library
|
||||||
|
:param lib_path: the path of the library
|
||||||
|
:param args: the arguments of the function
|
||||||
|
:param arg_type_symbol_dict: the type of the arguments
|
||||||
|
:param ret_shape: the shape of the return value
|
||||||
|
:param _builder: the builder
|
||||||
|
|
||||||
|
:return: the return value of the function
|
||||||
|
'''
|
||||||
|
if len(arg_type_symbol_dict) == 0:
|
||||||
|
raise ValueError("arg_type_symbol_dict is empty")
|
||||||
|
|
||||||
|
num_args = len(list(arg_type_symbol_dict.keys())[0])
|
||||||
|
if len(args) != num_args:
|
||||||
|
raise ValueError(f"length of input args does not match."
|
||||||
|
f"Expect {len(args)}, got {num_args}")
|
||||||
|
|
||||||
|
arg_types = []
|
||||||
|
arg_list = []
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, core.tensor):
|
||||||
|
arg_types.append(arg.dtype)
|
||||||
|
arg_list.append(arg.handle)
|
||||||
|
else:
|
||||||
|
arg_types.append(type(arg))
|
||||||
|
arg_list.append(arg)
|
||||||
|
arg_types = tuple(arg_types)
|
||||||
|
|
||||||
|
if arg_types not in arg_type_symbol_dict:
|
||||||
|
raise ValueError(f"input arg type does not match."
|
||||||
|
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
|
||||||
|
else:
|
||||||
|
symbol = arg_type_symbol_dict[arg_types][0]
|
||||||
|
ret_type = arg_type_symbol_dict[arg_types][1]
|
||||||
|
ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type
|
||||||
|
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
|
||||||
|
|
||||||
|
|
||||||
|
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
|
||||||
|
'''
|
||||||
|
Dispatch an elementwise function to a library
|
||||||
|
|
||||||
|
:param lib_name: the name of the library
|
||||||
|
:param lib_path: the path of the library
|
||||||
|
:param args: the arguments of the function
|
||||||
|
:param arg_type_symbol_dict: the type of the arguments
|
||||||
|
:param _builder: the builder
|
||||||
|
|
||||||
|
:return: the return value of the function
|
||||||
|
'''
|
||||||
|
dispatch_args = args.copy()
|
||||||
|
if len(args) == 1:
|
||||||
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
|
ret_shape = dispatch_args[0].shape
|
||||||
|
elif len(args) == 2:
|
||||||
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
|
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
|
||||||
|
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[0], dispatch_args[1], _builder)
|
||||||
|
ret_shape = dispatch_args[0].shape
|
||||||
|
else:
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
|
||||||
|
broadcast_arg = dispatch_args[0]
|
||||||
|
# Get the broadcast shape over all the arguments
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[i], broadcast_arg, _builder)
|
||||||
|
# Change the shape of each argument based on the broadcast shape
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[i], broadcast_arg, _builder)
|
||||||
|
ret_shape = broadcast_arg.shape
|
||||||
|
func = getattr(_builder, "create_extern_elementwise")
|
||||||
|
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalFunction:
|
||||||
|
'''
|
||||||
|
A wrapper for external functions
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, fn):
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if '_builder' not in kwargs or \
|
||||||
|
kwargs['_builder'] is None:
|
||||||
|
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||||
|
return self.fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def extern(fn):
|
||||||
|
'''
|
||||||
|
A decorator for external functions
|
||||||
|
'''
|
||||||
|
return ExternalFunction(fn)
|
BIN
python/triton/language/libdevice.10.bc
Normal file
BIN
python/triton/language/libdevice.10.bc
Normal file
Binary file not shown.
1661
python/triton/language/libdevice.py
Normal file
1661
python/triton/language/libdevice.py
Normal file
File diff suppressed because it is too large
Load Diff
340
python/triton/tools/build_extern.py
Normal file
340
python/triton/tools/build_extern.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Symbol:
|
||||||
|
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
|
||||||
|
'''
|
||||||
|
A symbol is a function declaration.
|
||||||
|
|
||||||
|
:param name: name of the symbol
|
||||||
|
:param op_name: name of the operation
|
||||||
|
:param ret_type: return type of the operation
|
||||||
|
:param arg_names: names of the arguments
|
||||||
|
:param arg_types: types of the arguments
|
||||||
|
'''
|
||||||
|
self._name = name
|
||||||
|
self._op_name = op_name
|
||||||
|
self._ret_type = ret_type
|
||||||
|
self._arg_names = arg_names
|
||||||
|
self._arg_types = arg_types
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def op_name(self):
|
||||||
|
return self._op_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ret_type(self):
|
||||||
|
return self._ret_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arg_names(self):
|
||||||
|
return self._arg_names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arg_types(self):
|
||||||
|
return self._arg_types
|
||||||
|
|
||||||
|
|
||||||
|
def convert_type(type_str):
|
||||||
|
if type_str == "i32":
|
||||||
|
return "int32"
|
||||||
|
elif type_str == "u32":
|
||||||
|
return "uint32"
|
||||||
|
elif type_str == "i64":
|
||||||
|
return "int64"
|
||||||
|
elif type_str == "u64":
|
||||||
|
return "uint64"
|
||||||
|
elif type_str == "float":
|
||||||
|
return "fp32"
|
||||||
|
elif type_str == "double":
|
||||||
|
return "fp64"
|
||||||
|
else:
|
||||||
|
# ignore other types, such as pointer types
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def to_unsigned(type_str):
|
||||||
|
if type_str == "int32":
|
||||||
|
return "uint32"
|
||||||
|
elif type_str == "int64":
|
||||||
|
return "uint64"
|
||||||
|
else:
|
||||||
|
return type_str
|
||||||
|
|
||||||
|
|
||||||
|
class ExternLibrary(ABC):
|
||||||
|
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
|
||||||
|
'''
|
||||||
|
Abstract class for extern library.
|
||||||
|
|
||||||
|
:param name: name of the library
|
||||||
|
:param path: path of the library
|
||||||
|
:param format: whether to format the generated stub file
|
||||||
|
'''
|
||||||
|
self._name = name
|
||||||
|
self._path = path
|
||||||
|
self._symbols = {}
|
||||||
|
self._format = True
|
||||||
|
self._grouping = grouping
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def symbols(self):
|
||||||
|
return self._symbols
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grouping(self):
|
||||||
|
return self._grouping
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_symbols(self, input_file):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _output_stubs(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate_stub_file(self, output_dir):
|
||||||
|
file_str = self._output_stubs()
|
||||||
|
if file_str is None or len(file_str) == 0:
|
||||||
|
raise Exception("file_str is empty")
|
||||||
|
|
||||||
|
output_file = f"{output_dir}/{self._name}.py"
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
f.write(file_str)
|
||||||
|
f.close()
|
||||||
|
if self._format:
|
||||||
|
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
|
||||||
|
stdout=subprocess.PIPE).communicate()
|
||||||
|
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
|
||||||
|
|
||||||
|
|
||||||
|
class Libdevice(ExternLibrary):
|
||||||
|
def __init__(self, path) -> None:
|
||||||
|
'''
|
||||||
|
Constructor for Libdevice.
|
||||||
|
|
||||||
|
:param path: path of the libdevice library
|
||||||
|
'''
|
||||||
|
super().__init__("libdevice", path)
|
||||||
|
self._symbol_groups = {}
|
||||||
|
|
||||||
|
def _extract_symbol(self, line):
|
||||||
|
# Extract symbols from line in the following format:
|
||||||
|
# "define [internal] <ret_type> @<name>(<arg_types>,)"
|
||||||
|
entries = line.split("@")
|
||||||
|
ret_str = entries[0]
|
||||||
|
func_str = entries[1]
|
||||||
|
# Get ret_type, skip internal symbols
|
||||||
|
ret_strs = ret_str.split()
|
||||||
|
if ret_strs[1] == "internal":
|
||||||
|
return None
|
||||||
|
ret_type = convert_type(ret_strs[1])
|
||||||
|
if ret_type is None:
|
||||||
|
return None
|
||||||
|
# Get function name
|
||||||
|
func_strs = func_str.split("(")
|
||||||
|
func_name = func_strs[0].replace("@", "")
|
||||||
|
op_name = func_name.replace("__nv_", "")
|
||||||
|
# Get arg_types
|
||||||
|
arg_strs = func_strs[1].split(",")
|
||||||
|
arg_types = []
|
||||||
|
arg_names = []
|
||||||
|
for i, arg_str in enumerate(arg_strs):
|
||||||
|
arg_type = convert_type(arg_str.split()[0])
|
||||||
|
if arg_type is None:
|
||||||
|
return None
|
||||||
|
arg_name = 'arg' + str(i)
|
||||||
|
arg_types.append(arg_type)
|
||||||
|
arg_names.append(arg_name)
|
||||||
|
if op_name == "sad":
|
||||||
|
# Special case for sad, where the last argument is an unsigned int
|
||||||
|
arg_types[-1] = to_unsigned(arg_types[-1])
|
||||||
|
elif op_name.startswith("u"):
|
||||||
|
# LLVM does not differentiate between signed and unsigned integer type.
|
||||||
|
# We have to convert the types to unsigned
|
||||||
|
ret_type = to_unsigned(ret_type)
|
||||||
|
for i, arg_type in enumerate(arg_types):
|
||||||
|
arg_types[i] = to_unsigned(arg_type)
|
||||||
|
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
|
||||||
|
|
||||||
|
def _group_symbols(self):
|
||||||
|
symbol_set = {}
|
||||||
|
for symbol in self._symbols.values():
|
||||||
|
op_name = symbol.op_name
|
||||||
|
symbol_set[op_name] = symbol
|
||||||
|
# The following cases are grouped together:
|
||||||
|
# op_name, <u/ull/ll>op_name<ll/f/i>
|
||||||
|
for symbol in self._symbols.values():
|
||||||
|
op_name = symbol.op_name
|
||||||
|
if "max" in op_name:
|
||||||
|
op_name = "max"
|
||||||
|
elif "min" in op_name:
|
||||||
|
op_name = "min"
|
||||||
|
elif "abs" in op_name:
|
||||||
|
op_name = "abs"
|
||||||
|
elif "pow" in op_name and "fast" in op_name:
|
||||||
|
op_name = "pow"
|
||||||
|
elif "round" in op_name:
|
||||||
|
if "llround" in op_name:
|
||||||
|
op_name = "llround"
|
||||||
|
else:
|
||||||
|
op_name = "round"
|
||||||
|
elif "rint" in op_name:
|
||||||
|
if "llrint" in op_name:
|
||||||
|
op_name = "llrint"
|
||||||
|
else:
|
||||||
|
op_name = "rint"
|
||||||
|
elif op_name.startswith("ull"):
|
||||||
|
if "2" not in op_name:
|
||||||
|
# e.g., ullmax->max
|
||||||
|
op_name = op_name[3:]
|
||||||
|
else:
|
||||||
|
# e.g., ull2double->ll2double
|
||||||
|
op_name = op_name[1:]
|
||||||
|
elif op_name.startswith("u"):
|
||||||
|
if "2" not in op_name:
|
||||||
|
# e.g., uhadd->hadd
|
||||||
|
op_name = op_name[1:]
|
||||||
|
else:
|
||||||
|
# e.g., uint2double_rn->int2double_rn
|
||||||
|
op_name = op_name[1:]
|
||||||
|
elif op_name.startswith("ll"):
|
||||||
|
if "2" not in op_name:
|
||||||
|
# e.g., llmax->max
|
||||||
|
op_name = op_name[2:]
|
||||||
|
elif op_name.endswith("ll"):
|
||||||
|
op_name = op_name[:-2]
|
||||||
|
elif op_name.endswith("f"):
|
||||||
|
op_name = op_name[:-1]
|
||||||
|
if op_name in symbol_set:
|
||||||
|
# Update op_name only if there's an existing symbol
|
||||||
|
symbol._op_name = op_name
|
||||||
|
else:
|
||||||
|
op_name = symbol._op_name
|
||||||
|
if op_name in self._symbol_groups:
|
||||||
|
self._symbol_groups[op_name].append(symbol)
|
||||||
|
else:
|
||||||
|
self._symbol_groups[op_name] = [symbol]
|
||||||
|
|
||||||
|
def parse_symbols(self, input_file):
|
||||||
|
if len(self.symbols) > 0:
|
||||||
|
return
|
||||||
|
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
|
||||||
|
for line in output:
|
||||||
|
symbol = self._extract_symbol(line)
|
||||||
|
if symbol is None:
|
||||||
|
continue
|
||||||
|
self._symbols[symbol.name] = symbol
|
||||||
|
|
||||||
|
self._group_symbols()
|
||||||
|
|
||||||
|
def _output_stubs(self):
|
||||||
|
# Generate python functions in the following format:
|
||||||
|
# @extern.extern
|
||||||
|
# def <op_name>(<args>, _builder=None):
|
||||||
|
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
|
||||||
|
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||||
|
import_str = "from . import core, extern\n"
|
||||||
|
import_str += "import os\n"
|
||||||
|
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
|
||||||
|
func_str = ""
|
||||||
|
for symbols in self._symbol_groups.values():
|
||||||
|
func_str += "@extern.extern\n"
|
||||||
|
func_name_str = f"def {symbols[0].op_name}("
|
||||||
|
for arg_name in symbols[0].arg_names:
|
||||||
|
func_name_str += f"{arg_name}, "
|
||||||
|
func_name_str += "_builder=None):\n"
|
||||||
|
|
||||||
|
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
|
||||||
|
for arg_name in symbols[0].arg_names:
|
||||||
|
return_str += f"{arg_name}, "
|
||||||
|
return_str += "], \n"
|
||||||
|
|
||||||
|
arg_type_symbol_dict_str = "{"
|
||||||
|
for symbol in symbols:
|
||||||
|
arg_type_symbol_dict_str += "("
|
||||||
|
for arg_type in symbol.arg_types:
|
||||||
|
arg_type_symbol_dict_str += f"core.dtype(\"{arg_type}\"),"
|
||||||
|
ret_type = f"core.dtype(\"{symbol.ret_type}\")"
|
||||||
|
arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
|
||||||
|
arg_type_symbol_dict_str += "}"
|
||||||
|
|
||||||
|
return_str += arg_type_symbol_dict_str
|
||||||
|
return_str += ", _builder)\n"
|
||||||
|
|
||||||
|
func_str += func_name_str + return_str + "\n"
|
||||||
|
file_str = import_str + header_str + func_str
|
||||||
|
|
||||||
|
return file_str
|
||||||
|
|
||||||
|
|
||||||
|
class LLVMDisassembler:
|
||||||
|
def __init__(self, path):
|
||||||
|
'''
|
||||||
|
Invoke llvm-dis to disassemble the given file.
|
||||||
|
|
||||||
|
:param path: path to llvm-dis
|
||||||
|
'''
|
||||||
|
self._path = path
|
||||||
|
self._ll_file = "/tmp/extern_lib.ll"
|
||||||
|
|
||||||
|
def disasm(self, lib_path):
|
||||||
|
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
|
||||||
|
stdout=subprocess.PIPE).communicate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ll_file(self):
|
||||||
|
return self._ll_file
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
|
||||||
|
extern_libs = ["libdevice"]
|
||||||
|
|
||||||
|
|
||||||
|
def build(llvm_dis_path, lib_path, lib_name, output_dir):
|
||||||
|
'''
|
||||||
|
Interface function to build the library file.
|
||||||
|
|
||||||
|
:param llvm_dis_path: path to the llvm-dis binary
|
||||||
|
:param lib_path: path to the external library file
|
||||||
|
:param lib_name: name of the library
|
||||||
|
:param output_dir: path to the output directory
|
||||||
|
'''
|
||||||
|
if lib_name == "libdevice":
|
||||||
|
extern_lib = Libdevice(lib_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown extern library: {lib_name}")
|
||||||
|
|
||||||
|
llvm_disassembler = LLVMDisassembler(llvm_dis_path)
|
||||||
|
llvm_disassembler.disasm(lib_path)
|
||||||
|
|
||||||
|
extern_lib.parse_symbols(llvm_disassembler.ll_file)
|
||||||
|
extern_lib.generate_stub_file(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis")
|
||||||
|
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library")
|
||||||
|
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library")
|
||||||
|
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)
|
74
python/tutorials/07-libdevice-function.py
Normal file
74
python/tutorials/07-libdevice-function.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
Libdevice function
|
||||||
|
===============
|
||||||
|
Triton can invoke a custom function from an external library.
|
||||||
|
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
|
||||||
|
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
|
||||||
|
|
||||||
|
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||||
|
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||||
|
Using triton, you can simply call `tl.libdevice.asinf`.
|
||||||
|
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# asin Kernel
|
||||||
|
# --------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def asin_kernel(
|
||||||
|
x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
x = tl.libdevice.asin(x)
|
||||||
|
tl.store(y_ptr + offsets, x, mask=mask)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Using the default libdevice library path
|
||||||
|
# --------------------------
|
||||||
|
# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
|
||||||
|
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
size = 98432
|
||||||
|
x = torch.rand(size, device='cuda')
|
||||||
|
output_triton = torch.zeros(size, device='cuda')
|
||||||
|
output_torch = torch.asin(x)
|
||||||
|
assert x.is_cuda and output_triton.is_cuda
|
||||||
|
n_elements = output_torch.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
|
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
|
||||||
|
print(output_torch)
|
||||||
|
print(output_triton)
|
||||||
|
print(
|
||||||
|
f'The maximum difference between torch and triton is '
|
||||||
|
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Customize the libdevice library path
|
||||||
|
# --------------------------
|
||||||
|
# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
|
||||||
|
|
||||||
|
output_triton = torch.empty_like(x)
|
||||||
|
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
|
||||||
|
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
|
||||||
|
print(output_torch)
|
||||||
|
print(output_triton)
|
||||||
|
print(
|
||||||
|
f'The maximum difference between torch and triton is '
|
||||||
|
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||||
|
)
|
Reference in New Issue
Block a user