[Triton-MLIR] Enable libdevice for ptx backend when has external functions. (#848)
At the phase from ptx to cubin, check whether llvm::Module has external functions. if has, link with libdevice at: https://github.com/openai/triton/blob/triton-mlir/python/triton/language/libdevice.10.bc
This commit is contained in:
@@ -151,7 +151,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> extern_libs;
|
||||
std::map<std::string, std::string> externLibs;
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||
if (func.isExternal())
|
||||
@@ -166,7 +166,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||
if (name) {
|
||||
std::string lib_name = name.str();
|
||||
extern_libs[lib_name] = path.str();
|
||||
externLibs[lib_name] = path.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,7 +176,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
->getAttr("triton_gpu.externs")
|
||||
.dyn_cast<DictionaryAttr>();
|
||||
for (auto &attr : dict) {
|
||||
extern_libs[attr.getName().strref().trim().str()] =
|
||||
externLibs[attr.getName().strref().trim().str()] =
|
||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||
}
|
||||
}
|
||||
@@ -188,20 +188,9 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
}
|
||||
|
||||
llvm::SMDiagnostic err;
|
||||
for (auto &lib : extern_libs) {
|
||||
auto ext_mod = llvm::parseIRFile(lib.second, err, *llvmContext);
|
||||
if (!ext_mod) {
|
||||
llvm::errs() << "Failed to load extern lib " << lib.first;
|
||||
for (auto &lib : externLibs) {
|
||||
if (linkExternLib(*llvmir, lib.second))
|
||||
return nullptr;
|
||||
}
|
||||
ext_mod->setTargetTriple(llvmir->getTargetTriple());
|
||||
ext_mod->setDataLayout(llvmir->getDataLayout());
|
||||
|
||||
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod),
|
||||
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||
llvm::errs() << "Failed to link extern lib " << lib.first;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return llvmir;
|
||||
@@ -227,5 +216,27 @@ void addExternalLibs(mlir::ModuleOp &module,
|
||||
return;
|
||||
}
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto &ctx = module.getContext();
|
||||
|
||||
auto extMod = llvm::parseIRFile(path, err, ctx);
|
||||
if (!extMod) {
|
||||
llvm::errs() << "Failed to load " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
extMod->setTargetTriple(module.getTargetTriple());
|
||||
extMod->setDataLayout(module.getDataLayout());
|
||||
|
||||
if (llvm::Linker::linkModules(module, std::move(extMod),
|
||||
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||
llvm::errs() << "Failed to link " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
Reference in New Issue
Block a user