[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)
This commit is contained in:
63
lib/codegen/extern_lib.cc
Normal file
63
lib/codegen/extern_lib.cc
Normal file
@@ -0,0 +1,63 @@
|
||||
#include "triton/codegen/extern_lib.h"
|
||||
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Metadata.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "triton/codegen/pass.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen {
|
||||
|
||||
std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto mod = llvm::parseIRFile(this->path_, err, ctx);
|
||||
if (!mod) {
|
||||
throw std::runtime_error("Failed to load extern lib " + this->name_ +
|
||||
" at " + this->path_);
|
||||
}
|
||||
return mod;
|
||||
}
|
||||
|
||||
void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
|
||||
std::unique_ptr<llvm::Module>& mod) {
|
||||
// Set triple and data layout to match the target module
|
||||
mod->setTargetTriple(llvm->getTargetTriple());
|
||||
mod->setDataLayout(llvm->getDataLayout());
|
||||
if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
|
||||
throw std::runtime_error("Failed to link extern lib " + this->name_ +
|
||||
" at " + this->path_);
|
||||
}
|
||||
}
|
||||
|
||||
void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
|
||||
// Add nvvm reflect flags to llvm module
|
||||
// https://llvm.org/docs/LangRef.html#module-flags-metadata
|
||||
// i32 4: Override the other module.
|
||||
// i32 1: Emit an error
|
||||
// If both modules specify Override, but the values differ, an error
|
||||
// will be emitted.
|
||||
llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::Metadata* md_four =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||
llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||
llvm::Metadata* md_one =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||
llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
|
||||
llvm->addModuleFlag(reflect);
|
||||
}
|
||||
|
||||
std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (lib_name == "libdevice") {
|
||||
return std::make_unique<LibDevice>(lib_name, lib_path);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown external library: " + lib_name);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
@@ -1,4 +1,14 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Transforms/IPO.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
@@ -9,24 +19,66 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/transform/inline.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/codegen/transform/inline.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
|
||||
const ExternLibMap& target_extern_lib_map,
|
||||
ir::module& ir, llvm::LLVMContext& ctx,
|
||||
std::unique_ptr<llvm::Module>& llvm) {
|
||||
for (const auto& iter : target_extern_lib_map) {
|
||||
auto &lib_name = iter.first;
|
||||
if (user_extern_lib_map.count(lib_name) != 0 &&
|
||||
user_extern_lib_map.at(lib_name)->path() != "") {
|
||||
// If the user specified a path for this library, use it.
|
||||
user_extern_lib_map.at(lib_name)->install(ctx, llvm);
|
||||
} else {
|
||||
// Otherwise, use the default path.
|
||||
iter.second->install(ctx, llvm);
|
||||
}
|
||||
}
|
||||
|
||||
std::set<llvm::StringRef> function_names;
|
||||
for (auto& func : ir.get_function_list()) {
|
||||
function_names.insert(func->get_name());
|
||||
}
|
||||
llvm::legacy::PassManager pass;
|
||||
pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
|
||||
if (function_names.count(v.getName()) != 0) {
|
||||
// Preserve global functions
|
||||
return true;
|
||||
}
|
||||
// Internalize all device functions
|
||||
return false;
|
||||
}));
|
||||
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*llvm);
|
||||
|
||||
llvm::PassManagerBuilder builder;
|
||||
builder.OptLevel = 3;
|
||||
builder.SizeLevel = 0;
|
||||
builder.populateModulePassManager(pass);
|
||||
|
||||
pass.run(*llvm);
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int num_warps, int num_stages, int& shared_static,
|
||||
const ExternLibMap& extern_lib_map) {
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
@@ -47,8 +99,10 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
codegen::transform::peephole peephole(target, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation,
|
||||
&prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
|
||||
target, num_warps);
|
||||
// run passes
|
||||
inliner.run(ir);
|
||||
dce.run(ir);
|
||||
@@ -56,7 +110,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
dce.run(ir);
|
||||
dce.run(ir);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
@@ -64,8 +118,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
if (target->is_gpu()) cts.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
@@ -73,8 +126,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
if (target->is_gpu()) cts.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
@@ -97,8 +149,15 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
// ir.print(std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
shared_static = allocation.allocated_size();
|
||||
|
||||
if (isel.get_extern_lib_map().size() > 0) {
|
||||
// If there's any extern lib calls,
|
||||
// we need to link them in.
|
||||
link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
|
||||
}
|
||||
|
||||
return llvm;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
@@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `umulhi`
|
||||
@@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
call(iasm);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `extern_elementwise`
|
||||
*/
|
||||
void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
|
||||
std::vector<Type *> operand_types;
|
||||
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||
operand_types.push_back(
|
||||
cvt(i->get_operand(j)->get_type()->get_scalar_ty()));
|
||||
}
|
||||
Type *ret_type = cvt(i->get_type()->get_scalar_ty());
|
||||
FunctionType *FT =
|
||||
FunctionType::get(ret_type, std::move(operand_types), false);
|
||||
Function *F = llvm::cast<llvm::Function>(
|
||||
mod_->getOrInsertFunction(i->get_name(), FT).getCallee());
|
||||
for (auto idx : idxs_.at(i)) {
|
||||
std::vector<llvm::Value *> args;
|
||||
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||
args.emplace_back(vals_[i->get_operand(j)][idx]);
|
||||
}
|
||||
vals_[i][idx] = call(F, std::move(args));
|
||||
}
|
||||
add_extern_lib(i->get_lib_name(), i->get_lib_path());
|
||||
}
|
||||
|
||||
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
@@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
visit_function(fn);
|
||||
}
|
||||
|
||||
void generator::add_extern_lib(const std::string &lib_name,
|
||||
const std::string &lib_path) {
|
||||
if (extern_lib_map_.count(lib_name) == 0) {
|
||||
extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path);
|
||||
} else if (extern_lib_map_.at(lib_name)->path() != lib_path) {
|
||||
throw std::runtime_error("A library has multiple paths (1) " + lib_path +
|
||||
" (2) " + extern_lib_map_.at(lib_name)->path());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
Reference in New Issue
Block a user