[PYTHON] Some cleaning of the PyBind11 wrappers (#62)
This commit is contained in:
committed by
Philippe Tillet
parent
08909b49c8
commit
2a02fabdac
@@ -33,8 +33,8 @@ endif()
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# PyBind11 wrapper source file
|
||||
set(TORCH_SRC torch/launch.cc torch/superblock.cc)
|
||||
set(PYTHON_SRC bindings.cc ${TORCH_SRC})
|
||||
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
|
||||
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
|
||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
|
@@ -1,219 +0,0 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
using namespace triton;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
namespace lng = triton::lang;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
|
||||
extern CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
extern CUdevice torch_get_cuda_device(int64_t dev_id);
|
||||
|
||||
|
||||
/* Grid utilities */
|
||||
|
||||
void register_grid(const map_key_t& key,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
}
|
||||
|
||||
void delete_grid(const map_key_t& key) {
|
||||
id_grid_map.erase(key);
|
||||
}
|
||||
|
||||
/* Function utilities */
|
||||
|
||||
void register_fn(int op_id,
|
||||
int dev_id,
|
||||
const std::string& src,
|
||||
const rt::options_t& opt,
|
||||
const rt::function::autotune_vals_t& autotune_vals,
|
||||
const std::vector<std::string>& autotune_key) {
|
||||
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
||||
driver::device* device;
|
||||
driver::stream* stream;
|
||||
if(dev_id >= 0){
|
||||
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
|
||||
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
|
||||
}
|
||||
else{
|
||||
device = new triton::driver::host_device();
|
||||
stream = new triton::driver::host_stream();
|
||||
}
|
||||
tt_devices[dev_id].reset(device);
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
|
||||
}
|
||||
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
||||
const rt::options_t* opt = &k.first;
|
||||
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
|
||||
for(auto x: opt->defines)
|
||||
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
opt_cache_[&k.second->opt] = obj;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void delete_fn(int op_id) {
|
||||
id_fn_map.erase(op_id);
|
||||
}
|
||||
|
||||
|
||||
void cleanup() {
|
||||
id_grid_map.clear();
|
||||
id_fn_map.clear();
|
||||
opt_cache_.clear();
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
|
||||
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
|
||||
rt::function* fn = id_fn_map.at(op_id).get();
|
||||
(*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
|
||||
|
||||
// for(size_t n = 0; n < constant_names.size(); n++){
|
||||
// const torch::Tensor& x = constant_vals[n];
|
||||
// fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
|
||||
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
|
||||
rt::function* fn = id_fn_map.at(op_id).get();
|
||||
auto wrapper = [&grid](const rt::options_t& opt){
|
||||
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
|
||||
for(auto x: opt.defines)
|
||||
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
return grid(*obj.cast<rt::options_t*>());
|
||||
};
|
||||
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
|
||||
return opt_cache_.at(&kernel->opt);
|
||||
}
|
||||
|
||||
|
||||
std::string extract_kernels(const std::string& str, const std::vector<std::string>& names) {
|
||||
if(names.empty())
|
||||
return str;
|
||||
// search for all regex matches of kernel_regex in str
|
||||
std::smatch matches;
|
||||
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
|
||||
std::sregex_iterator it(str.begin(), str.end(), regex);
|
||||
std::sregex_iterator end;
|
||||
std::vector<std::tuple<std::string, int, int>> kernels;
|
||||
for (; it != end; ++it) {
|
||||
int pos = it->position();
|
||||
int len = it->length();
|
||||
std::string name = it->str(1);
|
||||
kernels.push_back(std::make_tuple(name, pos, len));
|
||||
}
|
||||
|
||||
for(const std::string& name: names) {
|
||||
// check that str matches any string in kernels using std::any_of
|
||||
auto pred = [&name](const std::tuple<std::string, int, int>& t) { return std::get<0>(t) == name; };
|
||||
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||
if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||
}
|
||||
|
||||
|
||||
// extract functions
|
||||
std::string ret;
|
||||
for(const auto& k: kernels) {
|
||||
std::string name;
|
||||
int pos, len;
|
||||
std::tie(name, pos, len) = k;
|
||||
if(std::find(names.begin(), names.end(), name) != names.end()){
|
||||
std::string def = str.substr(pos, str.size() - pos);
|
||||
int count, pos;
|
||||
// skip over declaration
|
||||
count = 1;
|
||||
pos = def.find('(');
|
||||
while(!(def[pos++] == ')' && count == 0) && pos < def.size()){
|
||||
count += def[pos] == '(';
|
||||
count -= def[pos] == ')';
|
||||
}
|
||||
// skip over definition
|
||||
count = 1;
|
||||
pos = def.find('{', pos);
|
||||
while(!(def[pos++] == '}' && count == 0) && pos < def.size()){
|
||||
count += def[pos] == '{';
|
||||
count -= def[pos] == '}';
|
||||
}
|
||||
ret += def.substr(0, pos);
|
||||
ret += '\n';
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_launch(pybind11::module &m);
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1" , rt::INT1_T)
|
||||
.value("int8" , rt::INT8_T)
|
||||
.value("int16" , rt::INT16_T)
|
||||
.value("int32" , rt::INT32_T)
|
||||
.value("int64" , rt::INT64_T)
|
||||
.value("half" , rt::HALF_T)
|
||||
.value("float" , rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx" , rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("defines" , &rt::options_t::defines)
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("extract_kernels", &extract_kernels);
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("delete_fn", &delete_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("cleanup", &cleanup);
|
||||
m.def("autotune", &autotune, pybind11::return_value_policy::reference);
|
||||
m.def("launch_kernel", &launch_kernel);
|
||||
|
||||
init_launch(m);
|
||||
init_superblocking(m);
|
||||
}
|
12
python/src/main.cc
Normal file
12
python/src/main.cc
Normal file
@@ -0,0 +1,12 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_torch_utils(pybind11::module &m);
|
||||
void init_triton(pybind11::module &m);
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
init_triton(m);
|
||||
init_torch_utils(m);
|
||||
init_superblocking(m);
|
||||
}
|
@@ -1,83 +0,0 @@
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
|
||||
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
|
||||
|
||||
|
||||
int64_t cdiv(int64_t a, int64_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
int64_t largest_pow2_divisor(int64_t a){
|
||||
if(a % 8 == 0) return 8;
|
||||
if(a % 4 == 0) return 4;
|
||||
if(a % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
|
||||
auto _x = x.accessor<int, 1>();
|
||||
int64_t ret = 0;
|
||||
for(size_t i = 0; i < x.size(0); i++)
|
||||
ret += (_x[i] + div - 1) / div;
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
|
||||
CUdevice ret;
|
||||
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
tt_streams[dev_id]->synchronize();
|
||||
}
|
||||
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void cuda_set_device(int64_t dev_id) {
|
||||
if(dev_id >= 0)
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
|
||||
void init_launch(pybind11::module &m) {
|
||||
m.def("cuda_set_device", &cuda_set_device);
|
||||
m.def("cuda_empty_like", &cuda_empty_like);
|
||||
m.def("largest_pow2_divisor", &largest_pow2_divisor);
|
||||
m.def("cdiv", &cdiv);
|
||||
m.def("cdiv_sum", &cdiv_sum);
|
||||
m.def("synchronize", &synchronize);
|
||||
}
|
66
python/src/torch/utils.cc
Normal file
66
python/src/torch/utils.cc
Normal file
@@ -0,0 +1,66 @@
|
||||
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
|
||||
namespace torch_utils {
|
||||
|
||||
void register_device(int64_t dev_id) {
|
||||
if (tt_devices.find(dev_id) != tt_devices.end())
|
||||
return;
|
||||
triton::driver::device *device;
|
||||
if (dev_id >= 0) {
|
||||
CUdevice handle;
|
||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
device = new triton::driver::cu_device(handle, false);
|
||||
} else
|
||||
device = new triton::driver::host_device();
|
||||
tt_devices[dev_id].reset(device);
|
||||
}
|
||||
|
||||
void register_stream(int64_t dev_id) {
|
||||
if (tt_streams.find(dev_id) != tt_streams.end())
|
||||
return;
|
||||
triton::driver::stream *stream;
|
||||
if (dev_id >= 0) {
|
||||
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
stream = new triton::driver::cu_stream(handle, false);
|
||||
} else
|
||||
stream = new triton::driver::host_stream();
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
tt_streams[dev_id]->synchronize();
|
||||
}
|
||||
|
||||
void set_device(int64_t dev_id) {
|
||||
if (dev_id >= 0)
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
torch::Tensor move_out_of_pool(torch::Tensor x) {
|
||||
if (x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void *data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
|
||||
ret.copy_(x);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace torch_utils
|
||||
|
||||
void init_torch_utils(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("torch_utils");
|
||||
subm.def("register_device", &torch_utils::register_device);
|
||||
subm.def("register_stream", &torch_utils::register_stream);
|
||||
subm.def("set_device", &torch_utils::set_device);
|
||||
subm.def("synchronize", &torch_utils::synchronize);
|
||||
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
|
||||
}
|
169
python/src/triton.cc
Normal file
169
python/src/triton.cc
Normal file
@@ -0,0 +1,169 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
using namespace triton;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
namespace lng = triton::lang;
|
||||
|
||||
std::unordered_map<const rt::options_t *, pybind11::object> opt_cache_;
|
||||
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
extern std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
|
||||
/* Function utilities */
|
||||
|
||||
void register_fn(int op_id, int dev_id,
|
||||
const std::string &src, const rt::options_t &opt,
|
||||
const rt::function::autotune_vals_t &autotune_vals,
|
||||
const std::vector<std::string> &autotune_key) {
|
||||
if (id_fn_map.find(op_id) == id_fn_map.end()) {
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
|
||||
}
|
||||
for (const auto &k : id_fn_map[op_id]->get_kernels()) {
|
||||
const rt::options_t *opt = &k.first;
|
||||
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
|
||||
for (auto x : opt->defines)
|
||||
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
opt_cache_[&k.second->opt] = obj;
|
||||
}
|
||||
}
|
||||
|
||||
void delete_fn(int op_id) {
|
||||
id_fn_map.erase(op_id);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_fn_map.clear();
|
||||
opt_cache_.clear();
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
|
||||
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
|
||||
}
|
||||
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) {
|
||||
rt::function *fn = id_fn_map.at(op_id).get();
|
||||
(*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
|
||||
}
|
||||
|
||||
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) {
|
||||
rt::function *fn = id_fn_map.at(op_id).get();
|
||||
auto wrapper = [&grid](const rt::options_t &opt) {
|
||||
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
|
||||
for (auto x : opt.defines)
|
||||
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
return grid(*obj.cast<rt::options_t *>());
|
||||
};
|
||||
rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
|
||||
return opt_cache_.at(&kernel->opt);
|
||||
}
|
||||
|
||||
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
|
||||
if (names.empty())
|
||||
return str;
|
||||
// search for all regex matches of kernel_regex in str
|
||||
std::smatch matches;
|
||||
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
|
||||
std::sregex_iterator it(str.begin(), str.end(), regex);
|
||||
std::sregex_iterator end;
|
||||
std::vector<std::tuple<std::string, int, int>> kernels;
|
||||
for (; it != end; ++it) {
|
||||
int pos = it->position();
|
||||
int len = it->length();
|
||||
std::string name = it->str(1);
|
||||
kernels.push_back(std::make_tuple(name, pos, len));
|
||||
}
|
||||
|
||||
for (const std::string &name : names) {
|
||||
// check that str matches any string in kernels using std::any_of
|
||||
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
|
||||
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||
if (!found)
|
||||
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||
}
|
||||
|
||||
// extract functions
|
||||
std::string ret;
|
||||
for (const auto &k : kernels) {
|
||||
std::string name;
|
||||
int pos, len;
|
||||
std::tie(name, pos, len) = k;
|
||||
if (std::find(names.begin(), names.end(), name) != names.end()) {
|
||||
std::string def = str.substr(pos, str.size() - pos);
|
||||
int count, pos;
|
||||
// skip over declaration
|
||||
count = 1;
|
||||
pos = def.find('(');
|
||||
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '(';
|
||||
count -= def[pos] == ')';
|
||||
}
|
||||
// skip over definition
|
||||
count = 1;
|
||||
pos = def.find('{', pos);
|
||||
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '{';
|
||||
count -= def[pos] == '}';
|
||||
}
|
||||
ret += def.substr(0, pos);
|
||||
ret += '\n';
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_triton(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("triton");
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(subm, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
.value("int32", rt::INT32_T)
|
||||
.value("int64", rt::INT64_T)
|
||||
.value("half", rt::HALF_T)
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(subm, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<rt::options_t>(subm, "options", pybind11::dynamic_attr())
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("defines", &rt::options_t::defines)
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
subm.def("extract_kernels", &extract_kernels);
|
||||
subm.def("get_fn_signature", &get_fn_signature);
|
||||
subm.def("register_fn", ®ister_fn);
|
||||
subm.def("delete_fn", &delete_fn);
|
||||
subm.def("make_op_id", &make_op_id);
|
||||
subm.def("cleanup", &cleanup);
|
||||
subm.def("autotune", &autotune, pybind11::return_value_policy::reference);
|
||||
subm.def("launch_kernel", &launch_kernel);
|
||||
}
|
@@ -1,16 +1,11 @@
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
|
||||
# libtriton resources
|
||||
import atexit
|
||||
import triton._C.libtriton as libtriton
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
libtriton.cleanup()
|
||||
|
||||
# submodules
|
||||
from .kernel import *
|
||||
from . import ops
|
||||
# C bindings
|
||||
import triton._C.libtriton.torch_utils as _torch_utils
|
||||
|
||||
# version
|
||||
__version__ = '1.0.0'
|
@@ -1,18 +1,24 @@
|
||||
import triton._C.libtriton as libtriton
|
||||
import os
|
||||
import time
|
||||
from struct import pack
|
||||
import struct
|
||||
import torch
|
||||
# C bindings
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton._C.libtriton.torch_utils as _torch_utils
|
||||
# Make sure internal C resources are cleaned up upon exit
|
||||
import atexit
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
_triton.cleanup()
|
||||
|
||||
codes = {
|
||||
libtriton.arg_type.int1: 'B',
|
||||
libtriton.arg_type.int8: 'B',
|
||||
libtriton.arg_type.int32: 'I',
|
||||
libtriton.arg_type.int64: 'Q',
|
||||
libtriton.arg_type.half: 'H',
|
||||
libtriton.arg_type.float: 'f',
|
||||
libtriton.arg_type.double: 'd',
|
||||
libtriton.arg_type.buffer: 'P'
|
||||
_triton.arg_type.int1: 'B',
|
||||
_triton.arg_type.int8: 'B',
|
||||
_triton.arg_type.int32: 'I',
|
||||
_triton.arg_type.int64: 'Q',
|
||||
_triton.arg_type.half: 'H',
|
||||
_triton.arg_type.float: 'f',
|
||||
_triton.arg_type.double: 'd',
|
||||
_triton.arg_type.buffer: 'P'
|
||||
}
|
||||
|
||||
def th_to_triton(obj):
|
||||
@@ -30,17 +36,17 @@ def th_to_triton(obj):
|
||||
return str(obj)
|
||||
|
||||
def cdiv(a, b):
|
||||
return libtriton.cdiv(a, b)
|
||||
return (a + b - 1) // b
|
||||
|
||||
def synchronize(device):
|
||||
dev_id = device.index
|
||||
dev_id = -1 if dev_id is None else dev_id
|
||||
libtriton.synchronize(dev_id)
|
||||
_torch_utils.synchronize(dev_id)
|
||||
|
||||
def read(path, kernel_names=[]):
|
||||
with open(path, 'r') as f:
|
||||
source = f.read()
|
||||
source = libtriton.extract_kernels(source, kernel_names)
|
||||
source = _triton.extract_kernels(source, kernel_names)
|
||||
return source
|
||||
|
||||
class kernel:
|
||||
@@ -50,7 +56,7 @@ class kernel:
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
self.src = src
|
||||
self.opt = libtriton.options()
|
||||
self.opt = _triton.options()
|
||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||
self.opt.num_warps = num_warps
|
||||
# device
|
||||
@@ -59,39 +65,25 @@ class kernel:
|
||||
self.device = torch.cuda.current_device() if device.index is None else device.index
|
||||
if device.type == 'cpu':
|
||||
self.device = -1
|
||||
_torch_utils.register_device(self.device)
|
||||
_torch_utils.register_stream(self.device)
|
||||
# C++ function wrapper
|
||||
self.op_id = libtriton.make_op_id()
|
||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||
self.op_id = _triton.make_op_id()
|
||||
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||
# debug mode
|
||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||
# signature
|
||||
arg_types = libtriton.get_fn_signature(self.op_id)
|
||||
arg_types = _triton.get_fn_signature(self.op_id)
|
||||
self.tys = ''.join([codes[x] for x in arg_types])
|
||||
|
||||
def __call__(self, *args, grid):
|
||||
# debug mode (initialize)
|
||||
if self.is_debug:
|
||||
_args = args
|
||||
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
args[i] = libtriton.cuda_empty_like(args[i])
|
||||
args[i].copy_(_args[i])
|
||||
# initialize cuda device if necessary
|
||||
libtriton.cuda_set_device(self.device)
|
||||
_torch_utils.set_device(self.device)
|
||||
# pack parameters into a byte buffer
|
||||
params = pack(self.tys, *args)
|
||||
# auto-tune if necessary
|
||||
opt = libtriton.autotune(self.op_id, self.device, params, grid)
|
||||
params = struct.pack(self.tys, *args)
|
||||
opt = _triton.autotune(self.op_id, self.device, params, grid)
|
||||
# run kernel
|
||||
grid = grid(opt)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = 1 if len(grid) < 2 else grid[1]
|
||||
grid_2 = 1 if len(grid) < 3 else grid[2]
|
||||
libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
||||
# debug mode (finalize)
|
||||
if self.is_debug:
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
_args[i].copy_(args[i].clone())
|
||||
args = _args
|
||||
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
Reference in New Issue
Block a user