[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)
This commit is contained in:
89
include/triton/codegen/extern_lib.h
Normal file
89
include/triton/codegen/extern_lib.h
Normal file
@@ -0,0 +1,89 @@
|
||||
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
|
||||
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
///
|
||||
/// \brief ExternLib is a class that represents a library of external functions.
|
||||
///
|
||||
class ExternLib {
|
||||
public:
|
||||
ExternLib(const std::string &name, const std::string &path)
|
||||
: name_(name), path_(path) {}
|
||||
|
||||
virtual ~ExternLib() = default;
|
||||
|
||||
virtual const std::string &name() const { return name_; }
|
||||
|
||||
virtual const std::string &path() const { return path_; }
|
||||
|
||||
///
|
||||
/// \brief Load the library and return the module.
|
||||
///
|
||||
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
|
||||
|
||||
///
|
||||
/// \brief Link the module into the given module.
|
||||
///
|
||||
void link(std::unique_ptr<llvm::Module> &llvm,
|
||||
std::unique_ptr<llvm::Module> &mod);
|
||||
|
||||
///
|
||||
/// \brief Run load, link, and opt on the module.
|
||||
///
|
||||
virtual void install(llvm::LLVMContext &ctx,
|
||||
std::unique_ptr<llvm::Module> &llvm) {
|
||||
auto mod = load(ctx);
|
||||
link(llvm, mod);
|
||||
opt(ctx, llvm);
|
||||
}
|
||||
|
||||
///
|
||||
/// \brief Run opt on the module.
|
||||
///
|
||||
virtual void opt(llvm::LLVMContext &ctx,
|
||||
std::unique_ptr<llvm::Module> &llvm) = 0;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string path_;
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
|
||||
///
|
||||
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
|
||||
|
||||
///
|
||||
/// \brief Concrete class for NVIDIA's libdevice library.
|
||||
///
|
||||
class LibDevice final : public ExternLib {
|
||||
public:
|
||||
LibDevice(const std::string &name, const std::string &path)
|
||||
: ExternLib(name, path) {}
|
||||
|
||||
virtual ~LibDevice() = default;
|
||||
|
||||
virtual void opt(llvm::LLVMContext &ctx,
|
||||
std::unique_ptr<llvm::Module> &llvm) override;
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Create an ExternLib instance based on the name and path.
|
||||
///
|
||||
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
|
||||
const std::string &lib_path);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
#include <memory>
|
||||
#include "extern_lib.h"
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
@@ -30,12 +31,10 @@ namespace codegen{
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
|
||||
int num_warps, int num_stages, int &shared_static,
|
||||
const ExternLibMap &extern_libs);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/extern_lib.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
@@ -199,6 +200,7 @@ private:
|
||||
void visit_make_range(ir::make_range*);
|
||||
void visit_clock_inst(ir::clock_inst*);
|
||||
void visit_globaltimer_inst(ir::globaltimer_inst*);
|
||||
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
@@ -209,18 +211,26 @@ private:
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
// Add a new external library based on given name and path if it doesn't exist
|
||||
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
|
||||
|
||||
private:
|
||||
// Get all external libraries
|
||||
const ExternLibMap &get_extern_lib_map() {
|
||||
return extern_lib_map_;
|
||||
}
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
|
||||
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
|
@@ -169,6 +169,12 @@ public:
|
||||
// Utilities
|
||||
value *create_clock();
|
||||
value *create_globaltimer();
|
||||
// Extern instruction
|
||||
value *create_extern_elementwise(const std::string &lib_name,
|
||||
const std::string &lib_path,
|
||||
const std::string &symbol_name,
|
||||
const std::vector<value *> &args,
|
||||
type *ret_ty);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
|
@@ -154,6 +154,8 @@ enum value_id_t: unsigned {
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// extern
|
||||
INST_EXTERN_ELEMENTWISE,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
|
@@ -1097,7 +1097,28 @@ public:
|
||||
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class extern_elementwise_inst : public instruction {
|
||||
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
|
||||
type *dst_ty, const std::string &lib_name,
|
||||
const std::string &extern_lib_path,
|
||||
const std::string &symbol_name, instruction *next);
|
||||
std::string repr_impl() const { return "extern_elementwise"; }
|
||||
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
|
||||
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
|
||||
|
||||
public:
|
||||
static extern_elementwise_inst *create(
|
||||
context &ctx, const std::vector<value *> &args, type *dst_ty,
|
||||
const std::string &lib_name = "", const std::string &lib_path = "",
|
||||
const std::string &symbol_name = "", instruction *next = nullptr);
|
||||
|
||||
const std::string &get_lib_name() const { return lib_name_; }
|
||||
const std::string &get_lib_path() const { return lib_path_; }
|
||||
|
||||
private:
|
||||
std::string lib_name_ = "";
|
||||
std::string lib_path_ = "";
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -84,6 +84,8 @@ class prefetch_s_inst;
|
||||
class clock_inst;
|
||||
class globaltimer_inst;
|
||||
|
||||
class extern_elementwise_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
@@ -177,6 +179,8 @@ public:
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||
|
||||
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user