[GENERAL] Improved caching mechanism:

* Now computing hash in libtriton
* Now only compiling a single pytorch hook per function signature
This commit is contained in:
Philippe Tillet
2020-02-20 20:09:33 -05:00
committed by Philippe Tillet
parent 30f77e9ec5
commit dfb844bf41
14 changed files with 538 additions and 435 deletions

View File

@@ -74,6 +74,7 @@ public:
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
cu_module(driver::context* context, const std::string& source);
std::unique_ptr<buffer> symbol(const char * name) const;
const std::string& source() const { return source_; }
private:
std::string source_;

View File

@@ -7,6 +7,9 @@
#include <stdexcept>
namespace triton{
namespace ir{
class type;
}
namespace driver{
class buffer;
@@ -26,6 +29,9 @@ enum arg_type {
BUFFER_T
};
arg_type convert(ir::type *ty);
inline size_t size_of(arg_type ty){
switch(ty){
case INT1_T: return 1;

View File

@@ -8,6 +8,7 @@
#include <string>
#include <memory>
#include <functional>
#include <set>
// codegen
#include "triton/ir/context.h"
#include "triton/codegen/target.h"
@@ -68,6 +69,11 @@ public:
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
bool operator<(const options_t& other) const {
return std::make_pair(defines, num_warps) <
std::make_pair(other.defines, other.num_warps);
}
std::string to_str() const;
std::map<std::string, std::string> defines;
size_t num_warps;
@@ -79,41 +85,63 @@ public:
private:
class caller {
public:
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt_);
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
// constructors
caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt);
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
// serialization
void write(std::ofstream& ofs);
void read(driver::context* ctx, std::ifstream& ifs);
// accessors
const options_t opt() const { return opt_; }
const driver::module* parent() const { return &*parent_; }
// entry points
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
private:
std::shared_ptr<driver::kernel> bin_;
std::shared_ptr<driver::module> parent_;
std::vector<arg_type> param_tys_;
options_t opt_;
std::string name_;
};
private:
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
private:
// cache
static std::string get_cache_prefix();
// make
triton::lang::translation_unit *make_ast(const std::string &src);
std::unique_ptr<ir::module> make_ir(Parser &parser);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
caller *make(driver::stream *stream, options_t opt);
void precompile(driver::stream *stream, const options_space_t& tuning_space);
// autotune
function::cache_key_t get_key(driver::stream *stream, const std::vector<arg>& args);
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
public:
static std::string preheader();
public:
function(const std::string& src, const options_space_t& opt = options_space_t());
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
void set_cst(const std::string& name, void* data, size_t n_bytes);
private:
std::map<std::string, std::vector<char>> cst_;
// pre-compilation
ir::context ctx_;
std::string src_;
options_space_t opt_space_;
std::map<cache_key_t, caller> cache_;
std::map<std::string, std::vector<char>> cst_;
options_space_t opt_;
std::set<options_t> compiled_;
std::map<options_t, std::unique_ptr<caller>> callers_;
// caching
std::string cache_ref_;
std::string cache_path_;
std::map<cache_key_t, caller*> cache_;
};
}

View File

@@ -61,6 +61,14 @@ namespace tools
return (status==0 || errno==EEXIST)?0:-1;
}
inline int mtime(std::string const & path)
{
struct stat st;
if(stat(path.c_str(), &st) != 0)
return 0;
return st.st_mtime;
}
}
}