From 0e0399f866625e4c20f82b26f1ab79dacaaa11fe Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 26 Aug 2019 11:00:00 -0700 Subject: [PATCH] more tests --- CMakeLists.txt | 2 +- python/src/tensorflow.cc | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15985cc87..0616d19f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,7 @@ endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -target_link_libraries(triton LLVM) +target_link_libraries(triton LLVM ${TF_LIBS}) # Warning level #if(MSVC) diff --git a/python/src/tensorflow.cc b/python/src/tensorflow.cc index 489c545ac..d843fecab 100644 --- a/python/src/tensorflow.cc +++ b/python/src/tensorflow.cc @@ -4,6 +4,7 @@ #include #include #include +#include "tensorflow/core/framework/tensor.h" #include "triton/codegen/selection/selection.h" #include "triton/runtime/function.h" #include "triton/lang/code_gen.h" @@ -21,14 +22,15 @@ using namespace triton; namespace rt = triton::runtime; +typedef std::vector tf_grid_t; +typedef std::function tf_grid_fn_ty; /* TF triton op properties */ - -std::map id_grid_map; +std::map id_grid_map; std::map id_fn_map; void register_grid(size_t id, - const rt::function::grid_fn_ty& grid_fn) { + const tf_grid_fn_ty& grid_fn) { id_grid_map[id] = grid_fn; } @@ -110,6 +112,13 @@ void gen_make_handles(std::ostream &os, const std::vector& args) } void gen_make_launch_function(std::ostream &os, const std::vector& args) { + os << " rt::function::grid_fn_ty grid_fn = [&](const rt::function::options_t& opt) {" << std::endl; + os << " auto tmp = id_grid_map.at(id_)(opt);" << std::endl; + os << " rt::grid_t result;" << std::endl; + os << " for(auto& x: tmp) { result.push_back(x.scalar()()); }" << std::endl; + os << " return result; }; " << std::endl; + + os << " (*id_fn_map.at(id_))({"; for(unsigned i = 0; i < args.size() ; i++){ ir::argument *arg = args[i]; @@ -120,7 +129,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector os << ", "; os << name; } - os << "}, id_grid_map.at(id_), stream); \n"; + os << "}, grid_fn, stream); \n"; } void gen_register_kernel_builder(std::ostream &os, const std::string &name, @@ -230,7 +239,9 @@ using GPUDevice = Eigen::GpuDevice; namespace rt = triton::runtime; namespace drv = triton::driver; -extern std::map id_grid_map; +typedef std::vector tf_grid_t; +typedef std::function tf_grid_fn_ty; +extern std::map id_grid_map; extern std::map id_fn_map;