#include "triton/driver/stream.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/lang/code_gen.h" #include "triton/lang/cpp.h" #include "triton/lang/parser.h" #include "triton/runtime/arg.h" #include "triton/runtime/function.h" #include #include #include #include #include #include using namespace triton; namespace rt = triton::runtime; namespace drv = triton::driver; namespace lng = triton::lang; std::unordered_map opt_cache_; std::map> id_fn_map; extern std::map> tt_devices; extern std::map> tt_streams; /* Function utilities */ void register_fn(int op_id, int dev_id, const std::string &src, const rt::options_t &opt, const rt::function::autotune_vals_t &autotune_vals, const std::vector &autotune_key) { if (id_fn_map.find(op_id) == id_fn_map.end()) { id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key)); } for (const auto &k : id_fn_map[op_id]->get_kernels()) { const rt::options_t *opt = &k.first; pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference); for (auto x : opt->defines) if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) obj.attr(x.first.c_str()) = std::stoi(x.second); opt_cache_[&k.second->opt] = obj; } } void delete_fn(int op_id) { id_fn_map.erase(op_id); } void cleanup() { id_fn_map.clear(); opt_cache_.clear(); } size_t make_op_id() { return id_fn_map.size(); } std::vector get_fn_signature(size_t op_id) { return id_fn_map[op_id]->get_kernels()[0].second->get_sig(); } // Thanks to Scott Gray (OpenAI) for the idea to pass the arguments // as a string constructed with struct.pack in python void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) { rt::function *fn = id_fn_map.at(op_id).get(); (*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]); } pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) { rt::function *fn = id_fn_map.at(op_id).get(); auto wrapper = [&grid](const rt::options_t &opt) { pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference); for (auto x : opt.defines) if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) obj.attr(x.first.c_str()) = std::stoi(x.second); return grid(*obj.cast()); }; rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]); return opt_cache_.at(&kernel->opt); } std::string extract_kernels(const std::string &str, const std::vector &names) { if (names.empty()) return str; // search for all regex matches of kernel_regex in str std::smatch matches; std::regex regex(" *__global__ +void +([_a-zA-Z][_a-zA-Z0-9]{0,30})"); std::sregex_iterator it(str.begin(), str.end(), regex); std::sregex_iterator end; std::vector> kernels; for (; it != end; ++it) { int pos = it->position(); int len = it->length(); std::string name = it->str(1); kernels.push_back(std::make_tuple(name, pos, len)); } for (const std::string &name : names) { // check that str matches any string in kernels using std::any_of auto pred = [&name](const std::tuple &t) { return std::get<0>(t) == name; }; bool found = std::any_of(kernels.begin(), kernels.end(), pred); if (!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); } // extract functions std::string ret; for (const auto &k : kernels) { std::string name; int pos, len; std::tie(name, pos, len) = k; if (std::find(names.begin(), names.end(), name) != names.end()) { std::string def = str.substr(pos, str.size() - pos); int count, pos; // skip over declaration count = 1; pos = def.find('('); while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { count += def[pos] == '('; count -= def[pos] == ')'; } // skip over definition count = 1; pos = def.find('{', pos); while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { count += def[pos] == '{'; count -= def[pos] == '}'; } ret += def.substr(0, pos); ret += '\n'; } } return ret; } void init_triton(pybind11::module &m) { pybind11::module subm = m.def_submodule("triton"); // bindings for triton classes pybind11::enum_(subm, "arg_type") .value("int1", rt::INT1_T) .value("int8", rt::INT8_T) .value("int16", rt::INT16_T) .value("int32", rt::INT32_T) .value("int64", rt::INT64_T) .value("half", rt::HALF_T) .value("float", rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); pybind11::enum_(subm, "asm_mode") .value("ptx", rt::ASM_NV_PTX) .value("sass", rt::ASM_NV_SASS); pybind11::class_(subm, "options", pybind11::dynamic_attr()) .def(pybind11::init<>()) .def_readwrite("defines", &rt::options_t::defines) .def_readwrite("num_warps", &rt::options_t::num_warps); // hooks into triton constructs since frameworks may not use pybind11 subm.def("extract_kernels", &extract_kernels); subm.def("get_fn_signature", &get_fn_signature); subm.def("register_fn", ®ister_fn); subm.def("delete_fn", &delete_fn); subm.def("make_op_id", &make_op_id); subm.def("cleanup", &cleanup); subm.def("autotune", &autotune, pybind11::return_value_policy::reference); subm.def("launch_kernel", &launch_kernel); }