[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)

This commit is contained in:
Keren Zhou
2022-07-13 15:52:21 -07:00
committed by GitHub
parent 971f5782b4
commit 4912916c11
24 changed files with 2634 additions and 64 deletions

View File

@@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){
for(auto idx: idxs_.at(x)){
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
}
}
}
/**
* \brief Code Generation for `umulhi`
@@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
call(iasm);
}
/**
* \brief Code Generation for `extern_elementwise`
*/
void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
std::vector<Type *> operand_types;
for (size_t j = 0; j < i->get_num_operands(); j++) {
operand_types.push_back(
cvt(i->get_operand(j)->get_type()->get_scalar_ty()));
}
Type *ret_type = cvt(i->get_type()->get_scalar_ty());
FunctionType *FT =
FunctionType::get(ret_type, std::move(operand_types), false);
Function *F = llvm::cast<llvm::Function>(
mod_->getOrInsertFunction(i->get_name(), FT).getCallee());
for (auto idx : idxs_.at(i)) {
std::vector<llvm::Value *> args;
for (size_t j = 0; j < i->get_num_operands(); j++) {
args.emplace_back(vals_[i->get_operand(j)][idx]);
}
vals_[i][idx] = call(F, std::move(args));
}
add_extern_lib(i->get_lib_name(), i->get_lib_path());
}
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
// for(indices_t idx: idxs_.at(x)){
// assert(idx.size() == 1);
@@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
visit_function(fn);
}
void generator::add_extern_lib(const std::string &lib_name,
const std::string &lib_path) {
if (extern_lib_map_.count(lib_name) == 0) {
extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path);
} else if (extern_lib_map_.at(lib_name)->path() != lib_path) {
throw std::runtime_error("A library has multiple paths (1) " + lib_path +
" (2) " + extern_lib_map_.at(lib_name)->path());
}
}
}
}
} // namespace codegen
} // namespace triton