[DRIVER] Added options for developers to cache PTX file so that ti can

be manually modified
This commit is contained in:
Philippe Tillet
2021-02-09 00:09:10 -05:00
parent 5e3c7f5a60
commit 3ca40b05cf
3 changed files with 32 additions and 5 deletions

View File

@@ -54,7 +54,7 @@ namespace sha1
} }
} }
void innerHash(unsigned int* result, unsigned int* w) inline void innerHash(unsigned int* result, unsigned int* w)
{ {
unsigned int a = result[0]; unsigned int a = result[0];
unsigned int b = result[1]; unsigned int b = result[1];
@@ -114,7 +114,7 @@ namespace sha1
} }
} // namespace } // namespace
void calc(const void* src, const int bytelength, unsigned char* hash) inline void calc(const void* src, const int bytelength, unsigned char* hash)
{ {
// Init the result array. // Init the result array.
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 }; unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
@@ -170,7 +170,7 @@ namespace sha1
} }
} }
void toHexString(const unsigned char* hash, char* hexstring) inline void toHexString(const unsigned char* hash, char* hexstring)
{ {
const char hexDigits[] = { "0123456789abcdef" }; const char hexDigits[] = { "0123456789abcdef" };

View File

@@ -26,6 +26,9 @@
#include "triton/driver/module.h" #include "triton/driver/module.h"
#include "triton/driver/context.h" #include "triton/driver/context.h"
#include "triton/driver/error.h" #include "triton/driver/error.h"
#include "triton/tools/sha1.hpp"
#include "triton/tools/sys/getenv.hpp"
#include "triton/tools/sys/mkdir.hpp"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h" #include "llvm/IR/Verifier.h"
#include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/IRPrintingPasses.h"
@@ -346,7 +349,31 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
llvm::raw_string_ostream oss(llir_); llvm::raw_string_ostream oss(llir_);
oss << *ll_module; oss << *ll_module;
oss.flush(); oss.flush();
ptx_ = compile_llvm_module(std::move(ll_module), device); std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
if(cache_path.empty())
ptx_ = compile_llvm_module(std::move(ll_module), device);
else{
tools::mkdir(cache_path);
// update cache path to PTX file
unsigned char hash[20];
sha1::calc((void*)llir_.data(), llir_.size(), hash);
char _hex[40];
sha1::toHexString(hash, _hex);
std::string hex(_hex, _hex + 40);
cache_path += "/" + hex;
// read
std::ifstream ifs(cache_path);
std::ostringstream _ptx;
if(ifs)
_ptx << ifs.rdbuf();
ptx_ = _ptx.str();
// compile and write-back if read empty
if(ptx_.empty()){
ptx_ = compile_llvm_module(std::move(ll_module), device);
std::ofstream ofs(cache_path);
ofs << ptx_;
}
}
init_from_ptx(ptx_); init_from_ptx(ptx_);
} }

View File

@@ -184,7 +184,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
rt::add_arg(oss, *dlocks->cu()); rt::add_arg(oss, *dlocks->cu());
// function // function
rt::function function(src::dot, opt, device); rt::function function(src::dot, opt, device);
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl; // std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl;
// grid // grid
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [ceil, M, N](const rt::options_t& x) { auto grid = [ceil, M, N](const rt::options_t& x) {