[python] basic tensorflow wrapper working

This commit is contained in:
Philippe Tillet
2019-08-26 16:53:49 -07:00
parent 0e0399f866
commit 4075949f80
26 changed files with 702 additions and 968 deletions

View File

@@ -4,7 +4,6 @@
#include <string>
#include <regex>
#include <algorithm>
#include "tensorflow/core/framework/tensor.h"
#include "triton/codegen/selection/selection.h"
#include "triton/runtime/function.h"
#include "triton/lang/code_gen.h"
@@ -22,15 +21,15 @@ using namespace triton;
namespace rt = triton::runtime;
typedef std::vector<tensorflow::Tensor> tf_grid_t;
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
/* TF triton op properties */
std::map<size_t, tf_grid_fn_ty> id_grid_map;
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
std::map<size_t, rt::function*> id_fn_map;
std::map<size_t, int64_t> i64scalar_map;
void register_grid(size_t id,
const tf_grid_fn_ty& grid_fn) {
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id] = grid_fn;
}
@@ -43,6 +42,17 @@ size_t register_fn(const std::string& src,
return id;
}
size_t make_scalar_id() {
return i64scalar_map.size();
}
bool has_scalar(size_t id) {
return i64scalar_map.find(id) != i64scalar_map.end();
}
int64_t retrieve_scalar(size_t id) {
return i64scalar_map.at(id);
}
/* TF source-code generation */
@@ -112,13 +122,6 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
}
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " rt::function::grid_fn_ty grid_fn = [&](const rt::function::options_t& opt) {" << std::endl;
os << " auto tmp = id_grid_map.at(id_)(opt);" << std::endl;
os << " rt::grid_t result;" << std::endl;
os << " for(auto& x: tmp) { result.push_back(x.scalar<int>()()); }" << std::endl;
os << " return result; }; " << std::endl;
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
@@ -129,7 +132,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
os << ", ";
os << name;
}
os << "}, grid_fn, stream); \n";
os << "}, id_grid_map.at(id_), stream); \n";
}
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -239,9 +242,7 @@ using GPUDevice = Eigen::GpuDevice;
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::vector<tensorflow::Tensor> tf_grid_t;
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
extern std::map<size_t, tf_grid_fn_ty> id_grid_map;
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
extern std::map<size_t, rt::function*> id_fn_map;
@@ -307,7 +308,8 @@ PYBIND11_MODULE(libtriton, m) {
// bindings for triton classes
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("D", &options_t::D<int>);
.def("d", &options_t::D<int>)
.def_readonly("num_warps", &options_t::num_warps);
pybind11::class_<options_space_t>(m, "options_space")
.def(pybind11::init<>())
@@ -317,4 +319,6 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("register_grid", &register_grid);
m.def("register_fn", &register_fn);
m.def("make_scalar_id", &make_scalar_id);
m.def("retrieve_scalar", &retrieve_scalar);
}