diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index ddfdef380..d746ba6d3 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -60,7 +60,7 @@ public: // CUDA class cu_module: public module { std::string compile_llvm_module(llvm::Module* module, driver::device* device); - void init_from_ptx(const std::string& ptx); + void init_from_ptx(const std::string& ptx, cu_device *device); public: cu_module(driver::device* device, std::unique_ptr module); diff --git a/lib/driver/module.cc b/lib/driver/module.cc index b1a054e85..c5d04fa4f 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -281,53 +281,46 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* return result; } -void cu_module::init_from_ptx(const std::string& ptx) { +void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) { // JIT compile source-code // std::cout << ptx << std::endl; try{ -// // compile ptx with ptxas -// char _fsrc[] = "/tmp/triton_k_XXXXXX"; -// char _flog[] = "/tmp/triton_l_XXXXXX"; -// int fdsrc = mkstemp(_fsrc); -// int fdlog = mkstemp(_flog); -// std::string fsrc = _fsrc; -// std::string flog = _flog; -// std::ofstream ofs(fsrc); -// ofs << ptx; -// ofs.close(); -// std::string cmd; -// int err; -// driver::cu_device* cu_device = (driver::cu_device*)device; -// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; -// err = system(cmd.c_str()); -// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str()); -// std::ifstream file(flog); -// std::string log; -// if(file) -// while (!file.eof()) log.push_back(file.get()); -// unlink(_fsrc); -// unlink(_flog); + std::string ptxas = tools::getenv("TRITON_PTXAS"); -// std::smatch match; -// std::regex expr ("\\b([0-9]+) bytes spill"); -// spilled_ = 0; -// while (std::regex_search (log,match,expr)){ -// spilled_ += std::stoi(match[1]); -// log = match.suffix(); -// } -// std::cout << log << std::endl; -// std::cout << ptx_ << std::endl; + // Use PTXAS via system call + if(!ptxas.empty()){ + // compile ptx with ptxas + char _fsrc[] = "/tmp/triton_k_XXXXXX"; + char _flog[] = "/tmp/triton_l_XXXXXX"; + mkstemp(_fsrc); + mkstemp(_flog); + std::string fsrc = _fsrc; + std::string flog = _flog; + std::ofstream ofs(fsrc); + ofs << ptx; + ofs.close(); + std::string cmd; + int err; + std::string cc = std::to_string(device->compute_capability()); + cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; + err = system(cmd.c_str()); + dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str()); + unlink(_fsrc); + unlink(_flog); + return; + } - CUlinkState link_state; - dispatch::cuLinkCreate_v2(0, 0, 0, &link_state); - dispatch::cuLinkAddData_v2(link_state, CU_JIT_INPUT_PTX, (void*)ptx_.data(), ptx_.size(), 0, 0, 0, 0); - size_t cubin_size; - void *cubin; - dispatch::cuLinkComplete(link_state, &cubin, &cubin_size); - dispatch::cuModuleLoadData(&*cu_, cubin); - cubin_ = std::string((const char*)cubin, cubin_size); - dispatch::cuLinkDestroy(link_state); + // Use PTXAS included in driver + CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, + CU_JIT_LOG_VERBOSE}; + unsigned int errbufsize = 8192; + unsigned int logbufsize = 8192; + char _err[errbufsize]; + char _log[logbufsize]; + void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; + dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval); } catch(exception::cuda::invalid_ptx const &){ //#ifdef TRITON_LOG_PTX_ERROR @@ -343,37 +336,12 @@ cu_module::cu_module(driver::device* device, std::unique_ptr ll_mo llvm::raw_string_ostream oss(llir_); oss << *ll_module; oss.flush(); - std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH"); - if(cache_path.empty()){ - ptx_ = compile_llvm_module(ll_module.get(), device); - } - else{ - tools::mkdir(cache_path); - // update cache path to PTX file - unsigned char hash[20]; - sha1::calc((void*)llir_.data(), llir_.size(), hash); - char _hex[40]; - sha1::toHexString(hash, _hex); - std::string hex(_hex, _hex + 40); - cache_path += "/" + hex; - // read - std::ifstream ifs(cache_path); - std::ostringstream _ptx; - if(ifs) - _ptx << ifs.rdbuf(); - ptx_ = _ptx.str(); - // compile and write-back if read empty - if(ptx_.empty()){ - ptx_ = compile_llvm_module(ll_module.get(), device); - std::ofstream ofs(cache_path); - ofs << ptx_; - } - } - init_from_ptx(ptx_); + ptx_ = compile_llvm_module(ll_module.get(), device); + init_from_ptx(ptx_, (driver::cu_device*)device); } -cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){ - init_from_ptx(ptx_); +cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){ + init_from_ptx(ptx_, (driver::cu_device*)device); } std::unique_ptr cu_module::symbol(const char *name) const{ diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index dd93f8d83..91d2f8bdc 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -153,7 +153,7 @@ def benchmark(M, N, provider): return gbps(ms), gbps(max_ms), gbps(min_ms) -benchmark.run(show_plots=True) +benchmark.run(show_plots=True, print_data=True) # %% # In the above plot, we can see that: