[DRIVER] Improved performance of Host driver code
This commit is contained in:
committed by
Philippe Tillet
parent
8f8d36c7a4
commit
a77c925dfd
@@ -51,6 +51,8 @@ struct host_context_t{
|
||||
|
||||
struct host_stream_t{
|
||||
std::shared_ptr<ThreadPool> pool;
|
||||
std::shared_ptr<std::vector<std::future<void>>> futures;
|
||||
std::vector<std::shared_ptr<char*>> args;
|
||||
};
|
||||
|
||||
struct host_module_t{
|
||||
|
@@ -134,7 +134,7 @@ public:
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||
void set_cst(const char* name, void* data, size_t n_bytes);
|
||||
std::string ptx(driver::stream *stream, const options_t& opt);
|
||||
|
||||
private:
|
||||
|
@@ -21,6 +21,7 @@
|
||||
*/
|
||||
|
||||
#include <cassert>
|
||||
#include <unistd.h>
|
||||
#include <array>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
@@ -68,20 +69,25 @@ driver::context* stream::context() const {
|
||||
/* ------------------------ */
|
||||
|
||||
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
|
||||
hst_->pool.reset(new ThreadPool(8));
|
||||
hst_->pool.reset(new ThreadPool(1));
|
||||
hst_->futures.reset(new std::vector<std::future<void>>());
|
||||
}
|
||||
|
||||
void host_stream::synchronize() {
|
||||
hst_->pool.reset(new ThreadPool(8));
|
||||
for(auto& x: *hst_->futures)
|
||||
x.wait();
|
||||
hst_->futures->clear();
|
||||
hst_->args.clear();
|
||||
}
|
||||
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **args, size_t args_size) {
|
||||
ThreadPool pool(4);
|
||||
auto hst = kernel->module()->hst();
|
||||
char* params = new char[args_size];
|
||||
std::memcpy((void*)params, (void*)args, args_size);
|
||||
for(size_t i = 0; i < grid[0]; i++)
|
||||
for(size_t j = 0; j < grid[1]; j++)
|
||||
for(size_t k = 0; k < grid[2]; k++)
|
||||
hst_->pool->enqueue(hst->fn, (char**)args, int32_t(i), int32_t(j), int32_t(k));
|
||||
hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k)));
|
||||
}
|
||||
|
||||
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
|
@@ -335,8 +335,8 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g
|
||||
}
|
||||
|
||||
// set copy host buffer "data" into constant memory buffer "name"
|
||||
void function::set_cst(const std::string& name, void* data, size_t n_bytes) {
|
||||
cst_[name] = std::vector<char>((char*)data, (char*)data + n_bytes);
|
||||
void function::set_cst(const char* name, void* data, size_t n_bytes) {
|
||||
cst_[std::string(name)] = std::vector<char>((char*)data, (char*)data + n_bytes);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -27,23 +27,39 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_host_stream() {
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
}
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
if(dev_id == -1){
|
||||
init_host_stream();
|
||||
host_stream->synchronize();
|
||||
}
|
||||
else{
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
stream.synchronize();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
||||
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
||||
for(size_t n = 0; n < constant_names.size(); n++){
|
||||
const torch::Tensor& x = constant_vals[n];
|
||||
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
if(dev_id == -1){
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
}
|
||||
init_host_stream();
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
}
|
||||
else{
|
||||
@@ -56,4 +72,5 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
|
||||
static auto registry = torch::RegisterOperators()
|
||||
.op("triton::launch_kernel", &launch_kernel)
|
||||
.op("triton::cdiv_sum", &cdiv_sum);
|
||||
.op("triton::cdiv_sum", &cdiv_sum)
|
||||
.op("triton::synchronize", &synchronize);
|
||||
|
@@ -50,6 +50,11 @@ def cdiv(a, b):
|
||||
def cdiv_sum(a, b):
|
||||
return torch.ops.triton.cdiv_sum(a, b)
|
||||
|
||||
def synchronize(device):
|
||||
dev_id = device.index
|
||||
dev_id = -1 if dev_id is None else dev_id
|
||||
torch.ops.triton.synchronize(dev_id)
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
||||
|
Reference in New Issue
Block a user