[DRIVER] Added options for developers to cache PTX file so that ti can
be manually modified
This commit is contained in:
@@ -54,7 +54,7 @@ namespace sha1
|
||||
}
|
||||
}
|
||||
|
||||
void innerHash(unsigned int* result, unsigned int* w)
|
||||
inline void innerHash(unsigned int* result, unsigned int* w)
|
||||
{
|
||||
unsigned int a = result[0];
|
||||
unsigned int b = result[1];
|
||||
@@ -114,7 +114,7 @@ namespace sha1
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void calc(const void* src, const int bytelength, unsigned char* hash)
|
||||
inline void calc(const void* src, const int bytelength, unsigned char* hash)
|
||||
{
|
||||
// Init the result array.
|
||||
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
|
||||
@@ -170,7 +170,7 @@ namespace sha1
|
||||
}
|
||||
}
|
||||
|
||||
void toHexString(const unsigned char* hash, char* hexstring)
|
||||
inline void toHexString(const unsigned char* hash, char* hexstring)
|
||||
{
|
||||
const char hexDigits[] = { "0123456789abcdef" };
|
||||
|
||||
|
@@ -26,6 +26,9 @@
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/tools/sys/mkdir.hpp"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
@@ -346,7 +349,31 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
|
||||
llvm::raw_string_ostream oss(llir_);
|
||||
oss << *ll_module;
|
||||
oss.flush();
|
||||
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
|
||||
if(cache_path.empty())
|
||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||
else{
|
||||
tools::mkdir(cache_path);
|
||||
// update cache path to PTX file
|
||||
unsigned char hash[20];
|
||||
sha1::calc((void*)llir_.data(), llir_.size(), hash);
|
||||
char _hex[40];
|
||||
sha1::toHexString(hash, _hex);
|
||||
std::string hex(_hex, _hex + 40);
|
||||
cache_path += "/" + hex;
|
||||
// read
|
||||
std::ifstream ifs(cache_path);
|
||||
std::ostringstream _ptx;
|
||||
if(ifs)
|
||||
_ptx << ifs.rdbuf();
|
||||
ptx_ = _ptx.str();
|
||||
// compile and write-back if read empty
|
||||
if(ptx_.empty()){
|
||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||
std::ofstream ofs(cache_path);
|
||||
ofs << ptx_;
|
||||
}
|
||||
}
|
||||
init_from_ptx(ptx_);
|
||||
}
|
||||
|
||||
|
@@ -184,7 +184,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
||||
rt::add_arg(oss, *dlocks->cu());
|
||||
// function
|
||||
rt::function function(src::dot, opt, device);
|
||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl;
|
||||
// grid
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
||||
|
Reference in New Issue
Block a user