[DRIVER] Simplified Driver API by substantially removing reliance on driver::context

This commit is contained in:
Philippe Tillet
2020-11-26 00:27:12 -05:00
parent f42b04d925
commit 4f08d87fed
24 changed files with 167 additions and 194 deletions

View File

@@ -31,12 +31,18 @@ void init_host_stream() {
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(&*host_context));
host_stream.reset(drv::stream::create(host_context->backend()));
}
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
@@ -60,12 +66,12 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
}
if(dev_id == -1){
init_host_stream();
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
}
else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::context* ctx = stream.context();
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
}
}