diff --git a/include/triton/tools/sha1.hpp b/include/triton/tools/sha1.hpp index 93b9869c4..630a3fd77 100644 --- a/include/triton/tools/sha1.hpp +++ b/include/triton/tools/sha1.hpp @@ -54,7 +54,7 @@ namespace sha1 } } - void innerHash(unsigned int* result, unsigned int* w) + inline void innerHash(unsigned int* result, unsigned int* w) { unsigned int a = result[0]; unsigned int b = result[1]; @@ -114,7 +114,7 @@ namespace sha1 } } // namespace - void calc(const void* src, const int bytelength, unsigned char* hash) + inline void calc(const void* src, const int bytelength, unsigned char* hash) { // Init the result array. unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 }; @@ -170,7 +170,7 @@ namespace sha1 } } - void toHexString(const unsigned char* hash, char* hexstring) + inline void toHexString(const unsigned char* hash, char* hexstring) { const char hexDigits[] = { "0123456789abcdef" }; diff --git a/lib/driver/module.cc b/lib/driver/module.cc index f0035fc84..7c7286d3e 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -26,6 +26,9 @@ #include "triton/driver/module.h" #include "triton/driver/context.h" #include "triton/driver/error.h" +#include "triton/tools/sha1.hpp" +#include "triton/tools/sys/getenv.hpp" +#include "triton/tools/sys/mkdir.hpp" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/IRPrintingPasses.h" @@ -346,7 +349,31 @@ cu_module::cu_module(driver::device* device, std::unique_ptr ll_mo llvm::raw_string_ostream oss(llir_); oss << *ll_module; oss.flush(); - ptx_ = compile_llvm_module(std::move(ll_module), device); + std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH"); + if(cache_path.empty()) + ptx_ = compile_llvm_module(std::move(ll_module), device); + else{ + tools::mkdir(cache_path); + // update cache path to PTX file + unsigned char hash[20]; + sha1::calc((void*)llir_.data(), llir_.size(), hash); + char _hex[40]; + sha1::toHexString(hash, _hex); + std::string hex(_hex, _hex + 40); + cache_path += "/" + hex; + // read + std::ifstream ifs(cache_path); + std::ostringstream _ptx; + if(ifs) + _ptx << ifs.rdbuf(); + ptx_ = _ptx.str(); + // compile and write-back if read empty + if(ptx_.empty()){ + ptx_ = compile_llvm_module(std::move(ll_module), device); + std::ofstream ofs(cache_path); + ofs << ptx_; + } + } init_from_ptx(ptx_); } diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc index c878a5a54..1188b222d 100644 --- a/tutorials/01-matmul.cc +++ b/tutorials/01-matmul.cc @@ -184,7 +184,7 @@ float triton_dot(drv::context* context, drv::stream* stream, rt::add_arg(oss, *dlocks->cu()); // function rt::function function(src::dot, opt, device); -// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl; +// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl; // grid auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; auto grid = [ceil, M, N](const rt::options_t& x) {