[Triton-MLIR] Enable libdevice for ptx backend when has external functions. (#848)
At the phase from ptx to cubin, check whether llvm::Module has external functions. if has, link with libdevice at: https://github.com/openai/triton/blob/triton-mlir/python/triton/language/libdevice.10.bc
This commit is contained in:
@@ -29,6 +29,7 @@
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/Transforms/Scalar.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include <filesystem>
|
||||
#include <regex>
|
||||
|
||||
namespace triton {
|
||||
@@ -61,6 +62,43 @@ static bool find_and_replace(std::string &str, const std::string &begin,
|
||||
}
|
||||
|
||||
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
||||
bool hasExternal = false;
|
||||
for (auto &func : *module) {
|
||||
if (func.hasExternalLinkage()) {
|
||||
hasExternal = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasExternal) {
|
||||
namespace fs = std::filesystem;
|
||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||
static const fs::path libdevice = fs::path(__FILE__)
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path() /
|
||||
"python" / "triton" / "language" /
|
||||
"libdevice.10.bc";
|
||||
if (mlir::triton::linkExternLib(*module, libdevice.string()))
|
||||
llvm::errs() << "link failed for: " << libdevice.string();
|
||||
}
|
||||
|
||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||
// this will enable fast math path in libdevice
|
||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||
// sqrt.approx.ftz.f32
|
||||
{
|
||||
auto &ctx = module->getContext();
|
||||
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::Metadata *mdFour =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||
llvm::Metadata *mdOne =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
||||
module->addModuleFlag(reflect);
|
||||
}
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
// int max_nvvm_ptx = 74;
|
||||
|
Reference in New Issue
Block a user