more tests

This commit is contained in:
Philippe Tillet
2019-08-26 11:00:00 -07:00
parent 321d268a4a
commit 0e0399f866
2 changed files with 17 additions and 6 deletions

View File

@@ -48,7 +48,7 @@ endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
target_link_libraries(triton LLVM)
target_link_libraries(triton LLVM ${TF_LIBS})
# Warning level
#if(MSVC)

View File

@@ -4,6 +4,7 @@
#include <string>
#include <regex>
#include <algorithm>
#include "tensorflow/core/framework/tensor.h"
#include "triton/codegen/selection/selection.h"
#include "triton/runtime/function.h"
#include "triton/lang/code_gen.h"
@@ -21,14 +22,15 @@ using namespace triton;
namespace rt = triton::runtime;
typedef std::vector<tensorflow::Tensor> tf_grid_t;
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
/* TF triton op properties */
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
std::map<size_t, tf_grid_fn_ty> id_grid_map;
std::map<size_t, rt::function*> id_fn_map;
void register_grid(size_t id,
const rt::function::grid_fn_ty& grid_fn) {
const tf_grid_fn_ty& grid_fn) {
id_grid_map[id] = grid_fn;
}
@@ -110,6 +112,13 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
}
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " rt::function::grid_fn_ty grid_fn = [&](const rt::function::options_t& opt) {" << std::endl;
os << " auto tmp = id_grid_map.at(id_)(opt);" << std::endl;
os << " rt::grid_t result;" << std::endl;
os << " for(auto& x: tmp) { result.push_back(x.scalar<int>()()); }" << std::endl;
os << " return result; }; " << std::endl;
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
@@ -120,7 +129,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
os << ", ";
os << name;
}
os << "}, id_grid_map.at(id_), stream); \n";
os << "}, grid_fn, stream); \n";
}
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -230,7 +239,9 @@ using GPUDevice = Eigen::GpuDevice;
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
typedef std::vector<tensorflow::Tensor> tf_grid_t;
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
extern std::map<size_t, tf_grid_fn_ty> id_grid_map;
extern std::map<size_t, rt::function*> id_fn_map;