[DRIVER] Now giving the option to use system ptxas through environment variable (#123)

This commit is contained in:
Philippe Tillet
2021-06-11 13:48:11 -04:00
committed by Philippe Tillet
parent 80c86ecf4a
commit b7b05a560e
3 changed files with 40 additions and 72 deletions

View File

@@ -60,7 +60,7 @@ public:
// CUDA
class cu_module: public module {
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
void init_from_ptx(const std::string& ptx);
void init_from_ptx(const std::string& ptx, cu_device *device);
public:
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);

View File

@@ -281,53 +281,46 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
return result;
}
void cu_module::init_from_ptx(const std::string& ptx) {
void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
// JIT compile source-code
// std::cout << ptx << std::endl;
try{
// // compile ptx with ptxas
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
// char _flog[] = "/tmp/triton_l_XXXXXX";
// int fdsrc = mkstemp(_fsrc);
// int fdlog = mkstemp(_flog);
// std::string fsrc = _fsrc;
// std::string flog = _flog;
// std::ofstream ofs(fsrc);
// ofs << ptx;
// ofs.close();
// std::string cmd;
// int err;
// driver::cu_device* cu_device = (driver::cu_device*)device;
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
// err = system(cmd.c_str());
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
// std::ifstream file(flog);
// std::string log;
// if(file)
// while (!file.eof()) log.push_back(file.get());
// unlink(_fsrc);
// unlink(_flog);
std::string ptxas = tools::getenv("TRITON_PTXAS");
// std::smatch match;
// std::regex expr ("\\b([0-9]+) bytes spill");
// spilled_ = 0;
// while (std::regex_search (log,match,expr)){
// spilled_ += std::stoi(match[1]);
// log = match.suffix();
// }
// std::cout << log << std::endl;
// std::cout << ptx_ << std::endl;
// Use PTXAS via system call
if(!ptxas.empty()){
// compile ptx with ptxas
char _fsrc[] = "/tmp/triton_k_XXXXXX";
char _flog[] = "/tmp/triton_l_XXXXXX";
mkstemp(_fsrc);
mkstemp(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::ofstream ofs(fsrc);
ofs << ptx;
ofs.close();
std::string cmd;
int err;
std::string cc = std::to_string(device->compute_capability());
cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
unlink(_fsrc);
unlink(_flog);
return;
}
CUlinkState link_state;
dispatch::cuLinkCreate_v2(0, 0, 0, &link_state);
dispatch::cuLinkAddData_v2(link_state, CU_JIT_INPUT_PTX, (void*)ptx_.data(), ptx_.size(), 0, 0, 0, 0);
size_t cubin_size;
void *cubin;
dispatch::cuLinkComplete(link_state, &cubin, &cubin_size);
dispatch::cuModuleLoadData(&*cu_, cubin);
cubin_ = std::string((const char*)cubin, cubin_size);
dispatch::cuLinkDestroy(link_state);
// Use PTXAS included in driver
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
CU_JIT_LOG_VERBOSE};
unsigned int errbufsize = 8192;
unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
}
catch(exception::cuda::invalid_ptx const &){
//#ifdef TRITON_LOG_PTX_ERROR
@@ -343,37 +336,12 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
llvm::raw_string_ostream oss(llir_);
oss << *ll_module;
oss.flush();
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
if(cache_path.empty()){
ptx_ = compile_llvm_module(ll_module.get(), device);
}
else{
tools::mkdir(cache_path);
// update cache path to PTX file
unsigned char hash[20];
sha1::calc((void*)llir_.data(), llir_.size(), hash);
char _hex[40];
sha1::toHexString(hash, _hex);
std::string hex(_hex, _hex + 40);
cache_path += "/" + hex;
// read
std::ifstream ifs(cache_path);
std::ostringstream _ptx;
if(ifs)
_ptx << ifs.rdbuf();
ptx_ = _ptx.str();
// compile and write-back if read empty
if(ptx_.empty()){
ptx_ = compile_llvm_module(ll_module.get(), device);
std::ofstream ofs(cache_path);
ofs << ptx_;
}
}
init_from_ptx(ptx_);
init_from_ptx(ptx_, (driver::cu_device*)device);
}
cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){
init_from_ptx(ptx_);
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
init_from_ptx(ptx_, (driver::cu_device*)device);
}
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{

View File

@@ -153,7 +153,7 @@ def benchmark(M, N, provider):
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(show_plots=True)
benchmark.run(show_plots=True, print_data=True)
# %%
# In the above plot, we can see that: