[DRIVER] Now giving the option to use system ptxas through environment variable (#123)
This commit is contained in:
committed by
Philippe Tillet
parent
80c86ecf4a
commit
b7b05a560e
@@ -60,7 +60,7 @@ public:
|
||||
// CUDA
|
||||
class cu_module: public module {
|
||||
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
|
||||
void init_from_ptx(const std::string& ptx);
|
||||
void init_from_ptx(const std::string& ptx, cu_device *device);
|
||||
|
||||
public:
|
||||
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
||||
|
@@ -281,53 +281,46 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
|
||||
return result;
|
||||
}
|
||||
|
||||
void cu_module::init_from_ptx(const std::string& ptx) {
|
||||
void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) {
|
||||
// JIT compile source-code
|
||||
// std::cout << ptx << std::endl;
|
||||
|
||||
try{
|
||||
// // compile ptx with ptxas
|
||||
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||
// char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||
// int fdsrc = mkstemp(_fsrc);
|
||||
// int fdlog = mkstemp(_flog);
|
||||
// std::string fsrc = _fsrc;
|
||||
// std::string flog = _flog;
|
||||
// std::ofstream ofs(fsrc);
|
||||
// ofs << ptx;
|
||||
// ofs.close();
|
||||
// std::string cmd;
|
||||
// int err;
|
||||
// driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
// err = system(cmd.c_str());
|
||||
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||
// std::ifstream file(flog);
|
||||
// std::string log;
|
||||
// if(file)
|
||||
// while (!file.eof()) log.push_back(file.get());
|
||||
// unlink(_fsrc);
|
||||
// unlink(_flog);
|
||||
std::string ptxas = tools::getenv("TRITON_PTXAS");
|
||||
|
||||
// std::smatch match;
|
||||
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
// spilled_ = 0;
|
||||
// while (std::regex_search (log,match,expr)){
|
||||
// spilled_ += std::stoi(match[1]);
|
||||
// log = match.suffix();
|
||||
// }
|
||||
// std::cout << log << std::endl;
|
||||
// std::cout << ptx_ << std::endl;
|
||||
// Use PTXAS via system call
|
||||
if(!ptxas.empty()){
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||
char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||
mkstemp(_fsrc);
|
||||
mkstemp(_flog);
|
||||
std::string fsrc = _fsrc;
|
||||
std::string flog = _flog;
|
||||
std::ofstream ofs(fsrc);
|
||||
ofs << ptx;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
std::string cc = std::to_string(device->compute_capability());
|
||||
cmd = "ptxas -v --gpu-name=sm_" + cc + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
err = system(cmd.c_str());
|
||||
dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||
unlink(_fsrc);
|
||||
unlink(_flog);
|
||||
return;
|
||||
}
|
||||
|
||||
CUlinkState link_state;
|
||||
dispatch::cuLinkCreate_v2(0, 0, 0, &link_state);
|
||||
dispatch::cuLinkAddData_v2(link_state, CU_JIT_INPUT_PTX, (void*)ptx_.data(), ptx_.size(), 0, 0, 0, 0);
|
||||
size_t cubin_size;
|
||||
void *cubin;
|
||||
dispatch::cuLinkComplete(link_state, &cubin, &cubin_size);
|
||||
dispatch::cuModuleLoadData(&*cu_, cubin);
|
||||
cubin_ = std::string((const char*)cubin, cubin_size);
|
||||
dispatch::cuLinkDestroy(link_state);
|
||||
// Use PTXAS included in driver
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||
CU_JIT_LOG_VERBOSE};
|
||||
unsigned int errbufsize = 8192;
|
||||
unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
|
||||
}
|
||||
catch(exception::cuda::invalid_ptx const &){
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
@@ -343,37 +336,12 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
|
||||
llvm::raw_string_ostream oss(llir_);
|
||||
oss << *ll_module;
|
||||
oss.flush();
|
||||
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
|
||||
if(cache_path.empty()){
|
||||
ptx_ = compile_llvm_module(ll_module.get(), device);
|
||||
}
|
||||
else{
|
||||
tools::mkdir(cache_path);
|
||||
// update cache path to PTX file
|
||||
unsigned char hash[20];
|
||||
sha1::calc((void*)llir_.data(), llir_.size(), hash);
|
||||
char _hex[40];
|
||||
sha1::toHexString(hash, _hex);
|
||||
std::string hex(_hex, _hex + 40);
|
||||
cache_path += "/" + hex;
|
||||
// read
|
||||
std::ifstream ifs(cache_path);
|
||||
std::ostringstream _ptx;
|
||||
if(ifs)
|
||||
_ptx << ifs.rdbuf();
|
||||
ptx_ = _ptx.str();
|
||||
// compile and write-back if read empty
|
||||
if(ptx_.empty()){
|
||||
ptx_ = compile_llvm_module(ll_module.get(), device);
|
||||
std::ofstream ofs(cache_path);
|
||||
ofs << ptx_;
|
||||
}
|
||||
}
|
||||
init_from_ptx(ptx_);
|
||||
init_from_ptx(ptx_, (driver::cu_device*)device);
|
||||
}
|
||||
|
||||
cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||
init_from_ptx(ptx_);
|
||||
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||
init_from_ptx(ptx_, (driver::cu_device*)device);
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
|
||||
|
@@ -153,7 +153,7 @@ def benchmark(M, N, provider):
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True)
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
|
||||
# %%
|
||||
# In the above plot, we can see that:
|
||||
|
Reference in New Issue
Block a user