[PYTHON] Some cleaning of the PyBind11 wrappers (#62)

This commit is contained in:
Philippe Tillet
2021-02-06 17:10:44 -08:00
committed by Philippe Tillet
parent 08909b49c8
commit 2a02fabdac
8 changed files with 282 additions and 350 deletions

View File

@@ -33,8 +33,8 @@ endif()
if(BUILD_PYTHON_MODULE) if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module") message(STATUS "Adding Python module")
# PyBind11 wrapper source file # PyBind11 wrapper source file
set(TORCH_SRC torch/launch.cc torch/superblock.cc) file(GLOB_RECURSE TORCH_SRC torch/*.cc)
set(PYTHON_SRC bindings.cc ${TORCH_SRC}) set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
include_directories("." ${PYTHON_INCLUDE_DIRS}) include_directories("." ${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS}) link_directories(${PYTHON_LINK_DIRS})

View File

@@ -1,219 +0,0 @@
#include <pybind11/pybind11.h>
#include <pybind11/buffer_info.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <string>
#include <regex>
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/runtime/arg.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/parser.h"
#include "triton/lang/cpp.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
namespace lng = triton::lang;
typedef std::pair<int, int> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
extern CUstream torch_get_cuda_stream(int64_t dev_id);
extern CUdevice torch_get_cuda_device(int64_t dev_id);
/* Grid utilities */
void register_grid(const map_key_t& key,
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
}
void delete_grid(const map_key_t& key) {
id_grid_map.erase(key);
}
/* Function utilities */
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_t& opt,
const rt::function::autotune_vals_t& autotune_vals,
const std::vector<std::string>& autotune_key) {
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
driver::stream* stream;
if(dev_id >= 0){
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
}
else{
device = new triton::driver::host_device();
stream = new triton::driver::host_stream();
}
tt_devices[dev_id].reset(device);
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for(auto x: opt->defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
}
void delete_fn(int op_id) {
id_fn_map.erase(op_id);
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
opt_cache_.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
rt::function* fn = id_fn_map.at(op_id).get();
(*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
// for(size_t n = 0; n < constant_names.size(); n++){
// const torch::Tensor& x = constant_vals[n];
// fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
rt::function* fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t& opt){
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for(auto x: opt.defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t*>());
};
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
std::string extract_kernels(const std::string& str, const std::vector<std::string>& names) {
if(names.empty())
return str;
// search for all regex matches of kernel_regex in str
std::smatch matches;
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
std::sregex_iterator it(str.begin(), str.end(), regex);
std::sregex_iterator end;
std::vector<std::tuple<std::string, int, int>> kernels;
for (; it != end; ++it) {
int pos = it->position();
int len = it->length();
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
for(const std::string& name: names) {
// check that str matches any string in kernels using std::any_of
auto pred = [&name](const std::tuple<std::string, int, int>& t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// extract functions
std::string ret;
for(const auto& k: kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if(std::find(names.begin(), names.end(), name) != names.end()){
std::string def = str.substr(pos, str.size() - pos);
int count, pos;
// skip over declaration
count = 1;
pos = def.find('(');
while(!(def[pos++] == ')' && count == 0) && pos < def.size()){
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
count = 1;
pos = def.find('{', pos);
while(!(def[pos++] == '}' && count == 0) && pos < def.size()){
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
}
return ret;
}
void init_superblocking(pybind11::module &m);
void init_launch(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
// bindings for triton classes
pybind11::enum_<rt::arg_type>(m, "arg_type")
.value("int1" , rt::INT1_T)
.value("int8" , rt::INT8_T)
.value("int16" , rt::INT16_T)
.value("int32" , rt::INT32_T)
.value("int64" , rt::INT64_T)
.value("half" , rt::HALF_T)
.value("float" , rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx" , rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def(pybind11::init<>())
.def_readwrite("defines" , &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
m.def("extract_kernels", &extract_kernels);
m.def("get_fn_signature", &get_fn_signature);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("cleanup", &cleanup);
m.def("autotune", &autotune, pybind11::return_value_policy::reference);
m.def("launch_kernel", &launch_kernel);
init_launch(m);
init_superblocking(m);
}

12
python/src/main.cc Normal file
View File

@@ -0,0 +1,12 @@
#include <pybind11/pybind11.h>
void init_superblocking(pybind11::module &m);
void init_torch_utils(pybind11::module &m);
void init_triton(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
init_triton(m);
init_torch_utils(m);
init_superblocking(m);
}

View File

@@ -1,83 +0,0 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
int64_t cdiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
int64_t largest_pow2_divisor(int64_t a){
if(a % 8 == 0) return 8;
if(a % 4 == 0) return 4;
if(a % 2 == 0) return 2;
return 1;
}
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void cuda_set_device(int64_t dev_id) {
if(dev_id >= 0)
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
void init_launch(pybind11::module &m) {
m.def("cuda_set_device", &cuda_set_device);
m.def("cuda_empty_like", &cuda_empty_like);
m.def("largest_pow2_divisor", &largest_pow2_divisor);
m.def("cdiv", &cdiv);
m.def("cdiv_sum", &cdiv_sum);
m.def("synchronize", &synchronize);
}

66
python/src/torch/utils.cc Normal file
View File

@@ -0,0 +1,66 @@
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
namespace torch_utils {
void register_device(int64_t dev_id) {
if (tt_devices.find(dev_id) != tt_devices.end())
return;
triton::driver::device *device;
if (dev_id >= 0) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
device = new triton::driver::cu_device(handle, false);
} else
device = new triton::driver::host_device();
tt_devices[dev_id].reset(device);
}
void register_stream(int64_t dev_id) {
if (tt_streams.find(dev_id) != tt_streams.end())
return;
triton::driver::stream *stream;
if (dev_id >= 0) {
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
stream = new triton::driver::cu_stream(handle, false);
} else
stream = new triton::driver::host_stream();
tt_streams[dev_id].reset(stream);
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
}
void set_device(int64_t dev_id) {
if (dev_id >= 0)
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
torch::Tensor move_out_of_pool(torch::Tensor x) {
if (x.nbytes() == 0)
return torch::empty_like(x);
void *data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
ret.copy_(x);
return ret;
}
} // namespace torch_utils
void init_torch_utils(pybind11::module &m) {
pybind11::module subm = m.def_submodule("torch_utils");
subm.def("register_device", &torch_utils::register_device);
subm.def("register_stream", &torch_utils::register_stream);
subm.def("set_device", &torch_utils::set_device);
subm.def("synchronize", &torch_utils::synchronize);
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
}

169
python/src/triton.cc Normal file
View File

@@ -0,0 +1,169 @@
#include "triton/driver/stream.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/function.h"
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <regex>
#include <string>
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
namespace lng = triton::lang;
std::unordered_map<const rt::options_t *, pybind11::object> opt_cache_;
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
extern std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
/* Function utilities */
void register_fn(int op_id, int dev_id,
const std::string &src, const rt::options_t &opt,
const rt::function::autotune_vals_t &autotune_vals,
const std::vector<std::string> &autotune_key) {
if (id_fn_map.find(op_id) == id_fn_map.end()) {
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for (const auto &k : id_fn_map[op_id]->get_kernels()) {
const rt::options_t *opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for (auto x : opt->defines)
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
}
void delete_fn(int op_id) {
id_fn_map.erase(op_id);
}
void cleanup() {
id_fn_map.clear();
opt_cache_.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
}
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) {
rt::function *fn = id_fn_map.at(op_id).get();
(*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
}
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) {
rt::function *fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t &opt) {
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for (auto x : opt.defines)
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t *>());
};
rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
if (names.empty())
return str;
// search for all regex matches of kernel_regex in str
std::smatch matches;
std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})");
std::sregex_iterator it(str.begin(), str.end(), regex);
std::sregex_iterator end;
std::vector<std::tuple<std::string, int, int>> kernels;
for (; it != end; ++it) {
int pos = it->position();
int len = it->length();
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
for (const std::string &name : names) {
// check that str matches any string in kernels using std::any_of
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if (!found)
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// extract functions
std::string ret;
for (const auto &k : kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if (std::find(names.begin(), names.end(), name) != names.end()) {
std::string def = str.substr(pos, str.size() - pos);
int count, pos;
// skip over declaration
count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
}
return ret;
}
void init_triton(pybind11::module &m) {
pybind11::module subm = m.def_submodule("triton");
// bindings for triton classes
pybind11::enum_<rt::arg_type>(subm, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
.value("int32", rt::INT32_T)
.value("int64", rt::INT64_T)
.value("half", rt::HALF_T)
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(subm, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(subm, "options", pybind11::dynamic_attr())
.def(pybind11::init<>())
.def_readwrite("defines", &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
// hooks into triton constructs since frameworks may not use pybind11
subm.def("extract_kernels", &extract_kernels);
subm.def("get_fn_signature", &get_fn_signature);
subm.def("register_fn", &register_fn);
subm.def("delete_fn", &delete_fn);
subm.def("make_op_id", &make_op_id);
subm.def("cleanup", &cleanup);
subm.def("autotune", &autotune, pybind11::return_value_policy::reference);
subm.def("launch_kernel", &launch_kernel);
}

View File

@@ -1,16 +1,11 @@
# TODO: torch needs to be imported first # TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer` # or pybind11 shows `munmap_chunk(): invalid pointer`
import torch import torch
# submodules
# libtriton resources
import atexit
import triton._C.libtriton as libtriton
@atexit.register
def cleanup():
libtriton.cleanup()
from .kernel import * from .kernel import *
from . import ops from . import ops
# C bindings
import triton._C.libtriton.torch_utils as _torch_utils
# version # version
__version__ = '1.0.0' __version__ = '1.0.0'

View File

@@ -1,18 +1,24 @@
import triton._C.libtriton as libtriton
import os import os
import time import struct
from struct import pack
import torch import torch
# C bindings
import triton._C.libtriton.triton as _triton
import triton._C.libtriton.torch_utils as _torch_utils
# Make sure internal C resources are cleaned up upon exit
import atexit
@atexit.register
def cleanup():
_triton.cleanup()
codes = { codes = {
libtriton.arg_type.int1: 'B', _triton.arg_type.int1: 'B',
libtriton.arg_type.int8: 'B', _triton.arg_type.int8: 'B',
libtriton.arg_type.int32: 'I', _triton.arg_type.int32: 'I',
libtriton.arg_type.int64: 'Q', _triton.arg_type.int64: 'Q',
libtriton.arg_type.half: 'H', _triton.arg_type.half: 'H',
libtriton.arg_type.float: 'f', _triton.arg_type.float: 'f',
libtriton.arg_type.double: 'd', _triton.arg_type.double: 'd',
libtriton.arg_type.buffer: 'P' _triton.arg_type.buffer: 'P'
} }
def th_to_triton(obj): def th_to_triton(obj):
@@ -30,17 +36,17 @@ def th_to_triton(obj):
return str(obj) return str(obj)
def cdiv(a, b): def cdiv(a, b):
return libtriton.cdiv(a, b) return (a + b - 1) // b
def synchronize(device): def synchronize(device):
dev_id = device.index dev_id = device.index
dev_id = -1 if dev_id is None else dev_id dev_id = -1 if dev_id is None else dev_id
libtriton.synchronize(dev_id) _torch_utils.synchronize(dev_id)
def read(path, kernel_names=[]): def read(path, kernel_names=[]):
with open(path, 'r') as f: with open(path, 'r') as f:
source = f.read() source = f.read()
source = libtriton.extract_kernels(source, kernel_names) source = _triton.extract_kernels(source, kernel_names)
return source return source
class kernel: class kernel:
@@ -50,7 +56,7 @@ class kernel:
if src == '': if src == '':
raise ValueError('Kernel source code is empty') raise ValueError('Kernel source code is empty')
self.src = src self.src = src
self.opt = libtriton.options() self.opt = _triton.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps self.opt.num_warps = num_warps
# device # device
@@ -59,39 +65,25 @@ class kernel:
self.device = torch.cuda.current_device() if device.index is None else device.index self.device = torch.cuda.current_device() if device.index is None else device.index
if device.type == 'cpu': if device.type == 'cpu':
self.device = -1 self.device = -1
_torch_utils.register_device(self.device)
_torch_utils.register_stream(self.device)
# C++ function wrapper # C++ function wrapper
self.op_id = libtriton.make_op_id() self.op_id = _triton.make_op_id()
libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode # debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ self.is_debug = 'TRITON_DEBUG' in os.environ
# signature # signature
arg_types = libtriton.get_fn_signature(self.op_id) arg_types = _triton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types]) self.tys = ''.join([codes[x] for x in arg_types])
def __call__(self, *args, grid): def __call__(self, *args, grid):
# debug mode (initialize) _torch_utils.set_device(self.device)
if self.is_debug:
_args = args
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
args[i] = libtriton.cuda_empty_like(args[i])
args[i].copy_(_args[i])
# initialize cuda device if necessary
libtriton.cuda_set_device(self.device)
# pack parameters into a byte buffer # pack parameters into a byte buffer
params = pack(self.tys, *args) params = struct.pack(self.tys, *args)
# auto-tune if necessary opt = _triton.autotune(self.op_id, self.device, params, grid)
opt = libtriton.autotune(self.op_id, self.device, params, grid)
# run kernel # run kernel
grid = grid(opt) grid = grid(opt)
grid_0 = grid[0] grid_0 = grid[0]
grid_1 = 1 if len(grid) < 2 else grid[1] grid_1 = 1 if len(grid) < 2 else grid[1]
grid_2 = 1 if len(grid) < 3 else grid[2] grid_2 = 1 if len(grid) < 3 else grid[2]
libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
# debug mode (finalize)
if self.is_debug:
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
_args[i].copy_(args[i].clone())
args = _args