[DRIVER] Bumped CUDA requirement to 11.4+. This is to avoid bad performance surprises as older ptxas are much slower. (#769)

This also makes codegen simpler by avoiding special handling of eviction policies
This commit is contained in:
Philippe Tillet
2022-10-12 12:02:30 -07:00
committed by GitHub
parent af76c989eb
commit 33e6f0df7f

View File

@@ -1,27 +1,27 @@
/* Copyright 2015-2017 Philippe Tillet /* Copyright 2015-2017 Philippe Tillet
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files * a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction, * (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge, * including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software, * publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so, * and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions: * subject to the following conditions:
* *
* The above copyright notice and this permission notice shall be * The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software. * included in all copies or substantial portions of the Software.
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
#include <fstream> #include <fstream>
#if __has_include(<unistd.h>) #if __has_include(<unistd.h>)
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <memory> #include <memory>
#include <regex> #include <regex>
@@ -59,17 +59,21 @@
#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h"
// end AMD stuff // end AMD stuff
extern "C"{ extern "C"
int set_curterm(char* nterm){ return 0; } {
int del_curterm(char* nterm){ return 0; } int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; } int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; } int setupterm(char *term, int fildes, int *errret) { return 0; }
} }
namespace triton{ namespace triton
namespace driver{ {
namespace driver
{
void init_llvm() { void init_llvm()
{
LLVMInitializeNVPTXTargetInfo(); LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget(); LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXTargetMC();
@@ -78,41 +82,44 @@ void init_llvm() {
LLVMInitializeAMDGPUTarget(); LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC(); LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter(); LLVMInitializeAMDGPUAsmPrinter();
} }
/* ------------------------ */
/* ------------------------ */ // CUDA //
// CUDA // /* ------------------------ */
/* ------------------------ */ static bool find_and_replace(std::string &str, const std::string &begin, const std::string &end, const std::string &target)
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){ {
size_t start_replace = str.find(begin); size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace); size_t end_replace = str.find(end, start_replace);
if(start_replace == std::string::npos) if (start_replace == std::string::npos)
return false; return false;
str.replace(start_replace, end_replace + 1 - start_replace, target); str.replace(start_replace, end_replace + 1 - start_replace, target);
return true; return true;
} }
std::string path_to_ptxas(int& version) { std::string path_to_ptxas(int &version)
{
std::vector<std::string> rets; std::vector<std::string> rets;
std::string ret; std::string ret;
// search paths for ptxas // search paths for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty()) if (!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid // see what path for ptxas are valid
std::vector<std::string> working_ptxas; std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){ for (std::string prefix : ptxas_prefixes)
{
std::string ptxas = prefix + "ptxas"; std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0; bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if(works) { if (works)
{
working_ptxas.push_back(ptxas); working_ptxas.push_back(ptxas);
rets.push_back(ret); rets.push_back(ret);
} }
} }
// error if no working ptxas was found // error if no working ptxas was found
if(working_ptxas.empty()) if (working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found."); " but a working version could not be found.");
std::string ptxas = working_ptxas.front(); std::string ptxas = working_ptxas.front();
@@ -121,41 +128,46 @@ std::string path_to_ptxas(int& version) {
std::smatch match; std::smatch match;
bool found = false; bool found = false;
// currently choosing the first ptxas. Other logics can be implemented in future // currently choosing the first ptxas. Other logics can be implemented in future
for(std::string ret : rets) { for (std::string ret : rets)
if(std::regex_search(ret, match, version_regex)){ {
if (std::regex_search(ret, match, version_regex))
{
int major = std::stoi(match[1]); int major = std::stoi(match[1]);
int minor = std::stoi(match[2]); int minor = std::stoi(match[2]);
version = major*1000 + minor*10; version = major * 1000 + minor * 10;
found = true; found = true;
break; break;
} }
} }
if ( not found) { if (not found)
{
throw std::runtime_error("Error in parsing version"); throw std::runtime_error("Error in parsing version");
} }
return ptxas; return ptxas;
} }
int vptx(int version)
{
if (version >= 11040)
return 74;
// if(version >= 11030) return 73;
// if(version >= 11020) return 72;
// if(version >= 11010) return 71;
// if(version >= 11000) return 70;
// if(version >= 10020) return 65;
// if(version >= 10010) return 64;
// if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 11.4+");
}
int vptx(int version){ std::string llir_to_ptx(llvm::Module *module, int cc, int version)
if(version >= 11040) return 74; {
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
if(version >= 11000) return 70;
if(version >= 10020) return 65;
if(version >= 10010) return 64;
if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// LLVM version in use may not officially support target hardware // LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75; int max_nvvm_cc = 75;
int max_nvvm_ptx = 74; int max_nvvm_ptx = 74;
// options // options
auto options = llvm::cl::getRegisteredOptions(); auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]); auto *short_ptr = static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr); assert(short_ptr);
short_ptr->setValue(true); short_ptr->setValue(true);
// compute capability // compute capability
@@ -174,7 +186,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
init_llvm(); init_llvm();
// verify and store llvm // verify and store llvm
llvm::legacy::PassManager pm; llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs())); // pm.add(llvm::createPrintModulePass(llvm::outs()));
pm.add(llvm::createVerifierPass()); pm.add(llvm::createVerifierPass());
pm.run(*module); pm.run(*module);
// module->print(llvm::outs(), nullptr); // module->print(llvm::outs(), nullptr);
@@ -182,7 +194,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// create machine // create machine
module->setTargetTriple(triple); module->setTargetTriple(triple);
std::string error; std::string error;
llvm::TargetMachine* machine; llvm::TargetMachine *machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt; llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
@@ -192,7 +204,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout // set data layout
if(layout.empty()) if (layout.empty())
module->setDataLayout(machine->createDataLayout()); module->setDataLayout(machine->createDataLayout());
else else
module->setDataLayout(layout); module->setDataLayout(layout);
@@ -209,13 +221,15 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string result(buffer.begin(), buffer.end()); std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", "")); while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
while(find_and_replace(result, "\t// end inline asm", "\n", "")); ;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result; return result;
} }
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas, int cc)
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) { {
// compile ptx with ptxas // compile ptx with ptxas
char _fsrc[L_tmpnam]; char _fsrc[L_tmpnam];
char _flog[L_tmpnam]; char _flog[L_tmpnam];
@@ -224,7 +238,7 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
std::string fsrc = _fsrc; std::string fsrc = _fsrc;
std::string flog = _flog; std::string flog = _flog;
std::string fbin = fsrc + ".o"; std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str(); const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc); std::ofstream ofs(fsrc);
ofs << ptx << std::endl; ofs << ptx << std::endl;
ofs.close(); ofs.close();
@@ -232,31 +246,33 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
int err; int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str()); err = system(cmd.c_str());
if(err != 0){ if (err != 0)
{
std::ifstream _log(_flog); std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {}); std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc); unlink(_fsrc);
unlink(_flog); unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
} }
std::ifstream _cubin(_fbin, std::ios::binary ); std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {}); std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close(); _cubin.close();
unlink(_fsrc); unlink(_fsrc);
unlink(_flog); unlink(_flog);
unlink(_fbin); unlink(_fbin);
return cubin; return cubin;
} }
/* ------------------------ */ /* ------------------------ */
// HIP // // HIP //
/* ------------------------ */ /* ------------------------ */
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) { std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
{
init_llvm(); init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo)); // proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo)); // features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create // create
llvm::SmallVector<char, 0> buffer; llvm::SmallVector<char, 0> buffer;
@@ -281,7 +297,7 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
llvm::Reloc::PIC_, llvm::None, llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive); llvm::CodeGenOpt::Aggressive);
// set data layout // set data layout
if(layout.empty()) if (layout.empty())
module->setDataLayout(machine->createDataLayout()); module->setDataLayout(machine->createDataLayout());
else else
module->setDataLayout(layout); module->setDataLayout(layout);
@@ -329,17 +345,17 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
} }
return hsaco_path; return hsaco_path;
} }
hipModule_t amdgpu_to_hipmodule(const std::string &path)
hipModule_t amdgpu_to_hipmodule(const std::string& path) { {
// Read HSACO. // Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate); std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size); std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg); hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size); hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
hsaco_file.close(); hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer, hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer, hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
@@ -348,13 +364,13 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) {
const unsigned int logbufsize = 8192; const unsigned int logbufsize = 8192;
char _err[errbufsize]; char _err[errbufsize];
char _log[logbufsize]; char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, void *optval[] = {(void *)(uintptr_t)errbufsize,
(void*)_err, (void*)(uintptr_t)logbufsize, (void *)_err, (void *)(uintptr_t)logbufsize,
(void*)_log, (void*)1}; (void *)_log, (void *)1};
hipModule_t ret; hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval); dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret; return ret;
} }
} // namespace driver } // namespace driver
} // namespace triton } // namespace triton