|
|
|
@@ -18,6 +18,7 @@
|
|
|
|
|
#include "llvm/IRReader/IRReader.h"
|
|
|
|
|
#include "llvm/Linker/Linker.h"
|
|
|
|
|
#include "llvm/Support/SourceMgr.h"
|
|
|
|
|
#include <filesystem>
|
|
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
|
namespace triton {
|
|
|
|
@@ -26,19 +27,18 @@ namespace triton {
|
|
|
|
|
// information from mlir module.
|
|
|
|
|
struct NVVMMetadata {
|
|
|
|
|
int maxntidx{-1};
|
|
|
|
|
bool is_kernel{};
|
|
|
|
|
bool isKernel{};
|
|
|
|
|
// Free to extend with other information.
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Add the nvvm related metadata to LLVM IR.
|
|
|
|
|
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
|
|
|
|
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
|
|
|
|
auto *module = func->getParent();
|
|
|
|
|
auto &ctx = func->getContext();
|
|
|
|
|
|
|
|
|
|
if (metadata.maxntidx > 0) {
|
|
|
|
|
auto i32_ty = llvm::IntegerType::get(ctx, 32);
|
|
|
|
|
auto warps =
|
|
|
|
|
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
|
|
|
|
|
auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32),
|
|
|
|
|
llvm::APInt(32, metadata.maxntidx));
|
|
|
|
|
|
|
|
|
|
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
|
|
|
|
|
llvm::MDString::get(ctx, "maxntidx"),
|
|
|
|
@@ -48,17 +48,18 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
|
|
|
|
->addOperand(llvm::MDNode::get(ctx, md_args));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (metadata.is_kernel) {
|
|
|
|
|
llvm::Metadata *md_args[] = {
|
|
|
|
|
if (metadata.isKernel) {
|
|
|
|
|
llvm::Metadata *mdArgs[] = {
|
|
|
|
|
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
|
|
|
|
llvm::ValueAsMetadata::get(
|
|
|
|
|
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
|
|
|
|
module->getOrInsertNamedMetadata("nvvm.annotations")
|
|
|
|
|
->addOperand(llvm::MDNode::get(ctx, md_args));
|
|
|
|
|
->addOperand(llvm::MDNode::get(ctx, mdArgs));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void extractNVVMMetadata(mlir::ModuleOp module,
|
|
|
|
|
static void
|
|
|
|
|
extractNVVMMetadata(mlir::ModuleOp module,
|
|
|
|
|
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
|
|
|
|
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
|
|
|
|
|
NVVMMetadata meta;
|
|
|
|
@@ -74,7 +75,7 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|
|
|
|
|
|
|
|
|
// kernel
|
|
|
|
|
if (op->hasAttr("nvvm.kernel")) {
|
|
|
|
|
meta.is_kernel = true;
|
|
|
|
|
meta.isKernel = true;
|
|
|
|
|
hasMetadata = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -83,13 +84,109 @@ void extractNVVMMetadata(mlir::ModuleOp module,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
|
|
|
|
std::map<std::string, std::string> externLibs;
|
|
|
|
|
SmallVector<LLVM::LLVMFuncOp> funcs;
|
|
|
|
|
module.walk([&](LLVM::LLVMFuncOp func) {
|
|
|
|
|
if (func.isExternal())
|
|
|
|
|
funcs.push_back(func);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
for (auto &func : funcs) {
|
|
|
|
|
if (func.getOperation()->hasAttr("libname")) {
|
|
|
|
|
auto name =
|
|
|
|
|
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
|
|
|
|
auto path =
|
|
|
|
|
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
|
|
|
|
if (name) {
|
|
|
|
|
std::string libName = name.str();
|
|
|
|
|
externLibs[libName] = path.str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
|
|
|
|
auto dict = module.getOperation()
|
|
|
|
|
->getAttr("triton_gpu.externs")
|
|
|
|
|
.dyn_cast<DictionaryAttr>();
|
|
|
|
|
for (auto &attr : dict) {
|
|
|
|
|
externLibs[attr.getName().strref().trim().str()] =
|
|
|
|
|
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!funcs.empty()) {
|
|
|
|
|
// When using the Math Dialect, it is possible that some ops (e.g., log) are
|
|
|
|
|
// lowered to a function call. In this case, we need to link libdevice
|
|
|
|
|
// using its default path:
|
|
|
|
|
// [triton root dir]/python/triton/language/libdevice.10.bc
|
|
|
|
|
// TODO(Keren): handle external linkage other than libdevice?
|
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
|
static const std::string libdevice = "libdevice";
|
|
|
|
|
static const std::filesystem::path path = std::filesystem::path(__FILE__)
|
|
|
|
|
.parent_path()
|
|
|
|
|
.parent_path()
|
|
|
|
|
.parent_path()
|
|
|
|
|
.parent_path() /
|
|
|
|
|
"python" / "triton" / "language" /
|
|
|
|
|
"libdevice.10.bc";
|
|
|
|
|
externLibs.try_emplace(libdevice, path.string());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return externLibs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void linkLibdevice(llvm::Module &module) {
|
|
|
|
|
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
|
|
|
|
// this will enable fast math path in libdevice
|
|
|
|
|
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
|
|
|
|
// sqrt.approx.ftz.f32
|
|
|
|
|
auto &ctx = module.getContext();
|
|
|
|
|
llvm::Type *i32 = llvm::Type::getInt32Ty(ctx);
|
|
|
|
|
llvm::Metadata *mdFour =
|
|
|
|
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4));
|
|
|
|
|
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
|
|
|
|
llvm::Metadata *mdOne =
|
|
|
|
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1));
|
|
|
|
|
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
|
|
|
|
module.addModuleFlag(reflect);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
|
|
|
|
llvm::StringRef path) {
|
|
|
|
|
llvm::SMDiagnostic err;
|
|
|
|
|
auto &ctx = module.getContext();
|
|
|
|
|
|
|
|
|
|
auto extMod = llvm::parseIRFile(path, err, ctx);
|
|
|
|
|
if (!extMod) {
|
|
|
|
|
llvm::errs() << "Failed to load " << path;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
extMod->setTargetTriple(module.getTargetTriple());
|
|
|
|
|
extMod->setDataLayout(module.getDataLayout());
|
|
|
|
|
|
|
|
|
|
if (llvm::Linker::linkModules(module, std::move(extMod),
|
|
|
|
|
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
|
|
|
|
llvm::errs() << "Failed to link " << path;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (name == "libdevice") {
|
|
|
|
|
linkLibdevice(module);
|
|
|
|
|
} else {
|
|
|
|
|
assert(false && "unknown extern lib: ");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<llvm::Module>
|
|
|
|
|
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
|
|
|
|
auto context = module->getContext();
|
|
|
|
|
DialectRegistry registry;
|
|
|
|
|
mlir::registerLLVMDialectTranslation(registry);
|
|
|
|
|
mlir::registerNVVMDialectTranslation(registry);
|
|
|
|
|
context->appendDialectRegistry(registry);
|
|
|
|
|
module->getContext()->appendDialectRegistry(registry);
|
|
|
|
|
|
|
|
|
|
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
|
|
|
|
extractNVVMMetadata(module, &nvvmMetadata);
|
|
|
|
@@ -100,6 +197,20 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Link external libraries before perform optimizations
|
|
|
|
|
// Note from libdevice users guide:
|
|
|
|
|
// https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html
|
|
|
|
|
// The standard process for linking with libdevice is to first link it with
|
|
|
|
|
// the target module, then run the standard LLVM optimization and code
|
|
|
|
|
// generation passes. This allows the optimizers to inline and perform
|
|
|
|
|
// analyses on the used library functions, and eliminate any used functions as
|
|
|
|
|
// dead code.
|
|
|
|
|
auto externLibs = getExternLibs(module);
|
|
|
|
|
for (auto &lib : externLibs) {
|
|
|
|
|
if (linkExternLib(*llvmModule, lib.first, lib.second))
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto optPipeline = mlir::makeOptimizingTransformer(
|
|
|
|
|
/*optLevel=*/3, /*sizeLevel=*/0,
|
|
|
|
|
/*targetMachine=*/nullptr);
|
|
|
|
@@ -147,49 +258,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::string> externLibs;
|
|
|
|
|
SmallVector<LLVM::LLVMFuncOp> funcs;
|
|
|
|
|
module.walk([&](LLVM::LLVMFuncOp func) {
|
|
|
|
|
if (func.isExternal())
|
|
|
|
|
funcs.push_back(func);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
for (auto &func : funcs) {
|
|
|
|
|
if (func.getOperation()->hasAttr("libname")) {
|
|
|
|
|
auto name =
|
|
|
|
|
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
|
|
|
|
auto path =
|
|
|
|
|
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
|
|
|
|
if (name) {
|
|
|
|
|
std::string lib_name = name.str();
|
|
|
|
|
externLibs[lib_name] = path.str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
|
|
|
|
auto dict = module.getOperation()
|
|
|
|
|
->getAttr("triton_gpu.externs")
|
|
|
|
|
.dyn_cast<DictionaryAttr>();
|
|
|
|
|
for (auto &attr : dict) {
|
|
|
|
|
externLibs[attr.getName().strref().trim().str()] =
|
|
|
|
|
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
|
|
|
|
if (!llvmir) {
|
|
|
|
|
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
|
|
|
|
|
if (!llvmIR) {
|
|
|
|
|
llvm::errs() << "Translate to LLVM IR failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::SMDiagnostic err;
|
|
|
|
|
for (auto &lib : externLibs) {
|
|
|
|
|
if (linkExternLib(*llvmir, lib.second))
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return llvmir;
|
|
|
|
|
return llvmIR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void addExternalLibs(mlir::ModuleOp &module,
|
|
|
|
@@ -211,27 +285,5 @@ void addExternalLibs(mlir::ModuleOp &module,
|
|
|
|
|
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
|
|
|
|
|
llvm::SMDiagnostic err;
|
|
|
|
|
auto &ctx = module.getContext();
|
|
|
|
|
|
|
|
|
|
auto extMod = llvm::parseIRFile(path, err, ctx);
|
|
|
|
|
if (!extMod) {
|
|
|
|
|
llvm::errs() << "Failed to load " << path;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
extMod->setTargetTriple(module.getTargetTriple());
|
|
|
|
|
extMod->setDataLayout(module.getDataLayout());
|
|
|
|
|
|
|
|
|
|
if (llvm::Linker::linkModules(module, std::move(extMod),
|
|
|
|
|
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
|
|
|
|
llvm::errs() << "Failed to link " << path;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace triton
|
|
|
|
|
} // namespace mlir
|
|
|
|
|