[RUNTIME] Added auto-alignment mechanism (#71)

This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration.

This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
Philippe Tillet
2021-03-04 01:51:11 -05:00
committed by Philippe Tillet
parent ff62f7fffc
commit 62835a0979
19 changed files with 668 additions and 707 deletions

View File

@@ -5,38 +5,16 @@
#include <cuda_runtime_api.h>
#include <torch/extension.h>
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
namespace torch_utils {
void register_device(int64_t dev_id) {
if (tt_devices.find(dev_id) != tt_devices.end())
return;
triton::driver::device *device;
if (dev_id >= 0) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
device = new triton::driver::cu_device(handle, false);
} else
device = new triton::driver::host_device();
tt_devices[dev_id].reset(device);
uint64_t cu_device(int64_t dev_id) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
return (uint64_t)handle;
}
void register_stream(int64_t dev_id) {
if (tt_streams.find(dev_id) != tt_streams.end())
return;
triton::driver::stream *stream;
if (dev_id >= 0) {
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
stream = new triton::driver::cu_stream(handle, false);
} else
stream = new triton::driver::host_stream();
tt_streams[dev_id].reset(stream);
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
uint64_t cu_stream(int64_t dev_id) {
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
void set_device(int64_t dev_id) {
@@ -44,23 +22,11 @@ void set_device(int64_t dev_id) {
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
torch::Tensor move_out_of_pool(torch::Tensor x) {
if (x.nbytes() == 0)
return torch::empty_like(x);
void *data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
ret.copy_(x);
return ret;
}
} // namespace torch_utils
void init_torch_utils(pybind11::module &m) {
pybind11::module subm = m.def_submodule("torch_utils");
subm.def("register_device", &torch_utils::register_device);
subm.def("register_stream", &torch_utils::register_stream);
subm.def("cu_device", &torch_utils::cu_device);
subm.def("cu_stream", &torch_utils::cu_stream);
subm.def("set_device", &torch_utils::set_device);
subm.def("synchronize", &torch_utils::synchronize);
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
}