[DRIVER] Simplified Driver API by substantially removing reliance on driver::context

This commit is contained in:
Philippe Tillet
2020-11-26 00:27:12 -05:00
parent f42b04d925
commit 4f08d87fed
24 changed files with 167 additions and 194 deletions

View File

@@ -16,15 +16,14 @@ class stream;
// Base
class buffer : public polymorphic_resource<CUdeviceptr, host_buffer_t> {
public:
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
buffer(size_t size, CUdeviceptr cl, bool take_ownership);
buffer(size_t size, host_buffer_t hst, bool take_ownership);
uintptr_t addr_as_uintptr_t();
static buffer* create(driver::context* ctx, size_t size);
driver::context* context();
size_t size();
protected:
driver::context* context_;
size_t size_;
};
@@ -32,15 +31,15 @@ protected:
class host_buffer: public buffer
{
public:
host_buffer(driver::context* context, size_t size);
host_buffer(size_t size);
};
// CUDA
class cu_buffer: public buffer
{
public:
cu_buffer(driver::context* context, size_t size);
cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership);
cu_buffer(size_t size);
cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership);
void set_zero(triton::driver::stream *queue, size_t size);
};

View File

@@ -93,6 +93,7 @@ public:
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
@@ -154,6 +155,7 @@ private:
static void* cuModuleGetFunction_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuEventDestroy_v2_;
static void* cuMemAlloc_v2_;
static void* cuPointerGetAttribute_;

View File

@@ -35,26 +35,21 @@ protected:
};
public:
module(driver::context* ctx, CUmodule mod, bool has_ownership);
module(driver::context* ctx, host_module_t mod, bool has_ownership);
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
driver::context* context() const;
module(CUmodule mod, bool has_ownership);
module(host_module_t mod, bool has_ownership);
static module* create(driver::device* device, std::unique_ptr<llvm::Module> src);
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string &features,
file_type_t file_type);
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
protected:
driver::context* ctx_;
};
// CPU
class host_module: public module{
public:
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
host_module(std::unique_ptr<llvm::Module> module);
std::unique_ptr<buffer> symbol(const char * name) const;
};
@@ -63,8 +58,8 @@ class cu_module: public module {
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
public:
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
cu_module(driver::context* context, const std::string& source);
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
cu_module(const std::string& source);
std::unique_ptr<buffer> symbol(const char * name) const;
const std::string& source() const { return source_; }

View File

@@ -23,10 +23,10 @@ class cu_buffer;
// Base
class stream: public polymorphic_resource<CUstream, host_stream_t> {
public:
stream(driver::context *ctx, CUstream, bool has_ownership);
stream(driver::context *ctx, host_stream_t, bool has_ownership);
stream(CUstream, bool has_ownership);
stream(host_stream_t, bool has_ownership);
// factory
static driver::stream* create(driver::context* ctx);
static driver::stream* create(backend_t backend);
// accessors
driver::context* context() const;
// methods
@@ -39,16 +39,13 @@ public:
{ write(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
template<class T> void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T>& x)
{ read(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
protected:
driver::context *ctx_;
};
// Host
class host_stream: public stream {
public:
// Constructors
host_stream(driver::context *ctx);
host_stream();
// Overridden
void synchronize();
@@ -62,7 +59,7 @@ class cu_stream: public stream {
public:
// Constructors
cu_stream(CUstream str, bool take_ownership);
cu_stream(driver::context* context);
cu_stream();
// Overridden
void synchronize();

View File

@@ -87,11 +87,11 @@ private:
class caller {
public:
// constructors
caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt);
caller(std::ifstream& ifs, const options_t& opt);
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
// serialization
void write(std::ofstream& ofs);
void read(driver::context* ctx, std::ifstream& ifs);
void read(std::ifstream& ifs);
// accessors
const options_t opt() const { return opt_; }
const driver::module* parent() const { return &*parent_; }
@@ -101,7 +101,7 @@ private:
std::vector<int> retune() const { return retune_; }
// entry points
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const;
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size, const std::map<std::string, std::vector<char>>& = {}) const;
private:
std::shared_ptr<driver::kernel> bin_;
@@ -121,9 +121,9 @@ private:
// make
triton::lang::translation_unit *make_ast(const std::string &src);
std::unique_ptr<ir::module> make_ir(Parser &parser);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
void make(driver::stream *stream, options_t opt);
void precompile(driver::stream *stream, const options_space_t& tuning_space);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::device *device, const options_t &opt);
void make(driver::device *device, options_t opt);
void precompile(driver::device *device, const options_space_t& tuning_space);
// autotune
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
@@ -132,10 +132,10 @@ public:
public:
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
void set_cst(const char* name, void* data, size_t n_bytes);
std::string ptx(driver::stream *stream, const options_t& opt);
std::string ptx(driver::device *device, const options_t& opt);
private:
std::map<std::string, std::vector<char>> cst_;