[DRIVER] Added options for developers to cache PTX file so that ti can
be manually modified
This commit is contained in:
@@ -54,7 +54,7 @@ namespace sha1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void innerHash(unsigned int* result, unsigned int* w)
|
inline void innerHash(unsigned int* result, unsigned int* w)
|
||||||
{
|
{
|
||||||
unsigned int a = result[0];
|
unsigned int a = result[0];
|
||||||
unsigned int b = result[1];
|
unsigned int b = result[1];
|
||||||
@@ -114,7 +114,7 @@ namespace sha1
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void calc(const void* src, const int bytelength, unsigned char* hash)
|
inline void calc(const void* src, const int bytelength, unsigned char* hash)
|
||||||
{
|
{
|
||||||
// Init the result array.
|
// Init the result array.
|
||||||
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
|
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
|
||||||
@@ -170,7 +170,7 @@ namespace sha1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void toHexString(const unsigned char* hash, char* hexstring)
|
inline void toHexString(const unsigned char* hash, char* hexstring)
|
||||||
{
|
{
|
||||||
const char hexDigits[] = { "0123456789abcdef" };
|
const char hexDigits[] = { "0123456789abcdef" };
|
||||||
|
|
||||||
|
@@ -26,6 +26,9 @@
|
|||||||
#include "triton/driver/module.h"
|
#include "triton/driver/module.h"
|
||||||
#include "triton/driver/context.h"
|
#include "triton/driver/context.h"
|
||||||
#include "triton/driver/error.h"
|
#include "triton/driver/error.h"
|
||||||
|
#include "triton/tools/sha1.hpp"
|
||||||
|
#include "triton/tools/sys/getenv.hpp"
|
||||||
|
#include "triton/tools/sys/mkdir.hpp"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/Verifier.h"
|
#include "llvm/IR/Verifier.h"
|
||||||
#include "llvm/IR/IRPrintingPasses.h"
|
#include "llvm/IR/IRPrintingPasses.h"
|
||||||
@@ -346,7 +349,31 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
|
|||||||
llvm::raw_string_ostream oss(llir_);
|
llvm::raw_string_ostream oss(llir_);
|
||||||
oss << *ll_module;
|
oss << *ll_module;
|
||||||
oss.flush();
|
oss.flush();
|
||||||
|
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
|
||||||
|
if(cache_path.empty())
|
||||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||||
|
else{
|
||||||
|
tools::mkdir(cache_path);
|
||||||
|
// update cache path to PTX file
|
||||||
|
unsigned char hash[20];
|
||||||
|
sha1::calc((void*)llir_.data(), llir_.size(), hash);
|
||||||
|
char _hex[40];
|
||||||
|
sha1::toHexString(hash, _hex);
|
||||||
|
std::string hex(_hex, _hex + 40);
|
||||||
|
cache_path += "/" + hex;
|
||||||
|
// read
|
||||||
|
std::ifstream ifs(cache_path);
|
||||||
|
std::ostringstream _ptx;
|
||||||
|
if(ifs)
|
||||||
|
_ptx << ifs.rdbuf();
|
||||||
|
ptx_ = _ptx.str();
|
||||||
|
// compile and write-back if read empty
|
||||||
|
if(ptx_.empty()){
|
||||||
|
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||||
|
std::ofstream ofs(cache_path);
|
||||||
|
ofs << ptx_;
|
||||||
|
}
|
||||||
|
}
|
||||||
init_from_ptx(ptx_);
|
init_from_ptx(ptx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -184,7 +184,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
|||||||
rt::add_arg(oss, *dlocks->cu());
|
rt::add_arg(oss, *dlocks->cu());
|
||||||
// function
|
// function
|
||||||
rt::function function(src::dot, opt, device);
|
rt::function function(src::dot, opt, device);
|
||||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl;
|
||||||
// grid
|
// grid
|
||||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
auto grid = [ceil, M, N](const rt::options_t& x) {
|
||||||
|
Reference in New Issue
Block a user