more tests
This commit is contained in:
@@ -48,7 +48,7 @@ endif()
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton LLVM)
|
||||
target_link_libraries(triton LLVM ${TF_LIBS})
|
||||
|
||||
# Warning level
|
||||
#if(MSVC)
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
@@ -21,14 +22,15 @@ using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
typedef std::vector<tensorflow::Tensor> tf_grid_t;
|
||||
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
|
||||
|
||||
/* TF triton op properties */
|
||||
|
||||
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
std::map<size_t, tf_grid_fn_ty> id_grid_map;
|
||||
std::map<size_t, rt::function*> id_fn_map;
|
||||
|
||||
void register_grid(size_t id,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
const tf_grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id] = grid_fn;
|
||||
}
|
||||
|
||||
@@ -110,6 +112,13 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
||||
}
|
||||
|
||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " rt::function::grid_fn_ty grid_fn = [&](const rt::function::options_t& opt) {" << std::endl;
|
||||
os << " auto tmp = id_grid_map.at(id_)(opt);" << std::endl;
|
||||
os << " rt::grid_t result;" << std::endl;
|
||||
os << " for(auto& x: tmp) { result.push_back(x.scalar<int>()()); }" << std::endl;
|
||||
os << " return result; }; " << std::endl;
|
||||
|
||||
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
@@ -120,7 +129,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, id_grid_map.at(id_), stream); \n";
|
||||
os << "}, grid_fn, stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -230,7 +239,9 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
typedef std::vector<tensorflow::Tensor> tf_grid_t;
|
||||
typedef std::function<tf_grid_t(const rt::function::options_t& opt)> tf_grid_fn_ty;
|
||||
extern std::map<size_t, tf_grid_fn_ty> id_grid_map;
|
||||
extern std::map<size_t, rt::function*> id_fn_map;
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user