[DRIVER] Improved performance of Host driver code

This commit is contained in:
Philippe Tillet
2020-11-12 02:11:45 -05:00
committed by Philippe Tillet
parent 8f8d36c7a4
commit a77c925dfd
6 changed files with 44 additions and 14 deletions

View File

@@ -21,6 +21,7 @@
*/
#include <cassert>
#include <unistd.h>
#include <array>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
@@ -68,20 +69,25 @@ driver::context* stream::context() const {
/* ------------------------ */
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
hst_->pool.reset(new ThreadPool(8));
hst_->pool.reset(new ThreadPool(1));
hst_->futures.reset(new std::vector<std::future<void>>());
}
void host_stream::synchronize() {
hst_->pool.reset(new ThreadPool(8));
for(auto& x: *hst_->futures)
x.wait();
hst_->futures->clear();
hst_->args.clear();
}
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **args, size_t args_size) {
ThreadPool pool(4);
auto hst = kernel->module()->hst();
char* params = new char[args_size];
std::memcpy((void*)params, (void*)args, args_size);
for(size_t i = 0; i < grid[0]; i++)
for(size_t j = 0; j < grid[1]; j++)
for(size_t k = 0; k < grid[2]; k++)
hst_->pool->enqueue(hst->fn, (char**)args, int32_t(i), int32_t(j), int32_t(k));
hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k)));
}
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {

View File

@@ -335,8 +335,8 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g
}
// set copy host buffer "data" into constant memory buffer "name"
void function::set_cst(const std::string& name, void* data, size_t n_bytes) {
cst_[name] = std::vector<char>((char*)data, (char*)data + n_bytes);
void function::set_cst(const char* name, void* data, size_t n_bytes) {
cst_[std::string(name)] = std::vector<char>((char*)data, (char*)data + n_bytes);
}