diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f33eda9e..a65f3fc41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,8 +33,8 @@ endif() if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file - set(TORCH_SRC torch/launch.cc torch/superblock.cc) - set(PYTHON_SRC bindings.cc ${TORCH_SRC}) + file(GLOB_RECURSE TORCH_SRC torch/*.cc) + set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC}) set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}") include_directories("." ${PYTHON_INCLUDE_DIRS}) link_directories(${PYTHON_LINK_DIRS}) diff --git a/python/src/bindings.cc b/python/src/bindings.cc deleted file mode 100644 index 069555d91..000000000 --- a/python/src/bindings.cc +++ /dev/null @@ -1,219 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" -#include "triton/runtime/arg.h" -#include "triton/lang/code_gen.h" -#include "triton/lang/parser.h" -#include "triton/lang/cpp.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" - -using namespace triton; -namespace rt = triton::runtime; -namespace drv = triton::driver; -namespace lng = triton::lang; - -typedef std::pair map_key_t; - -std::map> id_grid_map; -std::map> id_fn_map; -std::map> tt_devices; -std::map> tt_streams; -std::unordered_map opt_cache_; -extern CUstream torch_get_cuda_stream(int64_t dev_id); -extern CUdevice torch_get_cuda_device(int64_t dev_id); - - -/* Grid utilities */ - -void register_grid(const map_key_t& key, - const rt::function::grid_fn_ty& grid_fn) { - id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn)); -} - -void delete_grid(const map_key_t& key) { - id_grid_map.erase(key); -} - -/* Function utilities */ - -void register_fn(int op_id, - int dev_id, - const std::string& src, - const rt::options_t& opt, - const rt::function::autotune_vals_t& autotune_vals, - const std::vector& autotune_key) { - if(tt_devices.find(dev_id) == tt_devices.end()) { - driver::device* device; - driver::stream* stream; - if(dev_id >= 0){ - device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false); - stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false); - } - else{ - device = new triton::driver::host_device(); - stream = new triton::driver::host_stream(); - } - tt_devices[dev_id].reset(device); - tt_streams[dev_id].reset(stream); - } - if(id_fn_map.find(op_id) == id_fn_map.end()){ - id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key)); - } - for(const auto& k: id_fn_map[op_id]->get_kernels()){ - const rt::options_t* opt = &k.first; - pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference); - for(auto x: opt->defines) - if(std::all_of(x.second.begin(), x.second.end(), ::isdigit)) - obj.attr(x.first.c_str()) = std::stoi(x.second); - opt_cache_[&k.second->opt] = obj; - } - -} - -void delete_fn(int op_id) { - id_fn_map.erase(op_id); -} - - -void cleanup() { - id_grid_map.clear(); - id_fn_map.clear(); - opt_cache_.clear(); -} - -size_t make_op_id() { - return id_fn_map.size(); -} - -std::vector get_fn_signature(size_t op_id) { - return id_fn_map[op_id]->get_kernels()[0].second->get_sig(); -} - -void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){ - rt::function* fn = id_fn_map.at(op_id).get(); - (*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]); - - // for(size_t n = 0; n < constant_names.size(); n++){ - // const torch::Tensor& x = constant_vals[n]; - // fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size()); -} - -pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){ - rt::function* fn = id_fn_map.at(op_id).get(); - auto wrapper = [&grid](const rt::options_t& opt){ - pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference); - for(auto x: opt.defines) - if(std::all_of(x.second.begin(), x.second.end(), ::isdigit)) - obj.attr(x.first.c_str()) = std::stoi(x.second); - return grid(*obj.cast()); - }; - rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]); - return opt_cache_.at(&kernel->opt); -} - - -std::string extract_kernels(const std::string& str, const std::vector& names) { - if(names.empty()) - return str; - // search for all regex matches of kernel_regex in str - std::smatch matches; - std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})"); - std::sregex_iterator it(str.begin(), str.end(), regex); - std::sregex_iterator end; - std::vector> kernels; - for (; it != end; ++it) { - int pos = it->position(); - int len = it->length(); - std::string name = it->str(1); - kernels.push_back(std::make_tuple(name, pos, len)); - } - - for(const std::string& name: names) { - // check that str matches any string in kernels using std::any_of - auto pred = [&name](const std::tuple& t) { return std::get<0>(t) == name; }; - bool found = std::any_of(kernels.begin(), kernels.end(), pred); - if(!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); - } - - - // extract functions - std::string ret; - for(const auto& k: kernels) { - std::string name; - int pos, len; - std::tie(name, pos, len) = k; - if(std::find(names.begin(), names.end(), name) != names.end()){ - std::string def = str.substr(pos, str.size() - pos); - int count, pos; - // skip over declaration - count = 1; - pos = def.find('('); - while(!(def[pos++] == ')' && count == 0) && pos < def.size()){ - count += def[pos] == '('; - count -= def[pos] == ')'; - } - // skip over definition - count = 1; - pos = def.find('{', pos); - while(!(def[pos++] == '}' && count == 0) && pos < def.size()){ - count += def[pos] == '{'; - count -= def[pos] == '}'; - } - ret += def.substr(0, pos); - ret += '\n'; - } - } - - return ret; -} - - - -void init_superblocking(pybind11::module &m); -void init_launch(pybind11::module &m); - -PYBIND11_MODULE(libtriton, m) { - m.doc() = "Python bindings to the C++ Triton API"; - - // bindings for triton classes - pybind11::enum_(m, "arg_type") - .value("int1" , rt::INT1_T) - .value("int8" , rt::INT8_T) - .value("int16" , rt::INT16_T) - .value("int32" , rt::INT32_T) - .value("int64" , rt::INT64_T) - .value("half" , rt::HALF_T) - .value("float" , rt::FLOAT_T) - .value("double", rt::DOUBLE_T) - .value("buffer", rt::BUFFER_T); - - pybind11::enum_(m, "asm_mode") - .value("ptx" , rt::ASM_NV_PTX) - .value("sass", rt::ASM_NV_SASS); - - pybind11::class_(m, "options", pybind11::dynamic_attr()) - .def(pybind11::init<>()) - .def_readwrite("defines" , &rt::options_t::defines) - .def_readwrite("num_warps", &rt::options_t::num_warps); - - // hooks into triton constructs since frameworks may not use pybind11 - m.def("extract_kernels", &extract_kernels); - m.def("get_fn_signature", &get_fn_signature); - m.def("register_grid", ®ister_grid); - m.def("delete_grid", &delete_grid); - m.def("register_fn", ®ister_fn); - m.def("delete_fn", &delete_fn); - m.def("make_op_id", &make_op_id); - m.def("cleanup", &cleanup); - m.def("autotune", &autotune, pybind11::return_value_policy::reference); - m.def("launch_kernel", &launch_kernel); - - init_launch(m); - init_superblocking(m); -} diff --git a/python/src/main.cc b/python/src/main.cc new file mode 100644 index 000000000..73394a30c --- /dev/null +++ b/python/src/main.cc @@ -0,0 +1,12 @@ +#include + +void init_superblocking(pybind11::module &m); +void init_torch_utils(pybind11::module &m); +void init_triton(pybind11::module &m); + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton(m); + init_torch_utils(m); + init_superblocking(m); +} diff --git a/python/src/torch/launch.cc b/python/src/torch/launch.cc deleted file mode 100644 index bd9461e0c..000000000 --- a/python/src/torch/launch.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments -// as a string constructed with struct.pack in python - -#include -#include -#include -#include "triton/driver/buffer.h" -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" -#include "triton/tools/bench.hpp" -#include "torch/script.h" -#include "ATen/cuda/CUDAContext.h" -#include -#include - - -namespace rt = triton::runtime; -namespace drv = triton::driver; - -typedef std::pair map_key_t; -extern std::map> id_grid_map; -extern std::map> id_fn_map; -extern std::map> tt_devices; -extern std::map> tt_streams; - - -int64_t cdiv(int64_t a, int64_t b) { - return (a + b - 1) / b; -} - -int64_t largest_pow2_divisor(int64_t a){ - if(a % 8 == 0) return 8; - if(a % 4 == 0) return 4; - if(a % 2 == 0) return 2; - return 1; -} - -int64_t cdiv_sum(torch::Tensor x, int64_t div){ - TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor") - auto _x = x.accessor(); - int64_t ret = 0; - for(size_t i = 0; i < x.size(0); i++) - ret += (_x[i] + div - 1) / div; - return ret; -} - -CUstream torch_get_cuda_stream(int64_t dev_id) { - return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); -} - -CUdeviceptr torch_get_cuda_device(int64_t dev_id) { - CUdevice ret; - triton::driver::dispatch::cuDeviceGet(&ret, dev_id); - return ret; -} - -void synchronize(int64_t dev_id) { - tt_streams[dev_id]->synchronize(); -} - -torch::Tensor cuda_empty_like(torch::Tensor x){ - if(x.nbytes() == 0) - return torch::empty_like(x); - void* data; - cudaMalloc(&data, x.nbytes()); - auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options()); - return ret; -} - -void cuda_set_device(int64_t dev_id) { - if(dev_id >= 0) - C10_CUDA_CHECK(cudaSetDevice(dev_id)); -} - - -void init_launch(pybind11::module &m) { - m.def("cuda_set_device", &cuda_set_device); - m.def("cuda_empty_like", &cuda_empty_like); - m.def("largest_pow2_divisor", &largest_pow2_divisor); - m.def("cdiv", &cdiv); - m.def("cdiv_sum", &cdiv_sum); - m.def("synchronize", &synchronize); -} \ No newline at end of file diff --git a/python/src/torch/utils.cc b/python/src/torch/utils.cc new file mode 100644 index 000000000..8deb8d4f2 --- /dev/null +++ b/python/src/torch/utils.cc @@ -0,0 +1,66 @@ + +#include "triton/driver/device.h" +#include "triton/driver/stream.h" +#include +#include +#include + +std::map> tt_devices; +std::map> tt_streams; + +namespace torch_utils { + +void register_device(int64_t dev_id) { + if (tt_devices.find(dev_id) != tt_devices.end()) + return; + triton::driver::device *device; + if (dev_id >= 0) { + CUdevice handle; + triton::driver::dispatch::cuDeviceGet(&handle, dev_id); + device = new triton::driver::cu_device(handle, false); + } else + device = new triton::driver::host_device(); + tt_devices[dev_id].reset(device); +} + +void register_stream(int64_t dev_id) { + if (tt_streams.find(dev_id) != tt_streams.end()) + return; + triton::driver::stream *stream; + if (dev_id >= 0) { + CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); + stream = new triton::driver::cu_stream(handle, false); + } else + stream = new triton::driver::host_stream(); + tt_streams[dev_id].reset(stream); +} + +void synchronize(int64_t dev_id) { + tt_streams[dev_id]->synchronize(); +} + +void set_device(int64_t dev_id) { + if (dev_id >= 0) + C10_CUDA_CHECK(cudaSetDevice(dev_id)); +} + +torch::Tensor move_out_of_pool(torch::Tensor x) { + if (x.nbytes() == 0) + return torch::empty_like(x); + void *data; + cudaMalloc(&data, x.nbytes()); + auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options()); + ret.copy_(x); + return ret; +} + +} // namespace torch_utils + +void init_torch_utils(pybind11::module &m) { + pybind11::module subm = m.def_submodule("torch_utils"); + subm.def("register_device", &torch_utils::register_device); + subm.def("register_stream", &torch_utils::register_stream); + subm.def("set_device", &torch_utils::set_device); + subm.def("synchronize", &torch_utils::synchronize); + subm.def("move_out_of_pool", &torch_utils::move_out_of_pool); +} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc new file mode 100644 index 000000000..c208e37a7 --- /dev/null +++ b/python/src/triton.cc @@ -0,0 +1,169 @@ +#include "triton/driver/stream.h" +#include "triton/ir/function.h" +#include "triton/ir/module.h" +#include "triton/lang/code_gen.h" +#include "triton/lang/cpp.h" +#include "triton/lang/parser.h" +#include "triton/runtime/arg.h" +#include "triton/runtime/function.h" +#include +#include +#include +#include +#include +#include + +using namespace triton; +namespace rt = triton::runtime; +namespace drv = triton::driver; +namespace lng = triton::lang; + +std::unordered_map opt_cache_; +std::map> id_fn_map; +extern std::map> tt_devices; +extern std::map> tt_streams; + +/* Function utilities */ + +void register_fn(int op_id, int dev_id, + const std::string &src, const rt::options_t &opt, + const rt::function::autotune_vals_t &autotune_vals, + const std::vector &autotune_key) { + if (id_fn_map.find(op_id) == id_fn_map.end()) { + id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key)); + } + for (const auto &k : id_fn_map[op_id]->get_kernels()) { + const rt::options_t *opt = &k.first; + pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference); + for (auto x : opt->defines) + if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) + obj.attr(x.first.c_str()) = std::stoi(x.second); + opt_cache_[&k.second->opt] = obj; + } +} + +void delete_fn(int op_id) { + id_fn_map.erase(op_id); +} + +void cleanup() { + id_fn_map.clear(); + opt_cache_.clear(); +} + +size_t make_op_id() { + return id_fn_map.size(); +} + +std::vector get_fn_signature(size_t op_id) { + return id_fn_map[op_id]->get_kernels()[0].second->get_sig(); +} + +// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments +// as a string constructed with struct.pack in python +void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) { + rt::function *fn = id_fn_map.at(op_id).get(); + (*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]); +} + +pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) { + rt::function *fn = id_fn_map.at(op_id).get(); + auto wrapper = [&grid](const rt::options_t &opt) { + pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference); + for (auto x : opt.defines) + if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) + obj.attr(x.first.c_str()) = std::stoi(x.second); + return grid(*obj.cast()); + }; + rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]); + return opt_cache_.at(&kernel->opt); +} + +std::string extract_kernels(const std::string &str, const std::vector &names) { + if (names.empty()) + return str; + // search for all regex matches of kernel_regex in str + std::smatch matches; + std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})"); + std::sregex_iterator it(str.begin(), str.end(), regex); + std::sregex_iterator end; + std::vector> kernels; + for (; it != end; ++it) { + int pos = it->position(); + int len = it->length(); + std::string name = it->str(1); + kernels.push_back(std::make_tuple(name, pos, len)); + } + + for (const std::string &name : names) { + // check that str matches any string in kernels using std::any_of + auto pred = [&name](const std::tuple &t) { return std::get<0>(t) == name; }; + bool found = std::any_of(kernels.begin(), kernels.end(), pred); + if (!found) + throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); + } + + // extract functions + std::string ret; + for (const auto &k : kernels) { + std::string name; + int pos, len; + std::tie(name, pos, len) = k; + if (std::find(names.begin(), names.end(), name) != names.end()) { + std::string def = str.substr(pos, str.size() - pos); + int count, pos; + // skip over declaration + count = 1; + pos = def.find('('); + while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { + count += def[pos] == '('; + count -= def[pos] == ')'; + } + // skip over definition + count = 1; + pos = def.find('{', pos); + while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { + count += def[pos] == '{'; + count -= def[pos] == '}'; + } + ret += def.substr(0, pos); + ret += '\n'; + } + } + + return ret; +} + +void init_triton(pybind11::module &m) { + pybind11::module subm = m.def_submodule("triton"); + // bindings for triton classes + pybind11::enum_(subm, "arg_type") + .value("int1", rt::INT1_T) + .value("int8", rt::INT8_T) + .value("int16", rt::INT16_T) + .value("int32", rt::INT32_T) + .value("int64", rt::INT64_T) + .value("half", rt::HALF_T) + .value("float", rt::FLOAT_T) + .value("double", rt::DOUBLE_T) + .value("buffer", rt::BUFFER_T); + + pybind11::enum_(subm, "asm_mode") + .value("ptx", rt::ASM_NV_PTX) + .value("sass", rt::ASM_NV_SASS); + + pybind11::class_(subm, "options", pybind11::dynamic_attr()) + .def(pybind11::init<>()) + .def_readwrite("defines", &rt::options_t::defines) + .def_readwrite("num_warps", &rt::options_t::num_warps); + + // hooks into triton constructs since frameworks may not use pybind11 + subm.def("extract_kernels", &extract_kernels); + subm.def("get_fn_signature", &get_fn_signature); + subm.def("register_fn", ®ister_fn); + subm.def("delete_fn", &delete_fn); + subm.def("make_op_id", &make_op_id); + subm.def("cleanup", &cleanup); + subm.def("autotune", &autotune, pybind11::return_value_policy::reference); + subm.def("launch_kernel", &launch_kernel); +} diff --git a/python/triton/__init__.py b/python/triton/__init__.py index b86e9b27b..becddcc32 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,16 +1,11 @@ # TODO: torch needs to be imported first # or pybind11 shows `munmap_chunk(): invalid pointer` import torch - -# libtriton resources -import atexit -import triton._C.libtriton as libtriton -@atexit.register -def cleanup(): - libtriton.cleanup() - +# submodules from .kernel import * from . import ops +# C bindings +import triton._C.libtriton.torch_utils as _torch_utils # version __version__ = '1.0.0' \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 2bb6a98be..ce2a7d579 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -1,18 +1,24 @@ -import triton._C.libtriton as libtriton import os -import time -from struct import pack +import struct import torch +# C bindings +import triton._C.libtriton.triton as _triton +import triton._C.libtriton.torch_utils as _torch_utils +# Make sure internal C resources are cleaned up upon exit +import atexit +@atexit.register +def cleanup(): + _triton.cleanup() codes = { - libtriton.arg_type.int1: 'B', - libtriton.arg_type.int8: 'B', - libtriton.arg_type.int32: 'I', - libtriton.arg_type.int64: 'Q', - libtriton.arg_type.half: 'H', - libtriton.arg_type.float: 'f', - libtriton.arg_type.double: 'd', - libtriton.arg_type.buffer: 'P' + _triton.arg_type.int1: 'B', + _triton.arg_type.int8: 'B', + _triton.arg_type.int32: 'I', + _triton.arg_type.int64: 'Q', + _triton.arg_type.half: 'H', + _triton.arg_type.float: 'f', + _triton.arg_type.double: 'd', + _triton.arg_type.buffer: 'P' } def th_to_triton(obj): @@ -30,17 +36,17 @@ def th_to_triton(obj): return str(obj) def cdiv(a, b): - return libtriton.cdiv(a, b) + return (a + b - 1) // b def synchronize(device): dev_id = device.index dev_id = -1 if dev_id is None else dev_id - libtriton.synchronize(dev_id) + _torch_utils.synchronize(dev_id) def read(path, kernel_names=[]): with open(path, 'r') as f: source = f.read() - source = libtriton.extract_kernels(source, kernel_names) + source = _triton.extract_kernels(source, kernel_names) return source class kernel: @@ -50,7 +56,7 @@ class kernel: if src == '': raise ValueError('Kernel source code is empty') self.src = src - self.opt = libtriton.options() + self.opt = _triton.options() self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} self.opt.num_warps = num_warps # device @@ -59,39 +65,25 @@ class kernel: self.device = torch.cuda.current_device() if device.index is None else device.index if device.type == 'cpu': self.device = -1 + _torch_utils.register_device(self.device) + _torch_utils.register_stream(self.device) # C++ function wrapper - self.op_id = libtriton.make_op_id() - libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) + self.op_id = _triton.make_op_id() + _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) # debug mode self.is_debug = 'TRITON_DEBUG' in os.environ # signature - arg_types = libtriton.get_fn_signature(self.op_id) + arg_types = _triton.get_fn_signature(self.op_id) self.tys = ''.join([codes[x] for x in arg_types]) def __call__(self, *args, grid): - # debug mode (initialize) - if self.is_debug: - _args = args - args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args] - for i in range(len(args)): - if isinstance(args[i], torch.Tensor): - args[i] = libtriton.cuda_empty_like(args[i]) - args[i].copy_(_args[i]) - # initialize cuda device if necessary - libtriton.cuda_set_device(self.device) + _torch_utils.set_device(self.device) # pack parameters into a byte buffer - params = pack(self.tys, *args) - # auto-tune if necessary - opt = libtriton.autotune(self.op_id, self.device, params, grid) + params = struct.pack(self.tys, *args) + opt = _triton.autotune(self.op_id, self.device, params, grid) # run kernel grid = grid(opt) grid_0 = grid[0] grid_1 = 1 if len(grid) < 2 else grid[1] grid_2 = 1 if len(grid) < 3 else grid[2] - libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) - # debug mode (finalize) - if self.is_debug: - for i in range(len(args)): - if isinstance(args[i], torch.Tensor): - _args[i].copy_(args[i].clone()) - args = _args \ No newline at end of file + _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) \ No newline at end of file