[GENERAL] Improved caching mechanism:
* Now computing hash in libtriton * Now only compiling a single pytorch hook per function signature
This commit is contained in:
committed by
Philippe Tillet
parent
30f77e9ec5
commit
dfb844bf41
@@ -74,6 +74,7 @@ public:
|
||||
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::context* context, const std::string& source);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
const std::string& source() const { return source_; }
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
|
@@ -7,6 +7,9 @@
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
class type;
|
||||
}
|
||||
|
||||
namespace driver{
|
||||
class buffer;
|
||||
@@ -26,6 +29,9 @@ enum arg_type {
|
||||
BUFFER_T
|
||||
};
|
||||
|
||||
arg_type convert(ir::type *ty);
|
||||
|
||||
|
||||
inline size_t size_of(arg_type ty){
|
||||
switch(ty){
|
||||
case INT1_T: return 1;
|
||||
|
@@ -8,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <set>
|
||||
// codegen
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/codegen/target.h"
|
||||
@@ -68,6 +69,11 @@ public:
|
||||
T D(const std::string& name) const {
|
||||
return convert<T>(defines.at(name));
|
||||
}
|
||||
bool operator<(const options_t& other) const {
|
||||
return std::make_pair(defines, num_warps) <
|
||||
std::make_pair(other.defines, other.num_warps);
|
||||
}
|
||||
std::string to_str() const;
|
||||
|
||||
std::map<std::string, std::string> defines;
|
||||
size_t num_warps;
|
||||
@@ -79,41 +85,63 @@ public:
|
||||
private:
|
||||
class caller {
|
||||
public:
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt_);
|
||||
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
|
||||
// constructors
|
||||
caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt);
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
|
||||
// serialization
|
||||
void write(std::ofstream& ofs);
|
||||
void read(driver::context* ctx, std::ifstream& ifs);
|
||||
// accessors
|
||||
const options_t opt() const { return opt_; }
|
||||
const driver::module* parent() const { return &*parent_; }
|
||||
// entry points
|
||||
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
std::shared_ptr<driver::module> parent_;
|
||||
std::vector<arg_type> param_tys_;
|
||||
options_t opt_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
private:
|
||||
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
|
||||
|
||||
private:
|
||||
// cache
|
||||
static std::string get_cache_prefix();
|
||||
// make
|
||||
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
||||
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
caller *make(driver::stream *stream, options_t opt);
|
||||
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
||||
// autotune
|
||||
function::cache_key_t get_key(driver::stream *stream, const std::vector<arg>& args);
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
|
||||
public:
|
||||
static std::string preheader();
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt = options_space_t());
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||
|
||||
private:
|
||||
std::map<std::string, std::vector<char>> cst_;
|
||||
// pre-compilation
|
||||
ir::context ctx_;
|
||||
std::string src_;
|
||||
options_space_t opt_space_;
|
||||
std::map<cache_key_t, caller> cache_;
|
||||
std::map<std::string, std::vector<char>> cst_;
|
||||
options_space_t opt_;
|
||||
std::set<options_t> compiled_;
|
||||
std::map<options_t, std::unique_ptr<caller>> callers_;
|
||||
// caching
|
||||
std::string cache_ref_;
|
||||
std::string cache_path_;
|
||||
std::map<cache_key_t, caller*> cache_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -61,6 +61,14 @@ namespace tools
|
||||
return (status==0 || errno==EEXIST)?0:-1;
|
||||
}
|
||||
|
||||
inline int mtime(std::string const & path)
|
||||
{
|
||||
struct stat st;
|
||||
if(stat(path.c_str(), &st) != 0)
|
||||
return 0;
|
||||
return st.st_mtime;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user