[DRIVER] Now giving the option to use system ptxas through environment variable (#123)

This commit is contained in:
Philippe Tillet
2021-06-11 13:48:11 -04:00
committed by Philippe Tillet
parent 80c86ecf4a
commit b7b05a560e
3 changed files with 40 additions and 72 deletions

View File

@@ -60,7 +60,7 @@ public:
// CUDA // CUDA
class cu_module: public module { class cu_module: public module {
std::string compile_llvm_module(llvm::Module* module, driver::device* device); std::string compile_llvm_module(llvm::Module* module, driver::device* device);
void init_from_ptx(const std::string& ptx); void init_from_ptx(const std::string& ptx, cu_device *device);
public: public:
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module); cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);

View File

@@ -281,53 +281,46 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
return result; return result;
} }
void cu_module::init_from_ptx(const std::string& ptx) { void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
// JIT compile source-code // JIT compile source-code
// std::cout << ptx << std::endl; // std::cout << ptx << std::endl;
try{ try{
// // compile ptx with ptxas std::string ptxas = tools::getenv("TRITON_PTXAS");
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
// char _flog[] = "/tmp/triton_l_XXXXXX";
// int fdsrc = mkstemp(_fsrc);
// int fdlog = mkstemp(_flog);
// std::string fsrc = _fsrc;
// std::string flog = _flog;
// std::ofstream ofs(fsrc);
// ofs << ptx;
// ofs.close();
// std::string cmd;
// int err;
// driver::cu_device* cu_device = (driver::cu_device*)device;
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
// err = system(cmd.c_str());
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
// std::ifstream file(flog);
// std::string log;
// if(file)
// while (!file.eof()) log.push_back(file.get());
// unlink(_fsrc);
// unlink(_flog);
// std::smatch match; // Use PTXAS via system call
// std::regex expr ("\\b([0-9]+) bytes spill"); if(!ptxas.empty()){
// spilled_ = 0; // compile ptx with ptxas
// while (std::regex_search (log,match,expr)){ char _fsrc[] = "/tmp/triton_k_XXXXXX";
// spilled_ += std::stoi(match[1]); char _flog[] = "/tmp/triton_l_XXXXXX";
// log = match.suffix(); mkstemp(_fsrc);
// } mkstemp(_flog);
// std::cout << log << std::endl; std::string fsrc = _fsrc;
// std::cout << ptx_ << std::endl; std::string flog = _flog;
std::ofstream ofs(fsrc);
ofs << ptx;
ofs.close();
std::string cmd;
int err;
std::string cc = std::to_string(device->compute_capability());
cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
unlink(_fsrc);
unlink(_flog);
return;
}
CUlinkState link_state; // Use PTXAS included in driver
dispatch::cuLinkCreate_v2(0, 0, 0, &link_state); CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
dispatch::cuLinkAddData_v2(link_state, CU_JIT_INPUT_PTX, (void*)ptx_.data(), ptx_.size(), 0, 0, 0, 0); CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
size_t cubin_size; CU_JIT_LOG_VERBOSE};
void *cubin; unsigned int errbufsize = 8192;
dispatch::cuLinkComplete(link_state, &cubin, &cubin_size); unsigned int logbufsize = 8192;
dispatch::cuModuleLoadData(&*cu_, cubin); char _err[errbufsize];
cubin_ = std::string((const char*)cubin, cubin_size); char _log[logbufsize];
dispatch::cuLinkDestroy(link_state); void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
} }
catch(exception::cuda::invalid_ptx const &){ catch(exception::cuda::invalid_ptx const &){
//#ifdef TRITON_LOG_PTX_ERROR //#ifdef TRITON_LOG_PTX_ERROR
@@ -343,37 +336,12 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
llvm::raw_string_ostream oss(llir_); llvm::raw_string_ostream oss(llir_);
oss << *ll_module; oss << *ll_module;
oss.flush(); oss.flush();
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH"); ptx_ = compile_llvm_module(ll_module.get(), device);
if(cache_path.empty()){ init_from_ptx(ptx_, (driver::cu_device*)device);
ptx_ = compile_llvm_module(ll_module.get(), device);
}
else{
tools::mkdir(cache_path);
// update cache path to PTX file
unsigned char hash[20];
sha1::calc((void*)llir_.data(), llir_.size(), hash);
char _hex[40];
sha1::toHexString(hash, _hex);
std::string hex(_hex, _hex + 40);
cache_path += "/" + hex;
// read
std::ifstream ifs(cache_path);
std::ostringstream _ptx;
if(ifs)
_ptx << ifs.rdbuf();
ptx_ = _ptx.str();
// compile and write-back if read empty
if(ptx_.empty()){
ptx_ = compile_llvm_module(ll_module.get(), device);
std::ofstream ofs(cache_path);
ofs << ptx_;
}
}
init_from_ptx(ptx_);
} }
cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){ cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
init_from_ptx(ptx_); init_from_ptx(ptx_, (driver::cu_device*)device);
} }
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{ std::unique_ptr<buffer> cu_module::symbol(const char *name) const{

View File

@@ -153,7 +153,7 @@ def benchmark(M, N, provider):
return gbps(ms), gbps(max_ms), gbps(min_ms) return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True) benchmark.run(show_plots=True, print_data=True)
# %% # %%
# In the above plot, we can see that: # In the above plot, we can see that: