diff --git a/include/triton/driver/handle.h b/include/triton/driver/handle.h index d750aeba5..e520326e6 100755 --- a/include/triton/driver/handle.h +++ b/include/triton/driver/handle.h @@ -51,6 +51,8 @@ struct host_context_t{ struct host_stream_t{ std::shared_ptr pool; + std::shared_ptr>> futures; + std::vector> args; }; struct host_module_t{ diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index f0adafdd2..84c0aaa03 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -134,7 +134,7 @@ public: function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = ""); void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream); void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); - void set_cst(const std::string& name, void* data, size_t n_bytes); + void set_cst(const char* name, void* data, size_t n_bytes); std::string ptx(driver::stream *stream, const options_t& opt); private: diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index 31f35cab3..fd581b085 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -21,6 +21,7 @@ */ #include +#include #include #include "triton/driver/backend.h" #include "triton/driver/stream.h" @@ -68,20 +69,25 @@ driver::context* stream::context() const { /* ------------------------ */ host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) { - hst_->pool.reset(new ThreadPool(8)); + hst_->pool.reset(new ThreadPool(1)); + hst_->futures.reset(new std::vector>()); } void host_stream::synchronize() { - hst_->pool.reset(new ThreadPool(8)); + for(auto& x: *hst_->futures) + x.wait(); + hst_->futures->clear(); + hst_->args.clear(); } void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **args, size_t args_size) { - ThreadPool pool(4); auto hst = kernel->module()->hst(); + char* params = new char[args_size]; + std::memcpy((void*)params, (void*)args, args_size); for(size_t i = 0; i < grid[0]; i++) for(size_t j = 0; j < grid[1]; j++) for(size_t k = 0; k < grid[2]; k++) - hst_->pool->enqueue(hst->fn, (char**)args, int32_t(i), int32_t(j), int32_t(k)); + hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k))); } void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 66d8723dc..bd9d8cb48 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -335,8 +335,8 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g } // set copy host buffer "data" into constant memory buffer "name" -void function::set_cst(const std::string& name, void* data, size_t n_bytes) { - cst_[name] = std::vector((char*)data, (char*)data + n_bytes); +void function::set_cst(const char* name, void* data, size_t n_bytes) { + cst_[std::string(name)] = std::vector((char*)data, (char*)data + n_bytes); } diff --git a/python/src/launch.cc b/python/src/launch.cc index 3883aba46..9f2426dd0 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -27,23 +27,39 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){ return ret; } +void init_host_stream() { + if(!host_stream){ + host_device.reset(new drv::host_device()); + host_context.reset(drv::context::create(&*host_device)); + host_stream.reset(drv::stream::create(&*host_context)); + } +} + CUstream torch_get_cuda_stream(int64_t dev_id) { return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); } +void synchronize(int64_t dev_id) { + if(dev_id == -1){ + init_host_stream(); + host_stream->synchronize(); + } + else{ + triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); + stream.synchronize(); + } + +} + void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, const std::vector& constant_names, const std::vector& constant_vals){ rt::function* fn = id_fn_map.at({op_id, dev_id}).get(); for(size_t n = 0; n < constant_names.size(); n++){ const torch::Tensor& x = constant_vals[n]; - fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size()); + fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size()); } if(dev_id == -1){ - if(!host_stream){ - host_device.reset(new drv::host_device()); - host_context.reset(drv::context::create(&*host_device)); - host_stream.reset(drv::stream::create(&*host_context)); - } + init_host_stream(); (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); } else{ @@ -56,4 +72,5 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, static auto registry = torch::RegisterOperators() .op("triton::launch_kernel", &launch_kernel) - .op("triton::cdiv_sum", &cdiv_sum); + .op("triton::cdiv_sum", &cdiv_sum) + .op("triton::synchronize", &synchronize); diff --git a/python/triton/kernel.py b/python/triton/kernel.py index e4f83aa41..5bab8f6ea 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -49,6 +49,11 @@ def cdiv(a, b): def cdiv_sum(a, b): return torch.ops.triton.cdiv_sum(a, b) + +def synchronize(device): + dev_id = device.index + dev_id = -1 if dev_id is None else dev_id + torch.ops.triton.synchronize(dev_id) class kernel: