[DRIVER] Improved performance of Host driver code

This commit is contained in:
Philippe Tillet
2020-11-12 02:11:45 -05:00
committed by Philippe Tillet
parent 8f8d36c7a4
commit a77c925dfd
6 changed files with 44 additions and 14 deletions

View File

@@ -27,23 +27,39 @@ int64_t cdiv_sum(torch::Tensor x, int64_t div){
return ret;
}
void init_host_stream() {
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(&*host_context));
}
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
}
void synchronize(int64_t dev_id) {
if(dev_id == -1){
init_host_stream();
host_stream->synchronize();
}
else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
stream.synchronize();
}
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
for(size_t n = 0; n < constant_names.size(); n++){
const torch::Tensor& x = constant_vals[n];
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
if(dev_id == -1){
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(&*host_context));
}
init_host_stream();
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
}
else{
@@ -56,4 +72,5 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::cdiv_sum", &cdiv_sum);
.op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize);

View File

@@ -49,6 +49,11 @@ def cdiv(a, b):
def cdiv_sum(a, b):
return torch.ops.triton.cdiv_sum(a, b)
def synchronize(device):
dev_id = device.index
dev_id = -1 if dev_id is None else dev_id
torch.ops.triton.synchronize(dev_id)
class kernel: