2022-08-18 01:42:48 +08:00
|
|
|
#include "triton/Target/PTX/PTXTranslation.h"
|
|
|
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
2022-09-26 16:38:06 -07:00
|
|
|
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
|
|
#include "llvm/IR/LegacyPassManager.h"
|
|
|
|
#include "llvm/IR/Module.h"
|
|
|
|
#include "llvm/IR/Verifier.h"
|
|
|
|
#include "llvm/MC/TargetRegistry.h"
|
|
|
|
#include "llvm/Support/TargetSelect.h"
|
|
|
|
#include "llvm/Target/TargetMachine.h"
|
2022-11-07 16:01:50 +08:00
|
|
|
#include <filesystem>
|
2022-08-18 01:42:48 +08:00
|
|
|
|
|
|
|
namespace triton {
|
|
|
|
|
2022-09-26 16:38:06 -07:00
|
|
|
static void init_llvm() {
|
|
|
|
LLVMInitializeNVPTXTargetInfo();
|
|
|
|
LLVMInitializeNVPTXTarget();
|
|
|
|
LLVMInitializeNVPTXTargetMC();
|
|
|
|
LLVMInitializeNVPTXAsmPrinter();
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool find_and_replace(std::string &str, const std::string &begin,
|
|
|
|
const std::string &end,
|
|
|
|
const std::string &target) {
|
|
|
|
size_t start_replace = str.find(begin);
|
|
|
|
if (start_replace == std::string::npos)
|
|
|
|
return false;
|
|
|
|
size_t end_replace = str.find(end, start_replace);
|
|
|
|
if (end_replace == std::string::npos)
|
|
|
|
return false;
|
|
|
|
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
2022-11-07 16:01:50 +08:00
|
|
|
bool hasExternal = false;
|
|
|
|
for (auto &func : *module) {
|
|
|
|
if (func.hasExternalLinkage()) {
|
|
|
|
hasExternal = true;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (hasExternal) {
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
// [triton root dir]/python/triton/language/libdevice.10.bc
|
|
|
|
static const fs::path libdevice = fs::path(__FILE__)
|
|
|
|
.parent_path()
|
|
|
|
.parent_path()
|
|
|
|
.parent_path()
|
|
|
|
.parent_path() /
|
|
|
|
"python" / "triton" / "language" /
|
|
|
|
"libdevice.10.bc";
|
|
|
|
if (mlir::triton::linkExternLib(*module, libdevice.string()))
|
|
|
|
llvm::errs() << "link failed for: " << libdevice.string();
|
|
|
|
}
|
|
|
|
|
|
|
|
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
|
|
|
// this will enable fast math path in libdevice
|
|
|
|
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
|
|
|
// sqrt.approx.ftz.f32
|
|
|
|
{
|
|
|
|
auto &ctx = module->getContext();
|
|
|
|
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
|
|
|
llvm::Metadata *mdFour =
|
|
|
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
|
|
|
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
|
|
|
llvm::Metadata *mdOne =
|
|
|
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
|
|
|
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
|
|
|
module->addModuleFlag(reflect);
|
|
|
|
}
|
2022-09-26 16:38:06 -07:00
|
|
|
// LLVM version in use may not officially support target hardware
|
|
|
|
int max_nvvm_cc = 75;
|
2022-10-28 12:36:09 -07:00
|
|
|
// int max_nvvm_ptx = 74;
|
2022-09-26 16:38:06 -07:00
|
|
|
// options
|
|
|
|
auto options = llvm::cl::getRegisteredOptions();
|
|
|
|
auto *short_ptr =
|
|
|
|
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
|
|
|
assert(short_ptr);
|
|
|
|
short_ptr->setValue(true);
|
|
|
|
// compute capability
|
|
|
|
std::string sm = "sm_" + std::to_string(capability);
|
|
|
|
// max PTX version
|
|
|
|
int ptx_major = ptx / 10;
|
|
|
|
int ptx_minor = ptx % 10;
|
|
|
|
// create
|
|
|
|
llvm::SmallVector<char, 0> buffer;
|
|
|
|
std::string triple = "nvptx64-nvidia-cuda";
|
|
|
|
std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc));
|
|
|
|
std::string layout = "";
|
|
|
|
std::string features = "";
|
|
|
|
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
|
|
|
// max_nvvm_ptx));
|
|
|
|
init_llvm();
|
|
|
|
// verify and store llvm
|
|
|
|
llvm::legacy::PassManager pm;
|
|
|
|
pm.add(llvm::createVerifierPass());
|
|
|
|
pm.run(*module);
|
|
|
|
// module->print(llvm::outs(), nullptr);
|
|
|
|
|
|
|
|
// create machine
|
|
|
|
module->setTargetTriple(triple);
|
|
|
|
std::string error;
|
|
|
|
auto target =
|
|
|
|
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
|
|
|
llvm::TargetOptions opt;
|
|
|
|
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
|
|
|
opt.UnsafeFPMath = false;
|
|
|
|
opt.NoInfsFPMath = false;
|
|
|
|
opt.NoNaNsFPMath = true;
|
|
|
|
llvm::TargetMachine *machine = target->createTargetMachine(
|
|
|
|
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
|
|
|
llvm::None, llvm::CodeGenOpt::Aggressive);
|
|
|
|
// set data layout
|
|
|
|
if (layout.empty())
|
|
|
|
module->setDataLayout(machine->createDataLayout());
|
|
|
|
else
|
|
|
|
module->setDataLayout(layout);
|
|
|
|
// emit machine code
|
|
|
|
for (llvm::Function &f : module->functions())
|
|
|
|
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
|
|
|
llvm::legacy::PassManager pass;
|
|
|
|
llvm::raw_svector_ostream stream(buffer);
|
|
|
|
// emit
|
|
|
|
machine->addPassesToEmitFile(pass, stream, nullptr,
|
|
|
|
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
|
|
|
pass.run(*module);
|
|
|
|
|
|
|
|
// post-process
|
|
|
|
std::string result(buffer.begin(), buffer.end());
|
|
|
|
find_and_replace(result, ".version", "\n",
|
|
|
|
".version " + std::to_string(ptx_major) + "." +
|
|
|
|
std::to_string(ptx_minor) + "\n");
|
|
|
|
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
|
|
|
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
|
|
|
|
;
|
|
|
|
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
|
|
|
|
;
|
|
|
|
return result;
|
2022-08-18 01:42:48 +08:00
|
|
|
}
|
|
|
|
|
2022-09-26 16:38:06 -07:00
|
|
|
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
|
|
|
auto ptxCode = llir_to_ptx(&module, cc, version);
|
|
|
|
return ptxCode;
|
2022-08-18 01:42:48 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace triton
|