[GENERAL] Some minor fixups (#393)

* [RUNTIME] Now displaying error message when generated PTX is invalid

* [CODEGEN] Now converting `if` condition to bool implicitly
This commit is contained in:
Philippe Tillet
2021-12-17 18:06:21 -08:00
committed by GitHub
parent e062812969
commit 4e93b41c52
2 changed files with 81 additions and 60 deletions

View File

@@ -158,12 +158,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
} }
std::string ptx_to_cubin(const std::string& ptx, int cc) { std::string ptx_to_cubin(const std::string& ptx, int cc) {
std::string ptxas = "ptxas";
std::string version; std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; // search pathes for ptxas
if(!use_system_ptxas) std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
return ""; std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", version) == 0;
if(works)
working_ptxas.push_back(ptxas);
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// compile ptx with ptxas // compile ptx with ptxas
char _fsrc[] = "/tmp/triton_k_XXXXXX"; char _fsrc[] = "/tmp/triton_k_XXXXXX";
char _flog[] = "/tmp/triton_l_XXXXXX"; char _flog[] = "/tmp/triton_l_XXXXXX";
@@ -180,6 +193,11 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
int err; int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str()); err = system(cmd.c_str());
if(err != 0){
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret; CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary ); std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {}); std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
@@ -191,62 +209,62 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
return cubin; return cubin;
} }
CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { //CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
// JIT compile source-code // // JIT compile source-code
try{ // try{
// use ptxas if present in PATH. Otherwise, use JIT from the driver // // use ptxas if present in PATH. Otherwise, use JIT from the driver
std::string ptxas = "ptxas"; // std::string ptxas = "ptxas";
std::string version; // std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; // int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
// Use PTXAS via system call // // Use PTXAS via system call
if(use_system_ptxas){ // if(use_system_ptxas){
// compile ptx with ptxas // // compile ptx with ptxas
char _fsrc[] = "/tmp/triton_k_XXXXXX"; // char _fsrc[] = "/tmp/triton_k_XXXXXX";
char _flog[] = "/tmp/triton_l_XXXXXX"; // char _flog[] = "/tmp/triton_l_XXXXXX";
mkstemp(_fsrc); // mkstemp(_fsrc);
mkstemp(_flog); // mkstemp(_flog);
std::string fsrc = _fsrc; // std::string fsrc = _fsrc;
std::string flog = _flog; // std::string flog = _flog;
std::string fbin = fsrc + ".o"; // std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str(); // const char* _fbin = fbin.c_str();
std::ofstream ofs(fsrc); // std::ofstream ofs(fsrc);
ofs << ptx; // ofs << ptx;
ofs.close(); // ofs.close();
std::string cmd; // std::string cmd;
int err; // int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; // cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str()); // err = system(cmd.c_str());
CUmodule ret; // CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary ); // std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {}); // std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close(); // _cubin.close();
dispatch::cuModuleLoadData(&ret, cubin.c_str()); // dispatch::cuModuleLoadData(&ret, cubin.c_str());
unlink(_fsrc); // unlink(_fsrc);
unlink(_flog); // unlink(_flog);
unlink(_fbin); // unlink(_fbin);
return ret; // return ret;
} // }
// Use PTXAS included in driver // // Use PTXAS included in driver
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, // CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, // CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
CU_JIT_LOG_VERBOSE}; // CU_JIT_LOG_VERBOSE};
unsigned int errbufsize = 8192; // unsigned int errbufsize = 8192;
unsigned int logbufsize = 8192; // unsigned int logbufsize = 8192;
char _err[errbufsize]; // char _err[errbufsize];
char _log[logbufsize]; // char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; // void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
CUmodule ret; // CUmodule ret;
dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval); // dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval);
return ret; // return ret;
} // }
catch(exception::cuda::invalid_ptx const &){ // catch(exception::cuda::invalid_ptx const &){
std::cout << ptx << std::endl; // std::cout << ptx << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; // std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
throw; // throw;
} // }
} //}
/* ------------------------ */ /* ------------------------ */
// HIP // // HIP //

View File

@@ -248,7 +248,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
if self.is_triton_object(cond): if isinstance(cond, triton.language.block):
cond = cond.to(triton.language.int1, _builder=self.builder)
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
@@ -273,6 +274,8 @@ class CodeGenerator(ast.NodeVisitor):
self.module.seal_block(endif_bb) self.module.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb) self.builder.set_insert_block(endif_bb)
else: else:
if isinstance(cond, triton.language.constexpr):
cond = cond.value
if cond: if cond:
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
else: else: