[GENERAL] Some minor fixups (#393)

* [RUNTIME] Now displaying error message when generated PTX is invalid

* [CODEGEN] Now converting `if` condition to bool implicitly
This commit is contained in:
Philippe Tillet
2021-12-17 18:06:21 -08:00
committed by GitHub
parent e062812969
commit 4e93b41c52
2 changed files with 81 additions and 60 deletions

View File

@@ -158,12 +158,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
}
std::string ptx_to_cubin(const std::string& ptx, int cc) {
std::string ptxas = "ptxas";
std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
if(!use_system_ptxas)
return "";
// search pathes for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", version) == 0;
if(works)
working_ptxas.push_back(ptxas);
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// compile ptx with ptxas
char _fsrc[] = "/tmp/triton_k_XXXXXX";
char _flog[] = "/tmp/triton_l_XXXXXX";
@@ -180,6 +193,11 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if(err != 0){
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
@@ -191,62 +209,62 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
return cubin;
}
CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
// JIT compile source-code
try{
// use ptxas if present in PATH. Otherwise, use JIT from the driver
std::string ptxas = "ptxas";
std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
//CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
// // JIT compile source-code
// try{
// // use ptxas if present in PATH. Otherwise, use JIT from the driver
// std::string ptxas = "ptxas";
// std::string version;
// int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
// Use PTXAS via system call
if(use_system_ptxas){
// compile ptx with ptxas
char _fsrc[] = "/tmp/triton_k_XXXXXX";
char _flog[] = "/tmp/triton_l_XXXXXX";
mkstemp(_fsrc);
mkstemp(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
dispatch::cuModuleLoadData(&ret, cubin.c_str());
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
return ret;
}
// // Use PTXAS via system call
// if(use_system_ptxas){
// // compile ptx with ptxas
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
// char _flog[] = "/tmp/triton_l_XXXXXX";
// mkstemp(_fsrc);
// mkstemp(_flog);
// std::string fsrc = _fsrc;
// std::string flog = _flog;
// std::string fbin = fsrc + ".o";
// const char* _fbin = fbin.c_str();
// std::ofstream ofs(fsrc);
// ofs << ptx;
// ofs.close();
// std::string cmd;
// int err;
// cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
// err = system(cmd.c_str());
// CUmodule ret;
// std::ifstream _cubin(_fbin, std::ios::binary );
// std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
// _cubin.close();
// dispatch::cuModuleLoadData(&ret, cubin.c_str());
// unlink(_fsrc);
// unlink(_flog);
// unlink(_fbin);
// return ret;
// }
// Use PTXAS included in driver
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
CU_JIT_LOG_VERBOSE};
unsigned int errbufsize = 8192;
unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
CUmodule ret;
dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval);
return ret;
}
catch(exception::cuda::invalid_ptx const &){
std::cout << ptx << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
throw;
}
}
// // Use PTXAS included in driver
// CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
// CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
// CU_JIT_LOG_VERBOSE};
// unsigned int errbufsize = 8192;
// unsigned int logbufsize = 8192;
// char _err[errbufsize];
// char _log[logbufsize];
// void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
// CUmodule ret;
// dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval);
// return ret;
// }
// catch(exception::cuda::invalid_ptx const &){
// std::cout << ptx << std::endl;
// std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
// throw;
// }
//}
/* ------------------------ */
// HIP //