diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index 52395b0b6..2becb6fe5 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -1,5 +1,6 @@ #ifndef TRITON_TARGET_LLVMIRTRANSLATION_H #define TRITON_TARGET_LLVMIRTRANSLATION_H +#include "llvm/ADT/StringRef.h" #include #include @@ -29,6 +30,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module); +bool linkExternLib(llvm::Module &module, llvm::StringRef path); + } // namespace triton } // namespace mlir diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index eaabb7c24..645aaff43 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -151,7 +151,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, return nullptr; } - std::map extern_libs; + std::map externLibs; SmallVector funcs; module.walk([&](LLVM::LLVMFuncOp func) { if (func.isExternal()) @@ -166,7 +166,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, func.getOperation()->getAttr("libpath").dyn_cast(); if (name) { std::string lib_name = name.str(); - extern_libs[lib_name] = path.str(); + externLibs[lib_name] = path.str(); } } } @@ -176,7 +176,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, ->getAttr("triton_gpu.externs") .dyn_cast(); for (auto &attr : dict) { - extern_libs[attr.getName().strref().trim().str()] = + externLibs[attr.getName().strref().trim().str()] = attr.getValue().dyn_cast().strref().trim().str(); } } @@ -188,20 +188,9 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, } llvm::SMDiagnostic err; - for (auto &lib : extern_libs) { - auto ext_mod = llvm::parseIRFile(lib.second, err, *llvmContext); - if (!ext_mod) { - llvm::errs() << "Failed to load extern lib " << lib.first; + for (auto &lib : externLibs) { + if (linkExternLib(*llvmir, lib.second)) return nullptr; - } - ext_mod->setTargetTriple(llvmir->getTargetTriple()); - ext_mod->setDataLayout(llvmir->getDataLayout()); - - if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod), - llvm::Linker::Flags::LinkOnlyNeeded)) { - llvm::errs() << "Failed to link extern lib " << lib.first; - return nullptr; - } } return llvmir; @@ -227,5 +216,27 @@ void addExternalLibs(mlir::ModuleOp &module, return; } +bool linkExternLib(llvm::Module &module, llvm::StringRef path) { + llvm::SMDiagnostic err; + auto &ctx = module.getContext(); + + auto extMod = llvm::parseIRFile(path, err, ctx); + if (!extMod) { + llvm::errs() << "Failed to load " << path; + return true; + } + + extMod->setTargetTriple(module.getTargetTriple()); + extMod->setDataLayout(module.getDataLayout()); + + if (llvm::Linker::linkModules(module, std::move(extMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + llvm::errs() << "Failed to link " << path; + return true; + } + + return false; +} + } // namespace triton } // namespace mlir diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index a8266322e..d55ce4b44 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -29,6 +29,7 @@ #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Cloning.h" +#include #include namespace triton { @@ -61,6 +62,43 @@ static bool find_and_replace(std::string &str, const std::string &begin, } static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { + bool hasExternal = false; + for (auto &func : *module) { + if (func.hasExternalLinkage()) { + hasExternal = true; + break; + } + } + + if (hasExternal) { + namespace fs = std::filesystem; + // [triton root dir]/python/triton/language/libdevice.10.bc + static const fs::path libdevice = fs::path(__FILE__) + .parent_path() + .parent_path() + .parent_path() + .parent_path() / + "python" / "triton" / "language" / + "libdevice.10.bc"; + if (mlir::triton::linkExternLib(*module, libdevice.string())) + llvm::errs() << "link failed for: " << libdevice.string(); + } + + // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters + // this will enable fast math path in libdevice + // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to + // sqrt.approx.ftz.f32 + { + auto &ctx = module->getContext(); + llvm::Type *I32 = llvm::Type::getInt32Ty(ctx); + llvm::Metadata *mdFour = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); + llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz"); + llvm::Metadata *mdOne = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); + llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne}); + module->addModuleFlag(reflect); + } // LLVM version in use may not officially support target hardware int max_nvvm_cc = 75; // int max_nvvm_ptx = 74; diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 7609c9419..eb9ffecef 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): # triton result x_tri = to_triton(x, device=device, dst_type=dtype_x) z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) - kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"}) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) # compare np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)