diff --git a/include/triton/driver/context.h b/include/triton/driver/context.h index d893ee87a..3d542f38c 100755 --- a/include/triton/driver/context.h +++ b/include/triton/driver/context.h @@ -36,18 +36,8 @@ public: // CUDA class cu_context: public context { -public: - class context_switcher{ - public: - context_switcher(driver::context const & ctx); - ~context_switcher(); - private: - driver::cu_context const & ctx_; - }; - private: static CUdevice get_device_of(CUcontext); - public: //Constructors cu_context(CUcontext cu, bool take_ownership = true); diff --git a/include/triton/driver/kernel.h b/include/triton/driver/kernel.h index 0aa7efc5e..ca5eebd6e 100755 --- a/include/triton/driver/kernel.h +++ b/include/triton/driver/kernel.h @@ -25,14 +25,8 @@ class kernel: public polymorphic_resource { public: kernel(driver::module* program, CUfunction fn, bool has_ownership); kernel(driver::module* program, host_function_t fn, bool has_ownership); - // Getters driver::module* module(); - // Factory methods static kernel* create(driver::module* program, const char* name); - // Arguments setters - virtual void setArg(unsigned int index, std::size_t size, void* ptr) = 0; - virtual void setArg(unsigned int index, buffer *) = 0; - template void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); } private: driver::module* program_; }; @@ -42,14 +36,6 @@ class host_kernel: public kernel { public: //Constructors host_kernel(driver::module* program, const char* name); - // Arguments setters - void setArg(unsigned int index, std::size_t size, void* ptr); - void setArg(unsigned int index, driver::buffer* buffer); - // Params - const std::vector& params(); -private: - std::vector > params_store_; - std::vector params_; }; // CUDA @@ -57,15 +43,6 @@ class cu_kernel: public kernel { public: //Constructors cu_kernel(driver::module* program, const char * name); - // Arguments setters - void setArg(unsigned int index, std::size_t size, void* ptr); - void setArg(unsigned int index, driver::buffer* buffer); - //Arguments getters - void* const* cu_params() const; - -private: - std::vector > cu_params_store_; - std::vector cu_params_; }; } diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index 9874d2a60..0d45975ff 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -27,8 +27,6 @@ public: stream(host_stream_t, bool has_ownership); // factory static driver::stream* create(backend_t backend); - // accessors - driver::context* context() const; // methods virtual void synchronize() = 0; virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0; diff --git a/lib/driver/context.cc b/lib/driver/context.cc index b6091e0ac..bf403e32e 100755 --- a/lib/driver/context.cc +++ b/lib/driver/context.cc @@ -94,17 +94,6 @@ host_context::host_context(driver::device* dev): context(dev, host_context_t(), // CUDA // /* ------------------------ */ -// RAII context switcher -cu_context::context_switcher::context_switcher(const context &ctx): ctx_((const cu_context&)ctx) { - dispatch::cuCtxPushCurrent_v2(*ctx_.cu()); -} - -cu_context::context_switcher::~context_switcher() { - CUcontext tmp; - dispatch::cuCtxPopCurrent_v2(&tmp); - assert(tmp==*ctx_.cu() && "Switching back to invalid context!"); -} - // import CUdevice CUdevice cu_context::get_device_of(CUcontext context){ dispatch::cuCtxPushCurrent_v2(context); diff --git a/lib/driver/kernel.cc b/lib/driver/kernel.cc index 05d8b7b3c..2b340bc4a 100755 --- a/lib/driver/kernel.cc +++ b/lib/driver/kernel.cc @@ -64,59 +64,15 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog hst_->fn = program->hst()->functions.at(name); } -void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){ - if(index + 1> params_store_.size()){ - params_store_.resize(index+1); - params_.resize(index+1); - } - params_store_[index].reset(malloc(size), free); - memcpy(params_store_[index].get(), ptr, size); - params_[index] = params_store_[index].get(); -} - -void host_kernel::setArg(unsigned int index, driver::buffer* buffer){ - if(buffer) - kernel::setArg(index, (void*)buffer->hst()->data); - else - kernel::setArg(index, (std::ptrdiff_t)0); -} - -const std::vector &host_kernel::params(){ - return params_; -} - /* ------------------------ */ // CUDA // /* ------------------------ */ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) { - cu_params_store_.reserve(64); - cu_params_.reserve(64); dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name); // dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED); } -void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){ - if(index + 1> cu_params_store_.size()){ - cu_params_store_.resize(index+1); - cu_params_.resize(index+1); - } - cu_params_store_[index].reset(malloc(size), free); - memcpy(cu_params_store_[index].get(), ptr, size); - cu_params_[index] = cu_params_store_[index].get(); -} - -void cu_kernel::setArg(unsigned int index, driver::buffer* data){ - if(data) - kernel::setArg(index, *data->cu()); - else - kernel::setArg(index, (std::ptrdiff_t)0); -} - -void* const* cu_kernel::cu_params() const -{ return cu_params_.data(); } - - } } diff --git a/tests/common/cuda/cublas.h b/tests/common/cuda/cublas.h index 1d403c413..aee399f57 100644 --- a/tests/common/cuda/cublas.h +++ b/tests/common/cuda/cublas.h @@ -193,9 +193,6 @@ inline void cublasGemm(cublasDataType_t dtype, triton::driver::buffer* B, int32_t ldb, void* beta, triton::driver::buffer* C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { - - // switch triton context - triton::driver::cu_context::context_switcher scope(*stream->context()); // get handle static cublasHandle_t handle = cublasGetHandle(stream); // set math mode