[RUNTIME] Lower-level interface for executing functions
This commit is contained in:
committed by
Philippe Tillet
parent
f4f216b88a
commit
acff1b5e05
@@ -32,7 +32,7 @@ public:
|
||||
driver::context* context() const;
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **extra = NULL) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -66,7 +66,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -80,7 +80,7 @@ public:
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -40,6 +40,7 @@ enum attribute_kind_t {
|
||||
noalias,
|
||||
aligned,
|
||||
multiple_of,
|
||||
retune,
|
||||
not_implemented
|
||||
};
|
||||
|
||||
@@ -113,6 +114,7 @@ public:
|
||||
// attributes
|
||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
// visitor
|
||||
|
@@ -64,7 +64,8 @@ public:
|
||||
ALIGNED,
|
||||
NOALIAS,
|
||||
READONLY,
|
||||
WRITEONLY
|
||||
WRITEONLY,
|
||||
RETUNE,
|
||||
};
|
||||
|
||||
KindT kind;
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#ifndef _TRITON_RUNTIME_FUNCTION_H_
|
||||
#define _TRITON_RUNTIME_FUNCTION_H_
|
||||
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
@@ -62,6 +62,7 @@ public:
|
||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||
std::vector<define_t> defines;
|
||||
std::vector<int> num_warps;
|
||||
std::vector<int> recompile_key;
|
||||
};
|
||||
|
||||
struct options_t {
|
||||
@@ -94,19 +95,25 @@ private:
|
||||
// accessors
|
||||
const options_t opt() const { return opt_; }
|
||||
const driver::module* parent() const { return &*parent_; }
|
||||
const driver::kernel* bin() const { return &*bin_; }
|
||||
arg_type param_ty(size_t i) const { return param_tys_.at(i);}
|
||||
const std::vector<arg_type>& param_tys() const { return param_tys_; }
|
||||
|
||||
std::vector<int> retune() const { return retune_; }
|
||||
// entry points
|
||||
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
|
||||
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
std::shared_ptr<driver::module> parent_;
|
||||
std::vector<arg_type> param_tys_;
|
||||
std::vector<int> retune_;
|
||||
options_t opt_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
private:
|
||||
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
|
||||
typedef std::pair<driver::device*, std::vector<int32_t>> cache_key_t;
|
||||
|
||||
private:
|
||||
// cache
|
||||
@@ -118,16 +125,15 @@ private:
|
||||
caller *make(driver::stream *stream, options_t opt);
|
||||
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
||||
// autotune
|
||||
function::cache_key_t get_key(driver::stream *stream, const std::vector<arg>& args);
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
||||
|
||||
public:
|
||||
static std::string preheader();
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||
|
||||
private:
|
||||
@@ -138,6 +144,8 @@ private:
|
||||
options_space_t opt_;
|
||||
std::set<options_t> compiled_;
|
||||
std::map<options_t, std::unique_ptr<caller>> callers_;
|
||||
std::vector<int> args_off_;
|
||||
size_t args_size_;
|
||||
// caching
|
||||
std::string cache_ref_;
|
||||
std::string cache_path_;
|
||||
|
Reference in New Issue
Block a user