#include "triton/driver/stream.h" #include "triton/runtime/function.h" #include #include #include #include #include #include namespace py = pybind11; using namespace triton; namespace rt = triton::runtime; namespace drv = triton::driver; /*****************************************************************************/ /* Python bindings for triton::tools */ /*****************************************************************************/ /*! @brief Function for extracting kernels out of a given source-string This can be important to enable pre-processor macros (or tunable parameters) that should only be defined within the scope of a single kernel function */ std::string extract_kernels(const std::string &str, const std::vector &names) { if (names.empty()) return str; // search for all regex matches of kernel_regex in str std::smatch matches; std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})"); std::sregex_iterator it(str.begin(), str.end(), regex); std::sregex_iterator end; std::vector> kernels; for (; it != end; ++it) { int pos = it->position(); int len = it->length(); std::string name = it->str(1); kernels.push_back(std::make_tuple(name, pos, len)); } // check that all the kernels provided actually exist for (const std::string &name : names) { auto pred = [&name](const std::tuple &t) { return std::get<0>(t) == name; }; bool found = std::any_of(kernels.begin(), kernels.end(), pred); if (!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); } // simple parsing logic to extract the declaration and body of each specified kernel std::string ret; for (const auto &k : kernels) { std::string name; int pos, len; std::tie(name, pos, len) = k; if (std::find(names.begin(), names.end(), name) == names.end()) continue; std::string def = str.substr(pos, str.size() - pos); // skip over declaration // by finding matching ')' for first '(' int count = 1; pos = def.find('('); while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { count += def[pos] == '('; count -= def[pos] == ')'; } // skip over definition // by finding matching '{' for first '}' count = 1; pos = def.find('{', pos); while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { count += def[pos] == '{'; count -= def[pos] == '}'; } ret += def.substr(0, pos); ret += '\n'; } return ret; } void init_triton_tools(py::module &&m) { m.def("extract_kernels", &extract_kernels); } /*****************************************************************************/ /* Python bindings for triton::driver */ /*****************************************************************************/ void init_triton_driver(py::module &&m) { // base device py::class_(m, "device"); // cuda device py::class_(m, "cu_device") .def(py::init()); // host device py::class_(m, "host_device") .def(py::init<>()); // base stream py::class_(m, "stream"); // host stream py::class_(m, "host_stream") .def(py::init<>()); // cuda stream py::class_(m, "cu_stream") // py doesn't support opaque pointer (e.g., CUstream) so // we assume it has been converted to uint64_t .def(py::init([](uint64_t handle, bool take_ownership) { return std::unique_ptr(new driver::cu_stream((CUstream)handle, take_ownership)); })); } /*****************************************************************************/ /* Python bindings for triton::runtime */ /*****************************************************************************/ void init_triton_runtime(py::module &&m) { // argument type py::enum_(m, "arg_type") .value("int1", rt::INT1_T) .value("int8", rt::INT8_T) .value("int16", rt::INT16_T) .value("int32", rt::INT32_T) .value("int64", rt::INT64_T) .value("half", rt::HALF_T) .value("float", rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); // assembly mode py::enum_(m, "asm_mode") .value("ptx", rt::ASM_NV_PTX) .value("sass", rt::ASM_NV_SASS); // compilation options py::class_(m, "options", py::dynamic_attr()) .def(py::init<>()) .def_readwrite("defines", &rt::options_t::defines) .def_readwrite("num_warps", &rt::options_t::num_warps) .def("__getattr__", [](rt::options_t *opt, const std::string &name) { return opt->D(name); }); // kernel py::class_(m, "kernel") .def("__call__", &rt::kernel::operator()) .def_readonly("opt", &rt::kernel::opt); // tune conf py::class_(m, "config") .def(py::init, int>(), py::arg("defines") = std::map(), py::arg("num_warps")); // function py::class_(m, "function") .def(py::init &, const std::vector &>()) .def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal) .def("signature", &rt::function::get_signature); } void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); init_triton_driver(std::move(subm.def_submodule("driver"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_tools(std::move(subm.def_submodule("tools"))); }