[python] basic tensorflow wrapper working
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
@@ -22,15 +21,15 @@ using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
typedef std::vector<tensorflow::Tensor> tf_grid_t;
|
||||
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
|
||||
|
||||
/* TF triton op properties */
|
||||
std::map<size_t, tf_grid_fn_ty> id_grid_map;
|
||||
|
||||
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
std::map<size_t, rt::function*> id_fn_map;
|
||||
std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
void register_grid(size_t id,
|
||||
const tf_grid_fn_ty& grid_fn) {
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id] = grid_fn;
|
||||
}
|
||||
|
||||
@@ -43,6 +42,17 @@ size_t register_fn(const std::string& src,
|
||||
return id;
|
||||
}
|
||||
|
||||
size_t make_scalar_id() {
|
||||
return i64scalar_map.size();
|
||||
}
|
||||
|
||||
bool has_scalar(size_t id) {
|
||||
return i64scalar_map.find(id) != i64scalar_map.end();
|
||||
}
|
||||
|
||||
int64_t retrieve_scalar(size_t id) {
|
||||
return i64scalar_map.at(id);
|
||||
}
|
||||
|
||||
/* TF source-code generation */
|
||||
|
||||
@@ -112,13 +122,6 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
||||
}
|
||||
|
||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " rt::function::grid_fn_ty grid_fn = [&](const rt::function::options_t& opt) {" << std::endl;
|
||||
os << " auto tmp = id_grid_map.at(id_)(opt);" << std::endl;
|
||||
os << " rt::grid_t result;" << std::endl;
|
||||
os << " for(auto& x: tmp) { result.push_back(x.scalar<int>()()); }" << std::endl;
|
||||
os << " return result; }; " << std::endl;
|
||||
|
||||
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
@@ -129,7 +132,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, grid_fn, stream); \n";
|
||||
os << "}, id_grid_map.at(id_), stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -239,9 +242,7 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::vector<tensorflow::Tensor> tf_grid_t;
|
||||
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
|
||||
extern std::map<size_t, tf_grid_fn_ty> id_grid_map;
|
||||
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
extern std::map<size_t, rt::function*> id_fn_map;
|
||||
|
||||
|
||||
@@ -307,7 +308,8 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
// bindings for triton classes
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("D", &options_t::D<int>);
|
||||
.def("d", &options_t::D<int>)
|
||||
.def_readonly("num_warps", &options_t::num_warps);
|
||||
|
||||
pybind11::class_<options_space_t>(m, "options_space")
|
||||
.def(pybind11::init<>())
|
||||
@@ -317,4 +319,6 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("make_scalar_id", &make_scalar_id);
|
||||
m.def("retrieve_scalar", &retrieve_scalar);
|
||||
}
|
||||
|
Reference in New Issue
Block a user