[RUNTIME] Added auto-alignment mechanism (#71)

This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration.

This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
Philippe Tillet
2021-03-04 01:51:11 -05:00
committed by Philippe Tillet
parent ff62f7fffc
commit 62835a0979
19 changed files with 668 additions and 707 deletions

View File

@@ -5,38 +5,16 @@
#include <cuda_runtime_api.h>
#include <torch/extension.h>
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
namespace torch_utils {
void register_device(int64_t dev_id) {
if (tt_devices.find(dev_id) != tt_devices.end())
return;
triton::driver::device *device;
if (dev_id >= 0) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
device = new triton::driver::cu_device(handle, false);
} else
device = new triton::driver::host_device();
tt_devices[dev_id].reset(device);
uint64_t cu_device(int64_t dev_id) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
return (uint64_t)handle;
}
void register_stream(int64_t dev_id) {
if (tt_streams.find(dev_id) != tt_streams.end())
return;
triton::driver::stream *stream;
if (dev_id >= 0) {
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
stream = new triton::driver::cu_stream(handle, false);
} else
stream = new triton::driver::host_stream();
tt_streams[dev_id].reset(stream);
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
uint64_t cu_stream(int64_t dev_id) {
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
void set_device(int64_t dev_id) {
@@ -44,23 +22,11 @@ void set_device(int64_t dev_id) {
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
torch::Tensor move_out_of_pool(torch::Tensor x) {
if (x.nbytes() == 0)
return torch::empty_like(x);
void *data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
ret.copy_(x);
return ret;
}
} // namespace torch_utils
void init_torch_utils(pybind11::module &m) {
pybind11::module subm = m.def_submodule("torch_utils");
subm.def("register_device", &torch_utils::register_device);
subm.def("register_stream", &torch_utils::register_stream);
subm.def("cu_device", &torch_utils::cu_device);
subm.def("cu_stream", &torch_utils::cu_stream);
subm.def("set_device", &torch_utils::set_device);
subm.def("synchronize", &torch_utils::synchronize);
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
}

View File

@@ -1,10 +1,4 @@
#include "triton/driver/stream.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/function.h"
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -13,72 +7,22 @@
#include <regex>
#include <string>
namespace py = pybind11;
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
namespace lng = triton::lang;
std::unordered_map<const rt::options_t *, pybind11::object> opt_cache_;
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
extern std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
/*****************************************************************************/
/* Python bindings for triton::tools */
/*****************************************************************************/
/* Function utilities */
void register_fn(int op_id, int dev_id,
const std::string &src, const rt::options_t &opt,
const rt::function::autotune_vals_t &autotune_vals,
const std::vector<std::string> &autotune_key) {
if (id_fn_map.find(op_id) == id_fn_map.end()) {
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for (const auto &k : id_fn_map[op_id]->get_kernels()) {
const rt::options_t *opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for (auto x : opt->defines)
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
}
void delete_fn(int op_id) {
id_fn_map.erase(op_id);
}
void cleanup() {
id_fn_map.clear();
opt_cache_.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
}
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) {
rt::function *fn = id_fn_map.at(op_id).get();
(*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
}
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) {
rt::function *fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t &opt) {
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for (auto x : opt.defines)
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t *>());
};
rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
/*!
@brief Function for extracting kernels out of a given source-string
This can be important to enable pre-processor macros (or tunable parameters) that should only
be defined within the scope of a single kernel function
*/
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
if (names.empty())
return str;
@@ -94,50 +38,82 @@ std::string extract_kernels(const std::string &str, const std::vector<std::strin
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
// check that all the kernels provided actually exist
for (const std::string &name : names) {
// check that str matches any string in kernels using std::any_of
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if (!found)
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// extract functions
// simple parsing logic to extract the declaration and body of each specified kernel
std::string ret;
for (const auto &k : kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if (std::find(names.begin(), names.end(), name) != names.end()) {
std::string def = str.substr(pos, str.size() - pos);
int count, pos;
// skip over declaration
count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
if (std::find(names.begin(), names.end(), name) == names.end())
continue;
std::string def = str.substr(pos, str.size() - pos);
// skip over declaration
// by finding matching ')' for first '('
int count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
// by finding matching '{' for first '}'
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
return ret;
}
void init_triton(pybind11::module &m) {
pybind11::module subm = m.def_submodule("triton");
// bindings for triton classes
pybind11::enum_<rt::arg_type>(subm, "arg_type")
void init_triton_tools(py::module &&m) {
m.def("extract_kernels", &extract_kernels);
}
/*****************************************************************************/
/* Python bindings for triton::driver */
/*****************************************************************************/
void init_triton_driver(py::module &&m) {
// base device
py::class_<drv::device>(m, "device");
// cuda device
py::class_<drv::cu_device, driver::device>(m, "cu_device")
.def(py::init<CUdevice, bool>());
// host device
py::class_<drv::host_device, driver::device>(m, "host_device")
.def(py::init<>());
// base stream
py::class_<drv::stream>(m, "stream");
// host stream
py::class_<drv::host_stream, drv::stream>(m, "host_stream")
.def(py::init<>());
// cuda stream
py::class_<drv::cu_stream, drv::stream>(m, "cu_stream")
// py doesn't support opaque pointer (e.g., CUstream) so
// we assume it has been converted to uint64_t
.def(py::init([](uint64_t handle, bool take_ownership) {
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
}));
}
/*****************************************************************************/
/* Python bindings for triton::runtime */
/*****************************************************************************/
void init_triton_runtime(py::module &&m) {
// argument type
py::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
@@ -147,23 +123,38 @@ void init_triton(pybind11::module &m) {
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(subm, "asm_mode")
// assembly mode
py::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(subm, "options", pybind11::dynamic_attr())
.def(pybind11::init<>())
// compilation options
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
.def(py::init<>())
.def_readwrite("defines", &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
return opt->D<int>(name);
});
// kernel
py::class_<rt::kernel>(m, "kernel")
.def("__call__", &rt::kernel::operator())
.def_readonly("opt", &rt::kernel::opt);
// tune conf
py::class_<rt::config>(m, "config")
.def(py::init<std::map<std::string, std::string>, int>(),
py::arg("defines") = std::map<std::string, std::string>(),
py::arg("num_warps"));
// hooks into triton constructs since frameworks may not use pybind11
subm.def("extract_kernels", &extract_kernels);
subm.def("get_fn_signature", &get_fn_signature);
subm.def("register_fn", &register_fn);
subm.def("delete_fn", &delete_fn);
subm.def("make_op_id", &make_op_id);
subm.def("cleanup", &cleanup);
subm.def("autotune", &autotune, pybind11::return_value_policy::reference);
subm.def("launch_kernel", &launch_kernel);
// function
py::class_<rt::function>(m, "function")
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
.def("signature", &rt::function::get_signature);
}
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_driver(std::move(subm.def_submodule("driver")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_tools(std::move(subm.def_submodule("tools")));
}