diff --git a/include/triton/Target/LLVMIR/LLVMIRTranslation.h b/include/triton/Target/LLVMIR/LLVMIRTranslation.h index b83ff9f57..3e9f73efe 100644 --- a/include/triton/Target/LLVMIR/LLVMIRTranslation.h +++ b/include/triton/Target/LLVMIR/LLVMIRTranslation.h @@ -31,8 +31,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module); -bool linkExternLib(llvm::Module &module, llvm::StringRef path); - } // namespace triton } // namespace mlir diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index f193bcc1d..ab9659da6 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -18,6 +18,7 @@ #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/SourceMgr.h" +#include namespace mlir { namespace triton { @@ -26,19 +27,18 @@ namespace triton { // information from mlir module. struct NVVMMetadata { int maxntidx{-1}; - bool is_kernel{}; + bool isKernel{}; // Free to extend with other information. }; // Add the nvvm related metadata to LLVM IR. -void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) { +static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) { auto *module = func->getParent(); auto &ctx = func->getContext(); if (metadata.maxntidx > 0) { - auto i32_ty = llvm::IntegerType::get(ctx, 32); - auto warps = - llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx)); + auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32), + llvm::APInt(32, metadata.maxntidx)); llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "maxntidx"), @@ -48,18 +48,19 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) { ->addOperand(llvm::MDNode::get(ctx, md_args)); } - if (metadata.is_kernel) { - llvm::Metadata *md_args[] = { + if (metadata.isKernel) { + llvm::Metadata *mdArgs[] = { llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"), llvm::ValueAsMetadata::get( llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))}; module->getOrInsertNamedMetadata("nvvm.annotations") - ->addOperand(llvm::MDNode::get(ctx, md_args)); + ->addOperand(llvm::MDNode::get(ctx, mdArgs)); } } -void extractNVVMMetadata(mlir::ModuleOp module, - llvm::DenseMap *dic) { +static void +extractNVVMMetadata(mlir::ModuleOp module, + llvm::DenseMap *dic) { for (auto op : module.getOps()) { NVVMMetadata meta; @@ -74,7 +75,7 @@ void extractNVVMMetadata(mlir::ModuleOp module, // kernel if (op->hasAttr("nvvm.kernel")) { - meta.is_kernel = true; + meta.isKernel = true; hasMetadata = true; } @@ -83,13 +84,109 @@ void extractNVVMMetadata(mlir::ModuleOp module, } } +static std::map getExternLibs(mlir::ModuleOp module) { + std::map externLibs; + SmallVector funcs; + module.walk([&](LLVM::LLVMFuncOp func) { + if (func.isExternal()) + funcs.push_back(func); + }); + + for (auto &func : funcs) { + if (func.getOperation()->hasAttr("libname")) { + auto name = + func.getOperation()->getAttr("libname").dyn_cast(); + auto path = + func.getOperation()->getAttr("libpath").dyn_cast(); + if (name) { + std::string libName = name.str(); + externLibs[libName] = path.str(); + } + } + } + + if (module.getOperation()->hasAttr("triton_gpu.externs")) { + auto dict = module.getOperation() + ->getAttr("triton_gpu.externs") + .dyn_cast(); + for (auto &attr : dict) { + externLibs[attr.getName().strref().trim().str()] = + attr.getValue().dyn_cast().strref().trim().str(); + } + } + + if (!funcs.empty()) { + // When using the Math Dialect, it is possible that some ops (e.g., log) are + // lowered to a function call. In this case, we need to link libdevice + // using its default path: + // [triton root dir]/python/triton/language/libdevice.10.bc + // TODO(Keren): handle external linkage other than libdevice? + namespace fs = std::filesystem; + static const std::string libdevice = "libdevice"; + static const std::filesystem::path path = std::filesystem::path(__FILE__) + .parent_path() + .parent_path() + .parent_path() + .parent_path() / + "python" / "triton" / "language" / + "libdevice.10.bc"; + externLibs.try_emplace(libdevice, path.string()); + } + + return externLibs; +} + +static void linkLibdevice(llvm::Module &module) { + // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters + // this will enable fast math path in libdevice + // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to + // sqrt.approx.ftz.f32 + auto &ctx = module.getContext(); + llvm::Type *i32 = llvm::Type::getInt32Ty(ctx); + llvm::Metadata *mdFour = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4)); + llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz"); + llvm::Metadata *mdOne = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1)); + llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne}); + module.addModuleFlag(reflect); +} + +static bool linkExternLib(llvm::Module &module, llvm::StringRef name, + llvm::StringRef path) { + llvm::SMDiagnostic err; + auto &ctx = module.getContext(); + + auto extMod = llvm::parseIRFile(path, err, ctx); + if (!extMod) { + llvm::errs() << "Failed to load " << path; + return true; + } + + extMod->setTargetTriple(module.getTargetTriple()); + extMod->setDataLayout(module.getDataLayout()); + + if (llvm::Linker::linkModules(module, std::move(extMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + llvm::errs() << "Failed to link " << path; + return true; + } + + if (name == "libdevice") { + linkLibdevice(module); + } else { + assert(false && "unknown extern lib: "); + } + + return false; +} + std::unique_ptr translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { - auto context = module->getContext(); DialectRegistry registry; mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); - context->appendDialectRegistry(registry); + module->getContext()->appendDialectRegistry(registry); llvm::DenseMap nvvmMetadata; extractNVVMMetadata(module, &nvvmMetadata); @@ -100,6 +197,20 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { return nullptr; } + // Link external libraries before perform optimizations + // Note from libdevice users guide: + // https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html + // The standard process for linking with libdevice is to first link it with + // the target module, then run the standard LLVM optimization and code + // generation passes. This allows the optimizers to inline and perform + // analyses on the used library functions, and eliminate any used functions as + // dead code. + auto externLibs = getExternLibs(module); + for (auto &lib : externLibs) { + if (linkExternLib(*llvmModule, lib.first, lib.second)) + return nullptr; + } + auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); @@ -147,49 +258,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, return nullptr; } - std::map externLibs; - SmallVector funcs; - module.walk([&](LLVM::LLVMFuncOp func) { - if (func.isExternal()) - funcs.push_back(func); - }); - - for (auto &func : funcs) { - if (func.getOperation()->hasAttr("libname")) { - auto name = - func.getOperation()->getAttr("libname").dyn_cast(); - auto path = - func.getOperation()->getAttr("libpath").dyn_cast(); - if (name) { - std::string lib_name = name.str(); - externLibs[lib_name] = path.str(); - } - } - } - - if (module.getOperation()->hasAttr("triton_gpu.externs")) { - auto dict = module.getOperation() - ->getAttr("triton_gpu.externs") - .dyn_cast(); - for (auto &attr : dict) { - externLibs[attr.getName().strref().trim().str()] = - attr.getValue().dyn_cast().strref().trim().str(); - } - } - - auto llvmir = translateLLVMToLLVMIR(llvmContext, module); - if (!llvmir) { + auto llvmIR = translateLLVMToLLVMIR(llvmContext, module); + if (!llvmIR) { llvm::errs() << "Translate to LLVM IR failed"; return nullptr; } - - llvm::SMDiagnostic err; - for (auto &lib : externLibs) { - if (linkExternLib(*llvmir, lib.second)) - return nullptr; - } - - return llvmir; + return llvmIR; } void addExternalLibs(mlir::ModuleOp &module, @@ -211,27 +285,5 @@ void addExternalLibs(mlir::ModuleOp &module, module.getOperation()->setAttr("triton_gpu.externs", dict); } -bool linkExternLib(llvm::Module &module, llvm::StringRef path) { - llvm::SMDiagnostic err; - auto &ctx = module.getContext(); - - auto extMod = llvm::parseIRFile(path, err, ctx); - if (!extMod) { - llvm::errs() << "Failed to load " << path; - return true; - } - - extMod->setTargetTriple(module.getTargetTriple()); - extMod->setDataLayout(module.getDataLayout()); - - if (llvm::Linker::linkModules(module, std::move(extMod), - llvm::Linker::Flags::LinkOnlyNeeded)) { - llvm::errs() << "Failed to link " << path; - return true; - } - - return false; -} - } // namespace triton } // namespace mlir diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index fae3b5c33..32c99df95 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -8,7 +8,6 @@ #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" -#include namespace triton { @@ -31,47 +30,7 @@ static bool findAndReplace(std::string &str, const std::string &begin, return true; } -static void linkExternal(llvm::Module &module) { - bool hasExternal = false; - for (auto &func : module) { - if (func.hasExternalLinkage()) { - hasExternal = true; - break; - } - } - - if (hasExternal) { - namespace fs = std::filesystem; - // [triton root dir]/python/triton/language/libdevice.10.bc - static const fs::path libdevice = fs::path(__FILE__) - .parent_path() - .parent_path() - .parent_path() - .parent_path() / - "python" / "triton" / "language" / - "libdevice.10.bc"; - if (mlir::triton::linkExternLib(module, libdevice.string())) - llvm::errs() << "link failed for: " << libdevice.string(); - - // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters - // this will enable fast math path in libdevice - // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to - // sqrt.approx.ftz.f32 - auto &ctx = module.getContext(); - llvm::Type *I32 = llvm::Type::getInt32Ty(ctx); - llvm::Metadata *mdFour = - llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); - llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz"); - llvm::Metadata *mdOne = - llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); - llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne}); - module.addModuleFlag(reflect); - } -} - std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { - linkExternal(module); - // LLVM version in use may not officially support target hardware int maxNNVMCC = 75; // options diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index d65931b48..548af578b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1116,11 +1116,13 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: x, y = binary_op_type_checking_impl(x, y, builder) + # FIXME(Keren): not portable, should be fixed from . import libdevice return libdevice.mulhi(x, y, _builder=builder) def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor: + # FIXME(Keren): not portable, should be fixed from . import libdevice return libdevice.floor(x, _builder=builder)