From 4e93b41c528a84660dd42f96beb278fcd84948d2 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 17 Dec 2021 18:06:21 -0800 Subject: [PATCH] [GENERAL] Some minor fixups (#393) * [RUNTIME] Now displaying error message when generated PTX is invalid * [CODEGEN] Now converting `if` condition to bool implicitly --- lib/driver/llvm.cc | 136 +++++++++++++++++++++----------------- python/triton/code_gen.py | 5 +- 2 files changed, 81 insertions(+), 60 deletions(-) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index f3c76ce77..db64aa73b 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -158,12 +158,25 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ } std::string ptx_to_cubin(const std::string& ptx, int cc) { - std::string ptxas = "ptxas"; std::string version; - int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; - if(!use_system_ptxas) - return ""; - + // search pathes for ptxas + std::vector ptxas_prefixes = {"", "/usr/local/cuda/bin/"}; + std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH"); + if(!triton_ptxas.empty()) + ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas); + // see what path for ptxas are valid + std::vector working_ptxas; + for(std::string prefix: ptxas_prefixes){ + std::string ptxas = prefix + "ptxas"; + bool works = tools::exec(ptxas + " --version 2>&1", version) == 0; + if(works) + working_ptxas.push_back(ptxas); + } + // error if no working ptxas was found + if(working_ptxas.empty()) + throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH" + " but a working version could not be found."); + std::string ptxas = working_ptxas.front(); // compile ptx with ptxas char _fsrc[] = "/tmp/triton_k_XXXXXX"; char _flog[] = "/tmp/triton_l_XXXXXX"; @@ -180,6 +193,11 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { int err; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; err = system(cmd.c_str()); + if(err != 0){ + std::ifstream _log(_flog); + std::string log(std::istreambuf_iterator(_log), {}); + throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); + } CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); std::string cubin(std::istreambuf_iterator(_cubin), {}); @@ -191,62 +209,62 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { return cubin; } -CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { - // JIT compile source-code - try{ - // use ptxas if present in PATH. Otherwise, use JIT from the driver - std::string ptxas = "ptxas"; - std::string version; - int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; +//CUmodule ptx_to_cumodule(const std::string& ptx, int cc) { +// // JIT compile source-code +// try{ +// // use ptxas if present in PATH. Otherwise, use JIT from the driver +// std::string ptxas = "ptxas"; +// std::string version; +// int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0; - // Use PTXAS via system call - if(use_system_ptxas){ - // compile ptx with ptxas - char _fsrc[] = "/tmp/triton_k_XXXXXX"; - char _flog[] = "/tmp/triton_l_XXXXXX"; - mkstemp(_fsrc); - mkstemp(_flog); - std::string fsrc = _fsrc; - std::string flog = _flog; - std::string fbin = fsrc + ".o"; - const char* _fbin = fbin.c_str(); - std::ofstream ofs(fsrc); - ofs << ptx; - ofs.close(); - std::string cmd; - int err; - cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; - err = system(cmd.c_str()); - CUmodule ret; - std::ifstream _cubin(_fbin, std::ios::binary ); - std::string cubin(std::istreambuf_iterator(_cubin), {}); - _cubin.close(); - dispatch::cuModuleLoadData(&ret, cubin.c_str()); - unlink(_fsrc); - unlink(_flog); - unlink(_fbin); - return ret; - } +// // Use PTXAS via system call +// if(use_system_ptxas){ +// // compile ptx with ptxas +// char _fsrc[] = "/tmp/triton_k_XXXXXX"; +// char _flog[] = "/tmp/triton_l_XXXXXX"; +// mkstemp(_fsrc); +// mkstemp(_flog); +// std::string fsrc = _fsrc; +// std::string flog = _flog; +// std::string fbin = fsrc + ".o"; +// const char* _fbin = fbin.c_str(); +// std::ofstream ofs(fsrc); +// ofs << ptx; +// ofs.close(); +// std::string cmd; +// int err; +// cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; +// err = system(cmd.c_str()); +// CUmodule ret; +// std::ifstream _cubin(_fbin, std::ios::binary ); +// std::string cubin(std::istreambuf_iterator(_cubin), {}); +// _cubin.close(); +// dispatch::cuModuleLoadData(&ret, cubin.c_str()); +// unlink(_fsrc); +// unlink(_flog); +// unlink(_fbin); +// return ret; +// } - // Use PTXAS included in driver - CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, - CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, - CU_JIT_LOG_VERBOSE}; - unsigned int errbufsize = 8192; - unsigned int logbufsize = 8192; - char _err[errbufsize]; - char _log[logbufsize]; - void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; - CUmodule ret; - dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval); - return ret; - } - catch(exception::cuda::invalid_ptx const &){ - std::cout << ptx << std::endl; - std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; - throw; - } -} +// // Use PTXAS included in driver +// CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, +// CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, +// CU_JIT_LOG_VERBOSE}; +// unsigned int errbufsize = 8192; +// unsigned int logbufsize = 8192; +// char _err[errbufsize]; +// char _log[logbufsize]; +// void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; +// CUmodule ret; +// dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval); +// return ret; +// } +// catch(exception::cuda::invalid_ptx const &){ +// std::cout << ptx << std::endl; +// std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; +// throw; +// } +//} /* ------------------------ */ // HIP // diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 96948e360..d00a9d50c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -248,7 +248,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if self.is_triton_object(cond): + if isinstance(cond, triton.language.block): + cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None @@ -273,6 +274,8 @@ class CodeGenerator(ast.NodeVisitor): self.module.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: + if isinstance(cond, triton.language.constexpr): + cond = cond.value if cond: self.visit_compound_statement(node.body) else: