[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)

This commit is contained in:
Keren Zhou
2022-07-13 15:52:21 -07:00
committed by GitHub
parent 971f5782b4
commit 4912916c11
24 changed files with 2634 additions and 64 deletions

View File

@@ -0,0 +1,89 @@
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
#include <memory>
#include <string>
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
namespace triton {
namespace codegen {
///
/// \brief ExternLib is a class that represents a library of external functions.
///
class ExternLib {
public:
ExternLib(const std::string &name, const std::string &path)
: name_(name), path_(path) {}
virtual ~ExternLib() = default;
virtual const std::string &name() const { return name_; }
virtual const std::string &path() const { return path_; }
///
/// \brief Load the library and return the module.
///
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
///
/// \brief Link the module into the given module.
///
void link(std::unique_ptr<llvm::Module> &llvm,
std::unique_ptr<llvm::Module> &mod);
///
/// \brief Run load, link, and opt on the module.
///
virtual void install(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) {
auto mod = load(ctx);
link(llvm, mod);
opt(ctx, llvm);
}
///
/// \brief Run opt on the module.
///
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) = 0;
private:
std::string name_;
std::string path_;
};
///
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
///
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
///
/// \brief Concrete class for NVIDIA's libdevice library.
///
class LibDevice final : public ExternLib {
public:
LibDevice(const std::string &name, const std::string &path)
: ExternLib(name, path) {}
virtual ~LibDevice() = default;
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) override;
};
///
/// \brief Create an ExternLib instance based on the name and path.
///
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
const std::string &lib_path);
} // namespace codegen
} // namespace triton
#endif

View File

@@ -3,6 +3,7 @@
#include <memory>
#include "extern_lib.h"
namespace llvm{
class Module;
@@ -30,12 +31,10 @@ namespace codegen{
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target,
int sm, int num_warps,
int num_stages, int &shared_static);
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
int num_warps, int num_stages, int &shared_static,
const ExternLibMap &extern_libs);
}
}

View File

@@ -6,6 +6,7 @@
#include "triton/ir/visitor.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/extern_lib.h"
#include <functional>
// forward
@@ -199,6 +200,7 @@ private:
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
@@ -209,18 +211,26 @@ private:
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
// Add a new external library based on given name and path if it doesn't exist
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
private:
// Get all external libraries
const ExternLibMap &get_extern_lib_map() {
return extern_lib_map_;
}
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;

View File

@@ -169,6 +169,12 @@ public:
// Utilities
value *create_clock();
value *create_globaltimer();
// Extern instruction
value *create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty);
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);

View File

@@ -154,6 +154,8 @@ enum value_id_t: unsigned {
INST_COS,
INST_SIN,
INST_LOG,
// extern
INST_EXTERN_ELEMENTWISE,
// array arithmetic
INST_TRANS,
INST_REDUCE,

View File

@@ -1097,7 +1097,28 @@ public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class extern_elementwise_inst : public instruction {
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
type *dst_ty, const std::string &lib_name,
const std::string &extern_lib_path,
const std::string &symbol_name, instruction *next);
std::string repr_impl() const { return "extern_elementwise"; }
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
public:
static extern_elementwise_inst *create(
context &ctx, const std::vector<value *> &args, type *dst_ty,
const std::string &lib_name = "", const std::string &lib_path = "",
const std::string &symbol_name = "", instruction *next = nullptr);
const std::string &get_lib_name() const { return lib_name_; }
const std::string &get_lib_path() const { return lib_path_; }
private:
std::string lib_name_ = "";
std::string lib_path_ = "";
};
}
}

View File

@@ -84,6 +84,8 @@ class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;
class extern_elementwise_inst;
class make_range_sta;
class undef_value;
class constant_int;
@@ -177,6 +179,8 @@ public:
virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0;
virtual void visit_alloc_const(alloc_const*) = 0;
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
};
}