[Triton-MLIR] Enable libdevice for ptx backend when has external functions. (#848)

At the phase from ptx to cubin, check whether llvm::Module has external
functions. if has, link with libdevice at:
https://github.com/openai/triton/blob/triton-mlir/python/triton/language/libdevice.10.bc
This commit is contained in:
ben-zhang-609
2022-11-07 16:01:50 +08:00
committed by GitHub
parent fdd59900f7
commit 84ad215268
4 changed files with 69 additions and 17 deletions

View File

@@ -29,6 +29,7 @@
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <filesystem>
#include <regex>
namespace triton {
@@ -61,6 +62,43 @@ static bool find_and_replace(std::string &str, const std::string &begin,
}
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
bool hasExternal = false;
for (auto &func : *module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
}
}
if (hasExternal) {
namespace fs = std::filesystem;
// [triton root dir]/python/triton/language/libdevice.10.bc
static const fs::path libdevice = fs::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(*module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
}
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
{
auto &ctx = module->getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module->addModuleFlag(reflect);
}
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
// int max_nvvm_ptx = 74;