[DRIVER] Bumped CUDA requirement to 11.4+. This is to avoid bad performance surprises as older ptxas
are much slower. (#769)
This also makes codegen simpler by avoiding special handling of eviction policies
This commit is contained in:
@@ -59,17 +59,21 @@
|
|||||||
#include "llvm/Analysis/TargetLibraryInfo.h"
|
#include "llvm/Analysis/TargetLibraryInfo.h"
|
||||||
// end AMD stuff
|
// end AMD stuff
|
||||||
|
|
||||||
extern "C"{
|
extern "C"
|
||||||
|
{
|
||||||
int set_curterm(char *nterm) { return 0; }
|
int set_curterm(char *nterm) { return 0; }
|
||||||
int del_curterm(char *nterm) { return 0; }
|
int del_curterm(char *nterm) { return 0; }
|
||||||
int tigetnum(char *capname) { return 0; }
|
int tigetnum(char *capname) { return 0; }
|
||||||
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
int setupterm(char *term, int fildes, int *errret) { return 0; }
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace triton{
|
namespace triton
|
||||||
namespace driver{
|
{
|
||||||
|
namespace driver
|
||||||
|
{
|
||||||
|
|
||||||
void init_llvm() {
|
void init_llvm()
|
||||||
|
{
|
||||||
LLVMInitializeNVPTXTargetInfo();
|
LLVMInitializeNVPTXTargetInfo();
|
||||||
LLVMInitializeNVPTXTarget();
|
LLVMInitializeNVPTXTarget();
|
||||||
LLVMInitializeNVPTXTargetMC();
|
LLVMInitializeNVPTXTargetMC();
|
||||||
@@ -80,11 +84,11 @@ void init_llvm() {
|
|||||||
LLVMInitializeAMDGPUAsmPrinter();
|
LLVMInitializeAMDGPUAsmPrinter();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// CUDA //
|
// CUDA //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
|
static bool find_and_replace(std::string &str, const std::string &begin, const std::string &end, const std::string &target)
|
||||||
|
{
|
||||||
size_t start_replace = str.find(begin);
|
size_t start_replace = str.find(begin);
|
||||||
size_t end_replace = str.find(end, start_replace);
|
size_t end_replace = str.find(end, start_replace);
|
||||||
if (start_replace == std::string::npos)
|
if (start_replace == std::string::npos)
|
||||||
@@ -93,7 +97,8 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string path_to_ptxas(int& version) {
|
std::string path_to_ptxas(int &version)
|
||||||
|
{
|
||||||
std::vector<std::string> rets;
|
std::vector<std::string> rets;
|
||||||
std::string ret;
|
std::string ret;
|
||||||
// search paths for ptxas
|
// search paths for ptxas
|
||||||
@@ -103,10 +108,12 @@ std::string path_to_ptxas(int& version) {
|
|||||||
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
|
||||||
// see what path for ptxas are valid
|
// see what path for ptxas are valid
|
||||||
std::vector<std::string> working_ptxas;
|
std::vector<std::string> working_ptxas;
|
||||||
for(std::string prefix: ptxas_prefixes){
|
for (std::string prefix : ptxas_prefixes)
|
||||||
|
{
|
||||||
std::string ptxas = prefix + "ptxas";
|
std::string ptxas = prefix + "ptxas";
|
||||||
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
|
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
|
||||||
if(works) {
|
if (works)
|
||||||
|
{
|
||||||
working_ptxas.push_back(ptxas);
|
working_ptxas.push_back(ptxas);
|
||||||
rets.push_back(ret);
|
rets.push_back(ret);
|
||||||
}
|
}
|
||||||
@@ -121,8 +128,10 @@ std::string path_to_ptxas(int& version) {
|
|||||||
std::smatch match;
|
std::smatch match;
|
||||||
bool found = false;
|
bool found = false;
|
||||||
// currently choosing the first ptxas. Other logics can be implemented in future
|
// currently choosing the first ptxas. Other logics can be implemented in future
|
||||||
for(std::string ret : rets) {
|
for (std::string ret : rets)
|
||||||
if(std::regex_search(ret, match, version_regex)){
|
{
|
||||||
|
if (std::regex_search(ret, match, version_regex))
|
||||||
|
{
|
||||||
int major = std::stoi(match[1]);
|
int major = std::stoi(match[1]);
|
||||||
int minor = std::stoi(match[2]);
|
int minor = std::stoi(match[2]);
|
||||||
version = major * 1000 + minor * 10;
|
version = major * 1000 + minor * 10;
|
||||||
@@ -130,26 +139,29 @@ std::string path_to_ptxas(int& version) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ( not found) {
|
if (not found)
|
||||||
|
{
|
||||||
throw std::runtime_error("Error in parsing version");
|
throw std::runtime_error("Error in parsing version");
|
||||||
}
|
}
|
||||||
return ptxas;
|
return ptxas;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int vptx(int version)
|
||||||
int vptx(int version){
|
{
|
||||||
if(version >= 11040) return 74;
|
if (version >= 11040)
|
||||||
if(version >= 11030) return 73;
|
return 74;
|
||||||
if(version >= 11020) return 72;
|
// if(version >= 11030) return 73;
|
||||||
if(version >= 11010) return 71;
|
// if(version >= 11020) return 72;
|
||||||
if(version >= 11000) return 70;
|
// if(version >= 11010) return 71;
|
||||||
if(version >= 10020) return 65;
|
// if(version >= 11000) return 70;
|
||||||
if(version >= 10010) return 64;
|
// if(version >= 10020) return 65;
|
||||||
if(version >= 10000) return 63;
|
// if(version >= 10010) return 64;
|
||||||
throw std::runtime_error("Triton requires CUDA 10+");
|
// if(version >= 10000) return 63;
|
||||||
|
throw std::runtime_error("Triton requires CUDA 11.4+");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
std::string llir_to_ptx(llvm::Module *module, int cc, int version)
|
||||||
|
{
|
||||||
// LLVM version in use may not officially support target hardware
|
// LLVM version in use may not officially support target hardware
|
||||||
int max_nvvm_cc = 75;
|
int max_nvvm_cc = 75;
|
||||||
int max_nvvm_ptx = 74;
|
int max_nvvm_ptx = 74;
|
||||||
@@ -209,13 +221,15 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
std::string result(buffer.begin(), buffer.end());
|
std::string result(buffer.begin(), buffer.end());
|
||||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
|
||||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
;
|
||||||
|
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
|
||||||
|
;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas, int cc)
|
||||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
{
|
||||||
// compile ptx with ptxas
|
// compile ptx with ptxas
|
||||||
char _fsrc[L_tmpnam];
|
char _fsrc[L_tmpnam];
|
||||||
char _flog[L_tmpnam];
|
char _flog[L_tmpnam];
|
||||||
@@ -232,7 +246,8 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
|||||||
int err;
|
int err;
|
||||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||||
err = system(cmd.c_str());
|
err = system(cmd.c_str());
|
||||||
if(err != 0){
|
if (err != 0)
|
||||||
|
{
|
||||||
std::ifstream _log(_flog);
|
std::ifstream _log(_flog);
|
||||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||||
unlink(_fsrc);
|
unlink(_fsrc);
|
||||||
@@ -252,7 +267,8 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
|||||||
// HIP //
|
// HIP //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
|
|
||||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
|
||||||
|
{
|
||||||
init_llvm();
|
init_llvm();
|
||||||
|
|
||||||
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
|
||||||
@@ -331,8 +347,8 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
|
|||||||
return hsaco_path;
|
return hsaco_path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hipModule_t amdgpu_to_hipmodule(const std::string &path)
|
||||||
hipModule_t amdgpu_to_hipmodule(const std::string& path) {
|
{
|
||||||
// Read HSACO.
|
// Read HSACO.
|
||||||
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
|
||||||
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
|
||||||
|
Reference in New Issue
Block a user