[DRIVER] Simplified Driver API by substantially removing reliance on driver::context
This commit is contained in:
@@ -21,6 +21,7 @@ std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
CUdevice torch_get_cuda_device(int64_t dev_id);
|
||||
|
||||
/* Grid utilities */
|
||||
|
||||
@@ -47,8 +48,8 @@ void delete_fn(const map_key_t& key) {
|
||||
}
|
||||
|
||||
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(key.second), false);
|
||||
return id_fn_map[key]->ptx(&stream, opt);
|
||||
triton::driver::cu_device device(torch_get_cuda_device(key.second), false);
|
||||
return id_fn_map[key]->ptx(&device, opt);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
|
@@ -31,12 +31,18 @@ void init_host_stream() {
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(&*host_context));
|
||||
host_stream.reset(drv::stream::create(host_context->backend()));
|
||||
}
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
|
||||
CUdevice ret;
|
||||
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
@@ -60,12 +66,12 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
}
|
||||
if(dev_id == -1){
|
||||
init_host_stream();
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
|
||||
}
|
||||
else{
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user