[PYTHON] Some cleaning of the PyBind11 wrappers (#62)
This commit is contained in:
committed by
Philippe Tillet
parent
08909b49c8
commit
2a02fabdac
66
python/src/torch/utils.cc
Normal file
66
python/src/torch/utils.cc
Normal file
@@ -0,0 +1,66 @@
|
||||
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
|
||||
namespace torch_utils {
|
||||
|
||||
void register_device(int64_t dev_id) {
|
||||
if (tt_devices.find(dev_id) != tt_devices.end())
|
||||
return;
|
||||
triton::driver::device *device;
|
||||
if (dev_id >= 0) {
|
||||
CUdevice handle;
|
||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
device = new triton::driver::cu_device(handle, false);
|
||||
} else
|
||||
device = new triton::driver::host_device();
|
||||
tt_devices[dev_id].reset(device);
|
||||
}
|
||||
|
||||
void register_stream(int64_t dev_id) {
|
||||
if (tt_streams.find(dev_id) != tt_streams.end())
|
||||
return;
|
||||
triton::driver::stream *stream;
|
||||
if (dev_id >= 0) {
|
||||
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
stream = new triton::driver::cu_stream(handle, false);
|
||||
} else
|
||||
stream = new triton::driver::host_stream();
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
tt_streams[dev_id]->synchronize();
|
||||
}
|
||||
|
||||
void set_device(int64_t dev_id) {
|
||||
if (dev_id >= 0)
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
torch::Tensor move_out_of_pool(torch::Tensor x) {
|
||||
if (x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void *data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
|
||||
ret.copy_(x);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace torch_utils
|
||||
|
||||
void init_torch_utils(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("torch_utils");
|
||||
subm.def("register_device", &torch_utils::register_device);
|
||||
subm.def("register_stream", &torch_utils::register_stream);
|
||||
subm.def("set_device", &torch_utils::set_device);
|
||||
subm.def("synchronize", &torch_utils::synchronize);
|
||||
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
|
||||
}
|
Reference in New Issue
Block a user