[Triton-MLIR] Enable libdevice for ptx backend when has external functions. (#848)

At the phase from ptx to cubin, check whether llvm::Module has external
functions. if has, link with libdevice at:
https://github.com/openai/triton/blob/triton-mlir/python/triton/language/libdevice.10.bc
This commit is contained in:
ben-zhang-609
2022-11-07 16:01:50 +08:00
committed by GitHub
parent fdd59900f7
commit 84ad215268
4 changed files with 69 additions and 17 deletions

View File

@@ -1,5 +1,6 @@
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H #ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIRTRANSLATION_H #define TRITON_TARGET_LLVMIRTRANSLATION_H
#include "llvm/ADT/StringRef.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@@ -29,6 +30,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
std::unique_ptr<llvm::Module> std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module); translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -151,7 +151,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr; return nullptr;
} }
std::map<std::string, std::string> extern_libs; std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs; SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) { module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal()) if (func.isExternal())
@@ -166,7 +166,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>(); func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) { if (name) {
std::string lib_name = name.str(); std::string lib_name = name.str();
extern_libs[lib_name] = path.str(); externLibs[lib_name] = path.str();
} }
} }
} }
@@ -176,7 +176,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
->getAttr("triton_gpu.externs") ->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>(); .dyn_cast<DictionaryAttr>();
for (auto &attr : dict) { for (auto &attr : dict) {
extern_libs[attr.getName().strref().trim().str()] = externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str(); attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
} }
} }
@@ -188,20 +188,9 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
} }
llvm::SMDiagnostic err; llvm::SMDiagnostic err;
for (auto &lib : extern_libs) { for (auto &lib : externLibs) {
auto ext_mod = llvm::parseIRFile(lib.second, err, *llvmContext); if (linkExternLib(*llvmir, lib.second))
if (!ext_mod) {
llvm::errs() << "Failed to load extern lib " << lib.first;
return nullptr; return nullptr;
}
ext_mod->setTargetTriple(llvmir->getTargetTriple());
ext_mod->setDataLayout(llvmir->getDataLayout());
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link extern lib " << lib.first;
return nullptr;
}
} }
return llvmir; return llvmir;
@@ -227,5 +216,27 @@ void addExternalLibs(mlir::ModuleOp &module,
return; return;
} }
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
return false;
}
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -29,6 +29,7 @@
#include "llvm/Target/TargetOptions.h" #include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Cloning.h"
#include <filesystem>
#include <regex> #include <regex>
namespace triton { namespace triton {
@@ -61,6 +62,43 @@ static bool find_and_replace(std::string &str, const std::string &begin,
} }
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) { static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
bool hasExternal = false;
for (auto &func : *module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
}
}
if (hasExternal) {
namespace fs = std::filesystem;
// [triton root dir]/python/triton/language/libdevice.10.bc
static const fs::path libdevice = fs::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(*module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
}
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
{
auto &ctx = module->getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module->addModuleFlag(reflect);
}
// LLVM version in use may not officially support target hardware // LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75; int max_nvvm_cc = 75;
// int max_nvvm_ptx = 74; // int max_nvvm_ptx = 74;

View File

@@ -144,7 +144,7 @@ def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
# triton result # triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x) x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, extern_libs={"libdevice": "/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"}) kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
# compare # compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)