Files
triton/python/src/torch/utils.cc

66 lines
2.0 KiB
C++
Raw Normal View History

#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
namespace torch_utils {
void register_device(int64_t dev_id) {
if (tt_devices.find(dev_id) != tt_devices.end())
return;
triton::driver::device *device;
if (dev_id >= 0) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
device = new triton::driver::cu_device(handle, false);
} else
device = new triton::driver::host_device();
tt_devices[dev_id].reset(device);
}
void register_stream(int64_t dev_id) {
if (tt_streams.find(dev_id) != tt_streams.end())
return;
triton::driver::stream *stream;
if (dev_id >= 0) {
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
stream = new triton::driver::cu_stream(handle, false);
} else
stream = new triton::driver::host_stream();
tt_streams[dev_id].reset(stream);
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
}
void set_device(int64_t dev_id) {
if (dev_id >= 0)
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
torch::Tensor move_out_of_pool(torch::Tensor x) {
if (x.nbytes() == 0)
return torch::empty_like(x);
void *data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
ret.copy_(x);
return ret;
}
} // namespace torch_utils
void init_torch_utils(pybind11::module &m) {
pybind11::module subm = m.def_submodule("torch_utils");
subm.def("register_device", &torch_utils::register_device);
subm.def("register_stream", &torch_utils::register_stream);
subm.def("set_device", &torch_utils::set_device);
subm.def("synchronize", &torch_utils::synchronize);
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
}