[RUNTIME] Major code cleanup (#711)

This PR does the following:
- CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore.
- Refactoring driver/llvm.cc to split it between PTX codegen and python.
- By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore.
- `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
This commit is contained in:
Philippe Tillet
2022-09-26 16:38:06 -07:00
committed by GitHub
parent 8bb09f83ee
commit 1e91ed30d0
28 changed files with 509 additions and 31483 deletions

View File

@@ -13,7 +13,6 @@
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/driver/llvm.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/Constants.h"
@@ -99,7 +98,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
}
// Initialize LLVM targets.
::triton::driver::init_llvm();
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
auto optPipeline = mlir::makeOptimizingTransformer(

View File

@@ -11,31 +11,129 @@
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/dispatch.h"
#include "triton/driver/llvm.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include <regex>
namespace triton {
void getCuCCAndVersionFromDevice(uint64_t device, int *cc, int *version,
std::string *ptxasPath) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
*cc = major * 10 + minor;
*ptxasPath = driver::path_to_ptxas(*version); // assign version
extern "C" {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
std::tuple<std::string, size_t, int, std::string>
translateTritonGPUToPTX(mlir::ModuleOp module, uint64_t device) {
int cc;
int version;
std::string ptxasPath;
getCuCCAndVersionFromDevice(device, &cc, &version, &ptxasPath);
static void init_llvm() {
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
}
llvm::LLVMContext ctx;
auto llModule = mlir::triton::translateTritonGPUToLLVMIR(&ctx, module);
auto ptxCode = driver::llir_to_ptx(llModule.get(), cc, version);
return std::make_tuple(ptxCode, cc, version, ptxasPath);
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
if (start_replace == std::string::npos)
return false;
size_t end_replace = str.find(end, start_replace);
if (end_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(capability);
// max PTX version
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto ptxCode = llir_to_ptx(&module, cc, version);
return ptxCode;
}
} // namespace triton