[DRIVER] Removed obsolete SetArg
This commit is contained in:
@@ -36,18 +36,8 @@ public:
|
|||||||
|
|
||||||
// CUDA
|
// CUDA
|
||||||
class cu_context: public context {
|
class cu_context: public context {
|
||||||
public:
|
|
||||||
class context_switcher{
|
|
||||||
public:
|
|
||||||
context_switcher(driver::context const & ctx);
|
|
||||||
~context_switcher();
|
|
||||||
private:
|
|
||||||
driver::cu_context const & ctx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static CUdevice get_device_of(CUcontext);
|
static CUdevice get_device_of(CUcontext);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
//Constructors
|
//Constructors
|
||||||
cu_context(CUcontext cu, bool take_ownership = true);
|
cu_context(CUcontext cu, bool take_ownership = true);
|
||||||
|
@@ -25,14 +25,8 @@ class kernel: public polymorphic_resource<CUfunction, host_function_t> {
|
|||||||
public:
|
public:
|
||||||
kernel(driver::module* program, CUfunction fn, bool has_ownership);
|
kernel(driver::module* program, CUfunction fn, bool has_ownership);
|
||||||
kernel(driver::module* program, host_function_t fn, bool has_ownership);
|
kernel(driver::module* program, host_function_t fn, bool has_ownership);
|
||||||
// Getters
|
|
||||||
driver::module* module();
|
driver::module* module();
|
||||||
// Factory methods
|
|
||||||
static kernel* create(driver::module* program, const char* name);
|
static kernel* create(driver::module* program, const char* name);
|
||||||
// Arguments setters
|
|
||||||
virtual void setArg(unsigned int index, std::size_t size, void* ptr) = 0;
|
|
||||||
virtual void setArg(unsigned int index, buffer *) = 0;
|
|
||||||
template<class T> void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); }
|
|
||||||
private:
|
private:
|
||||||
driver::module* program_;
|
driver::module* program_;
|
||||||
};
|
};
|
||||||
@@ -42,14 +36,6 @@ class host_kernel: public kernel {
|
|||||||
public:
|
public:
|
||||||
//Constructors
|
//Constructors
|
||||||
host_kernel(driver::module* program, const char* name);
|
host_kernel(driver::module* program, const char* name);
|
||||||
// Arguments setters
|
|
||||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
|
||||||
void setArg(unsigned int index, driver::buffer* buffer);
|
|
||||||
// Params
|
|
||||||
const std::vector<void*>& params();
|
|
||||||
private:
|
|
||||||
std::vector<std::shared_ptr<void> > params_store_;
|
|
||||||
std::vector<void*> params_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// CUDA
|
// CUDA
|
||||||
@@ -57,15 +43,6 @@ class cu_kernel: public kernel {
|
|||||||
public:
|
public:
|
||||||
//Constructors
|
//Constructors
|
||||||
cu_kernel(driver::module* program, const char * name);
|
cu_kernel(driver::module* program, const char * name);
|
||||||
// Arguments setters
|
|
||||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
|
||||||
void setArg(unsigned int index, driver::buffer* buffer);
|
|
||||||
//Arguments getters
|
|
||||||
void* const* cu_params() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<std::shared_ptr<void> > cu_params_store_;
|
|
||||||
std::vector<void*> cu_params_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -27,8 +27,6 @@ public:
|
|||||||
stream(host_stream_t, bool has_ownership);
|
stream(host_stream_t, bool has_ownership);
|
||||||
// factory
|
// factory
|
||||||
static driver::stream* create(backend_t backend);
|
static driver::stream* create(backend_t backend);
|
||||||
// accessors
|
|
||||||
driver::context* context() const;
|
|
||||||
// methods
|
// methods
|
||||||
virtual void synchronize() = 0;
|
virtual void synchronize() = 0;
|
||||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0;
|
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0;
|
||||||
|
@@ -94,17 +94,6 @@ host_context::host_context(driver::device* dev): context(dev, host_context_t(),
|
|||||||
// CUDA //
|
// CUDA //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
|
|
||||||
// RAII context switcher
|
|
||||||
cu_context::context_switcher::context_switcher(const context &ctx): ctx_((const cu_context&)ctx) {
|
|
||||||
dispatch::cuCtxPushCurrent_v2(*ctx_.cu());
|
|
||||||
}
|
|
||||||
|
|
||||||
cu_context::context_switcher::~context_switcher() {
|
|
||||||
CUcontext tmp;
|
|
||||||
dispatch::cuCtxPopCurrent_v2(&tmp);
|
|
||||||
assert(tmp==*ctx_.cu() && "Switching back to invalid context!");
|
|
||||||
}
|
|
||||||
|
|
||||||
// import CUdevice
|
// import CUdevice
|
||||||
CUdevice cu_context::get_device_of(CUcontext context){
|
CUdevice cu_context::get_device_of(CUcontext context){
|
||||||
dispatch::cuCtxPushCurrent_v2(context);
|
dispatch::cuCtxPushCurrent_v2(context);
|
||||||
|
@@ -64,59 +64,15 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
|
|||||||
hst_->fn = program->hst()->functions.at(name);
|
hst_->fn = program->hst()->functions.at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
|
||||||
if(index + 1> params_store_.size()){
|
|
||||||
params_store_.resize(index+1);
|
|
||||||
params_.resize(index+1);
|
|
||||||
}
|
|
||||||
params_store_[index].reset(malloc(size), free);
|
|
||||||
memcpy(params_store_[index].get(), ptr, size);
|
|
||||||
params_[index] = params_store_[index].get();
|
|
||||||
}
|
|
||||||
|
|
||||||
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
|
|
||||||
if(buffer)
|
|
||||||
kernel::setArg(index, (void*)buffer->hst()->data);
|
|
||||||
else
|
|
||||||
kernel::setArg(index, (std::ptrdiff_t)0);
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<void *> &host_kernel::params(){
|
|
||||||
return params_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
// CUDA //
|
// CUDA //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
|
|
||||||
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
|
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
|
||||||
cu_params_store_.reserve(64);
|
|
||||||
cu_params_.reserve(64);
|
|
||||||
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
|
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
|
||||||
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
|
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
|
||||||
if(index + 1> cu_params_store_.size()){
|
|
||||||
cu_params_store_.resize(index+1);
|
|
||||||
cu_params_.resize(index+1);
|
|
||||||
}
|
|
||||||
cu_params_store_[index].reset(malloc(size), free);
|
|
||||||
memcpy(cu_params_store_[index].get(), ptr, size);
|
|
||||||
cu_params_[index] = cu_params_store_[index].get();
|
|
||||||
}
|
|
||||||
|
|
||||||
void cu_kernel::setArg(unsigned int index, driver::buffer* data){
|
|
||||||
if(data)
|
|
||||||
kernel::setArg(index, *data->cu());
|
|
||||||
else
|
|
||||||
kernel::setArg(index, (std::ptrdiff_t)0);
|
|
||||||
}
|
|
||||||
|
|
||||||
void* const* cu_kernel::cu_params() const
|
|
||||||
{ return cu_params_.data(); }
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -193,9 +193,6 @@ inline void cublasGemm(cublasDataType_t dtype,
|
|||||||
triton::driver::buffer* B, int32_t ldb,
|
triton::driver::buffer* B, int32_t ldb,
|
||||||
void* beta, triton::driver::buffer* C, int32_t ldc,
|
void* beta, triton::driver::buffer* C, int32_t ldc,
|
||||||
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
|
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
|
||||||
|
|
||||||
// switch triton context
|
|
||||||
triton::driver::cu_context::context_switcher scope(*stream->context());
|
|
||||||
// get handle
|
// get handle
|
||||||
static cublasHandle_t handle = cublasGetHandle(stream);
|
static cublasHandle_t handle = cublasGetHandle(stream);
|
||||||
// set math mode
|
// set math mode
|
||||||
|
Reference in New Issue
Block a user