diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 54ee54132..d2994f923 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -99,9 +99,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { return nullptr; } - // Initialize LLVM targets. - mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); - auto optPipeline = mlir::makeOptimizingTransformer( /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index 40846fa86..fae3b5c33 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -12,29 +12,28 @@ namespace triton { -static void init_llvm() { +static void initLLVM() { LLVMInitializeNVPTXTargetInfo(); LLVMInitializeNVPTXTarget(); LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXAsmPrinter(); } -static bool find_and_replace(std::string &str, const std::string &begin, - const std::string &end, - const std::string &target) { - size_t start_replace = str.find(begin); - if (start_replace == std::string::npos) +static bool findAndReplace(std::string &str, const std::string &begin, + const std::string &end, const std::string &target) { + size_t startReplace = str.find(begin); + if (startReplace == std::string::npos) return false; - size_t end_replace = str.find(end, start_replace); - if (end_replace == std::string::npos) + size_t endReplace = str.find(end, startReplace); + if (endReplace == std::string::npos) return false; - str.replace(start_replace, end_replace + 1 - start_replace, target); + str.replace(startReplace, endReplace + 1 - startReplace, target); return true; } -static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { +static void linkExternal(llvm::Module &module) { bool hasExternal = false; - for (auto &func : *module) { + for (auto &func : module) { if (func.hasExternalLinkage()) { hasExternal = true; break; @@ -51,16 +50,14 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { .parent_path() / "python" / "triton" / "language" / "libdevice.10.bc"; - if (mlir::triton::linkExternLib(*module, libdevice.string())) + if (mlir::triton::linkExternLib(module, libdevice.string())) llvm::errs() << "link failed for: " << libdevice.string(); - } - // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters - // this will enable fast math path in libdevice - // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to - // sqrt.approx.ftz.f32 - { - auto &ctx = module->getContext(); + // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters + // this will enable fast math path in libdevice + // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to + // sqrt.approx.ftz.f32 + auto &ctx = module.getContext(); llvm::Type *I32 = llvm::Type::getInt32Ty(ctx); llvm::Metadata *mdFour = llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); @@ -68,81 +65,80 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { llvm::Metadata *mdOne = llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne}); - module->addModuleFlag(reflect); + module.addModuleFlag(reflect); } +} + +std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { + linkExternal(module); + // LLVM version in use may not officially support target hardware - int max_nvvm_cc = 75; - // int max_nvvm_ptx = 74; + int maxNNVMCC = 75; // options auto options = llvm::cl::getRegisteredOptions(); - auto *short_ptr = + auto *shortPtr = static_cast *>(options["nvptx-short-ptr"]); - assert(short_ptr); - short_ptr->setValue(true); + assert(shortPtr); + shortPtr->setValue(true); // compute capability - std::string sm = "sm_" + std::to_string(capability); + std::string sm = "sm_" + std::to_string(cc); // max PTX version - int ptx_major = ptx / 10; - int ptx_minor = ptx % 10; + int ptxMajor = version / 10; + int ptxMinor = version % 10; // create llvm::SmallVector buffer; std::string triple = "nvptx64-nvidia-cuda"; - std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc)); + std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC)); std::string layout = ""; std::string features = ""; // std::string features = "+ptx" + std::to_string(std::min(ptx, // max_nvvm_ptx)); - init_llvm(); + initLLVM(); // verify and store llvm llvm::legacy::PassManager pm; pm.add(llvm::createVerifierPass()); - pm.run(*module); + pm.run(module); // module->print(llvm::outs(), nullptr); // create machine - module->setTargetTriple(triple); + module.setTargetTriple(triple); std::string error; auto target = - llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); llvm::TargetOptions opt; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; llvm::TargetMachine *machine = target->createTargetMachine( - module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); // set data layout if (layout.empty()) - module->setDataLayout(machine->createDataLayout()); + module.setDataLayout(machine->createDataLayout()); else - module->setDataLayout(layout); + module.setDataLayout(layout); // emit machine code - for (llvm::Function &f : module->functions()) + for (llvm::Function &f : module.functions()) f.addFnAttr(llvm::Attribute::AlwaysInline); llvm::legacy::PassManager pass; llvm::raw_svector_ostream stream(buffer); // emit machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile); - pass.run(*module); + pass.run(module); // post-process std::string result(buffer.begin(), buffer.end()); - find_and_replace(result, ".version", "\n", - ".version " + std::to_string(ptx_major) + "." + - std::to_string(ptx_minor) + "\n"); - find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); - while (find_and_replace(result, "\t// begin inline asm", "\n", "")) + findAndReplace(result, ".version", "\n", + ".version " + std::to_string(ptxMajor) + "." + + std::to_string(ptxMinor) + "\n"); + findAndReplace(result, ".target", "\n", ".target " + sm + "\n"); + while (findAndReplace(result, "\t// begin inline asm", "\n", "")) ; - while (find_and_replace(result, "\t// end inline asm", "\n", "")) + while (findAndReplace(result, "\t// end inline asm", "\n", "")) ; return result; } -std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { - auto ptxCode = llir_to_ptx(&module, cc, version); - return ptxCode; -} - } // namespace triton