[DRIVER] Simplified Driver API by substantially removing reliance on driver::context
This commit is contained in:
@@ -16,15 +16,14 @@ class stream;
|
||||
// Base
|
||||
class buffer : public polymorphic_resource<CUdeviceptr, host_buffer_t> {
|
||||
public:
|
||||
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
|
||||
buffer(size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(size_t size, host_buffer_t hst, bool take_ownership);
|
||||
uintptr_t addr_as_uintptr_t();
|
||||
static buffer* create(driver::context* ctx, size_t size);
|
||||
driver::context* context();
|
||||
size_t size();
|
||||
|
||||
protected:
|
||||
driver::context* context_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
@@ -32,15 +31,15 @@ protected:
|
||||
class host_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
host_buffer(driver::context* context, size_t size);
|
||||
host_buffer(size_t size);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
cu_buffer(driver::context* context, size_t size);
|
||||
cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership);
|
||||
cu_buffer(size_t size);
|
||||
cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership);
|
||||
void set_zero(triton::driver::stream *queue, size_t size);
|
||||
};
|
||||
|
||||
|
@@ -93,6 +93,7 @@ public:
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
|
||||
static CUresult cuStreamSynchronize(CUstream hStream);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
|
||||
static CUresult cuStreamDestroy_v2(CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
@@ -154,6 +155,7 @@ private:
|
||||
static void* cuModuleGetFunction_;
|
||||
static void* cuStreamSynchronize_;
|
||||
static void* cuStreamDestroy_v2_;
|
||||
static void* cuStreamGetCtx_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
|
@@ -35,26 +35,21 @@ protected:
|
||||
};
|
||||
|
||||
public:
|
||||
module(driver::context* ctx, CUmodule mod, bool has_ownership);
|
||||
module(driver::context* ctx, host_module_t mod, bool has_ownership);
|
||||
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
|
||||
driver::context* context() const;
|
||||
module(CUmodule mod, bool has_ownership);
|
||||
module(host_module_t mod, bool has_ownership);
|
||||
static module* create(driver::device* device, std::unique_ptr<llvm::Module> src);
|
||||
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string &features,
|
||||
file_type_t file_type);
|
||||
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
||||
|
||||
|
||||
protected:
|
||||
driver::context* ctx_;
|
||||
};
|
||||
|
||||
// CPU
|
||||
class host_module: public module{
|
||||
public:
|
||||
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
host_module(std::unique_ptr<llvm::Module> module);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
};
|
||||
|
||||
@@ -63,8 +58,8 @@ class cu_module: public module {
|
||||
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
|
||||
|
||||
public:
|
||||
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::context* context, const std::string& source);
|
||||
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(const std::string& source);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
const std::string& source() const { return source_; }
|
||||
|
||||
|
@@ -23,10 +23,10 @@ class cu_buffer;
|
||||
// Base
|
||||
class stream: public polymorphic_resource<CUstream, host_stream_t> {
|
||||
public:
|
||||
stream(driver::context *ctx, CUstream, bool has_ownership);
|
||||
stream(driver::context *ctx, host_stream_t, bool has_ownership);
|
||||
stream(CUstream, bool has_ownership);
|
||||
stream(host_stream_t, bool has_ownership);
|
||||
// factory
|
||||
static driver::stream* create(driver::context* ctx);
|
||||
static driver::stream* create(backend_t backend);
|
||||
// accessors
|
||||
driver::context* context() const;
|
||||
// methods
|
||||
@@ -39,16 +39,13 @@ public:
|
||||
{ write(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
template<class T> void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T>& x)
|
||||
{ read(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
|
||||
protected:
|
||||
driver::context *ctx_;
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_stream: public stream {
|
||||
public:
|
||||
// Constructors
|
||||
host_stream(driver::context *ctx);
|
||||
host_stream();
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
@@ -62,7 +59,7 @@ class cu_stream: public stream {
|
||||
public:
|
||||
// Constructors
|
||||
cu_stream(CUstream str, bool take_ownership);
|
||||
cu_stream(driver::context* context);
|
||||
cu_stream();
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
|
@@ -87,11 +87,11 @@ private:
|
||||
class caller {
|
||||
public:
|
||||
// constructors
|
||||
caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt);
|
||||
caller(std::ifstream& ifs, const options_t& opt);
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
|
||||
// serialization
|
||||
void write(std::ofstream& ofs);
|
||||
void read(driver::context* ctx, std::ifstream& ifs);
|
||||
void read(std::ifstream& ifs);
|
||||
// accessors
|
||||
const options_t opt() const { return opt_; }
|
||||
const driver::module* parent() const { return &*parent_; }
|
||||
@@ -101,7 +101,7 @@ private:
|
||||
|
||||
std::vector<int> retune() const { return retune_; }
|
||||
// entry points
|
||||
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const;
|
||||
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size, const std::map<std::string, std::vector<char>>& = {}) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
@@ -121,9 +121,9 @@ private:
|
||||
// make
|
||||
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
||||
void make(driver::stream *stream, options_t opt);
|
||||
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::device *device, const options_t &opt);
|
||||
void make(driver::device *device, options_t opt);
|
||||
void precompile(driver::device *device, const options_space_t& tuning_space);
|
||||
// autotune
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
||||
|
||||
@@ -132,10 +132,10 @@ public:
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
|
||||
void set_cst(const char* name, void* data, size_t n_bytes);
|
||||
std::string ptx(driver::stream *stream, const options_t& opt);
|
||||
std::string ptx(driver::device *device, const options_t& opt);
|
||||
|
||||
private:
|
||||
std::map<std::string, std::vector<char>> cst_;
|
||||
|
Reference in New Issue
Block a user