[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)
This commit is contained in:
@@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
for(auto idx: idxs_.at(x)){
|
||||
vals_[x][idx] = call(cos, std::vector<llvm::Value*>{vals_[x->get_operand(0)][idx]});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `umulhi`
|
||||
@@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
call(iasm);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `extern_elementwise`
|
||||
*/
|
||||
void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) {
|
||||
std::vector<Type *> operand_types;
|
||||
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||
operand_types.push_back(
|
||||
cvt(i->get_operand(j)->get_type()->get_scalar_ty()));
|
||||
}
|
||||
Type *ret_type = cvt(i->get_type()->get_scalar_ty());
|
||||
FunctionType *FT =
|
||||
FunctionType::get(ret_type, std::move(operand_types), false);
|
||||
Function *F = llvm::cast<llvm::Function>(
|
||||
mod_->getOrInsertFunction(i->get_name(), FT).getCallee());
|
||||
for (auto idx : idxs_.at(i)) {
|
||||
std::vector<llvm::Value *> args;
|
||||
for (size_t j = 0; j < i->get_num_operands(); j++) {
|
||||
args.emplace_back(vals_[i->get_operand(j)][idx]);
|
||||
}
|
||||
vals_[i][idx] = call(F, std::move(args));
|
||||
}
|
||||
add_extern_lib(i->get_lib_name(), i->get_lib_path());
|
||||
}
|
||||
|
||||
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
@@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
visit_function(fn);
|
||||
}
|
||||
|
||||
void generator::add_extern_lib(const std::string &lib_name,
|
||||
const std::string &lib_path) {
|
||||
if (extern_lib_map_.count(lib_name) == 0) {
|
||||
extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path);
|
||||
} else if (extern_lib_map_.at(lib_name)->path() != lib_path) {
|
||||
throw std::runtime_error("A library has multiple paths (1) " + lib_path +
|
||||
" (2) " + extern_lib_map_.at(lib_name)->path());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
Reference in New Issue
Block a user