2019-08-25 21:26:09 -07:00
|
|
|
|
#include <pybind11/pybind11.h>
|
2020-01-16 12:09:50 -05:00
|
|
|
|
#include <pybind11/buffer_info.h>
|
2019-08-15 20:50:10 -07:00
|
|
|
|
#include <pybind11/stl.h>
|
2019-08-25 21:26:09 -07:00
|
|
|
|
#include <pybind11/functional.h>
|
2019-08-15 20:50:10 -07:00
|
|
|
|
#include <string>
|
|
|
|
|
|
#include "triton/runtime/function.h"
|
2020-02-20 20:09:33 -05:00
|
|
|
|
#include "triton/runtime/arg.h"
|
2019-08-25 21:26:09 -07:00
|
|
|
|
#include "triton/lang/code_gen.h"
|
|
|
|
|
|
#include "triton/lang/parser.h"
|
|
|
|
|
|
#include "triton/lang/cpp.h"
|
2019-08-15 20:50:10 -07:00
|
|
|
|
#include "triton/ir/module.h"
|
|
|
|
|
|
#include "triton/ir/function.h"
|
|
|
|
|
|
|
|
|
|
|
|
using namespace triton;
|
|
|
|
|
|
|
2019-08-25 21:26:09 -07:00
|
|
|
|
namespace rt = triton::runtime;
|
|
|
|
|
|
|
2020-05-04 18:33:56 -04:00
|
|
|
|
typedef std::pair<size_t, size_t> map_key_t;
|
|
|
|
|
|
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
|
|
|
|
|
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
2019-10-27 15:32:34 -04:00
|
|
|
|
std::map<size_t, double> fp64scalar_map;
|
2019-08-26 16:53:49 -07:00
|
|
|
|
std::map<size_t, int64_t> i64scalar_map;
|
2019-08-25 21:26:09 -07:00
|
|
|
|
|
2019-10-27 15:32:34 -04:00
|
|
|
|
/* Grid map */
|
|
|
|
|
|
|
2020-05-04 18:33:56 -04:00
|
|
|
|
void register_grid(const map_key_t& key,
|
2019-08-26 16:53:49 -07:00
|
|
|
|
const rt::function::grid_fn_ty& grid_fn) {
|
2020-05-04 18:33:56 -04:00
|
|
|
|
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
|
2019-08-25 21:26:09 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
2020-05-04 18:33:56 -04:00
|
|
|
|
void delete_grid(const map_key_t& key) {
|
|
|
|
|
|
id_grid_map.erase(key);
|
2019-09-04 03:12:23 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2019-10-27 15:32:34 -04:00
|
|
|
|
/* Function map */
|
|
|
|
|
|
|
2020-05-04 18:33:56 -04:00
|
|
|
|
void register_fn(const map_key_t& key,
|
2020-02-20 20:09:33 -05:00
|
|
|
|
const std::string& src,
|
|
|
|
|
|
const rt::function::options_space_t& opt,
|
|
|
|
|
|
const std::string &cache_ref) {
|
2020-05-04 18:33:56 -04:00
|
|
|
|
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
|
2019-08-26 17:21:09 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
2020-05-04 18:33:56 -04:00
|
|
|
|
void delete_fn(const map_key_t& key) {
|
|
|
|
|
|
id_fn_map.erase(key);
|
2019-09-04 03:12:23 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
2020-05-04 18:33:56 -04:00
|
|
|
|
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
2020-01-16 12:09:50 -05:00
|
|
|
|
pybind11::buffer_info info = data.request();
|
2020-05-04 18:33:56 -04:00
|
|
|
|
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
2020-01-16 12:09:50 -05:00
|
|
|
|
}
|
|
|
|
|
|
|
2019-09-04 03:12:23 -04:00
|
|
|
|
void cleanup() {
|
|
|
|
|
|
id_grid_map.clear();
|
|
|
|
|
|
id_fn_map.clear();
|
|
|
|
|
|
i64scalar_map.clear();
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2019-08-26 17:21:09 -07:00
|
|
|
|
size_t make_op_id() {
|
|
|
|
|
|
return id_fn_map.size();
|
2019-08-25 21:26:09 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
2020-02-20 20:09:33 -05:00
|
|
|
|
|
2019-10-27 15:32:34 -04:00
|
|
|
|
/* TF scalar wrapper */
|
2019-08-26 16:53:49 -07:00
|
|
|
|
size_t make_scalar_id() {
|
2019-10-27 15:32:34 -04:00
|
|
|
|
size_t ret = i64scalar_map.size();
|
|
|
|
|
|
i64scalar_map[ret] = int64_t();
|
|
|
|
|
|
return ret;
|
2019-08-26 16:53:49 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool has_scalar(size_t id) {
|
|
|
|
|
|
return i64scalar_map.find(id) != i64scalar_map.end();
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int64_t retrieve_scalar(size_t id) {
|
|
|
|
|
|
return i64scalar_map.at(id);
|
|
|
|
|
|
}
|
2019-08-25 21:26:09 -07:00
|
|
|
|
|
2019-08-29 17:06:59 -07:00
|
|
|
|
void make_module(const std::string& src, ir::module* ir,
|
|
|
|
|
|
const runtime::function::options_space_t& opt) {
|
2019-10-29 17:29:11 -04:00
|
|
|
|
std::string copy = triton::runtime::function::preheader() + src;
|
2019-08-25 21:26:09 -07:00
|
|
|
|
// pre-process
|
|
|
|
|
|
TokenSequence tokens;
|
2019-08-29 17:06:59 -07:00
|
|
|
|
Preprocessor cpp(©, true);
|
2019-08-25 21:26:09 -07:00
|
|
|
|
for(auto it: opt.defines){
|
|
|
|
|
|
cpp.AddMacro(it.first, &it.second[0]);
|
|
|
|
|
|
}
|
|
|
|
|
|
cpp.Process(tokens);
|
|
|
|
|
|
// parse
|
|
|
|
|
|
Parser parser(tokens);
|
|
|
|
|
|
parser.Parse();
|
2019-08-29 17:06:59 -07:00
|
|
|
|
Generator gen(&parser);
|
|
|
|
|
|
gen.Gen(ir);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2020-02-20 20:09:33 -05:00
|
|
|
|
/* Function signature */
|
|
|
|
|
|
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
|
|
|
|
|
const runtime::function::options_space_t& opt) {
|
|
|
|
|
|
// triton-ir code-gen
|
|
|
|
|
|
ir::context ctx;
|
|
|
|
|
|
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
|
|
|
|
|
make_module(src, &*ir, opt);
|
|
|
|
|
|
// function
|
|
|
|
|
|
ir::function* fn = ir->get_function_list().front();
|
|
|
|
|
|
// extract signature
|
|
|
|
|
|
std::vector<rt::arg_type> ret;
|
|
|
|
|
|
ir::function_type* ty = fn->get_fn_type();
|
|
|
|
|
|
for(size_t i = 0; i < ty->get_num_params(); i++)
|
|
|
|
|
|
ret.push_back(rt::convert(ty->get_param_ty(i)));
|
|
|
|
|
|
return ret;
|
|
|
|
|
|
}
|
2019-08-29 17:06:59 -07:00
|
|
|
|
|
2019-08-25 21:26:09 -07:00
|
|
|
|
typedef triton::runtime::function::options_t options_t;
|
|
|
|
|
|
typedef triton::runtime::function::options_space_t options_space_t;
|
2019-08-15 20:50:10 -07:00
|
|
|
|
|
|
|
|
|
|
PYBIND11_MODULE(libtriton, m) {
|
|
|
|
|
|
m.doc() = "Python bindings to the C++ Triton API";
|
2019-08-25 21:26:09 -07:00
|
|
|
|
|
|
|
|
|
|
// bindings for triton classes
|
2020-02-20 20:09:33 -05:00
|
|
|
|
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
|
|
|
|
|
.value("int1", rt::INT1_T)
|
|
|
|
|
|
.value("int8", rt::INT8_T)
|
|
|
|
|
|
.value("int16", rt::INT16_T)
|
|
|
|
|
|
.value("int32", rt::INT32_T)
|
|
|
|
|
|
.value("int64", rt::INT64_T)
|
|
|
|
|
|
.value("half", rt::HALF_T)
|
|
|
|
|
|
.value("float", rt::FLOAT_T)
|
|
|
|
|
|
.value("double", rt::DOUBLE_T)
|
|
|
|
|
|
.value("buffer", rt::BUFFER_T);
|
|
|
|
|
|
|
2019-08-25 21:26:09 -07:00
|
|
|
|
pybind11::class_<options_t>(m, "options")
|
|
|
|
|
|
.def(pybind11::init<>())
|
2019-08-26 16:53:49 -07:00
|
|
|
|
.def("d", &options_t::D<int>)
|
|
|
|
|
|
.def_readonly("num_warps", &options_t::num_warps);
|
2019-08-25 21:26:09 -07:00
|
|
|
|
|
|
|
|
|
|
pybind11::class_<options_space_t>(m, "options_space")
|
|
|
|
|
|
.def(pybind11::init<>())
|
|
|
|
|
|
.def_readwrite("defines", &options_space_t::defines)
|
|
|
|
|
|
.def_readwrite("num_warps", &options_space_t::num_warps);
|
|
|
|
|
|
|
|
|
|
|
|
// hooks into triton constructs since frameworks may not use pybind11
|
2020-02-20 20:09:33 -05:00
|
|
|
|
m.def("get_fn_signature", &get_fn_signature);
|
2019-08-25 21:26:09 -07:00
|
|
|
|
m.def("register_grid", ®ister_grid);
|
2019-09-04 03:12:23 -04:00
|
|
|
|
m.def("delete_grid", &delete_grid);
|
2019-08-25 21:26:09 -07:00
|
|
|
|
m.def("register_fn", ®ister_fn);
|
2020-01-16 12:09:50 -05:00
|
|
|
|
m.def("register_cst", ®ister_cst);
|
2019-09-04 03:12:23 -04:00
|
|
|
|
m.def("delete_fn", &delete_fn);
|
2019-08-26 17:21:09 -07:00
|
|
|
|
m.def("make_op_id", &make_op_id);
|
2019-08-26 16:53:49 -07:00
|
|
|
|
m.def("make_scalar_id", &make_scalar_id);
|
2019-09-04 03:12:23 -04:00
|
|
|
|
m.def("retrieve_scalar", &retrieve_scalar);
|
|
|
|
|
|
m.def("cleanup", &cleanup);
|
2019-08-26 17:21:09 -07:00
|
|
|
|
;
|
2019-08-15 20:50:10 -07:00
|
|
|
|
}
|