testing some register gradient

This commit is contained in:
Philippe Tillet
2019-08-26 19:25:58 -07:00
parent 9ece3eccc6
commit 7cb73f66e2
4 changed files with 47 additions and 19 deletions

View File

@@ -24,21 +24,19 @@ namespace rt = triton::runtime;
/* TF triton op properties */
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
std::map<size_t, rt::function*> id_fn_map;
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, int64_t> i64scalar_map;
void register_grid(size_t id,
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id] = grid_fn;
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
}
void register_fn(size_t id,
const std::string& src,
const rt::function::options_space_t& opt) {
bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second;
if(!is_inserted)
assert(false);
id_fn_map[id].reset(new rt::function(src, opt));
}
size_t make_op_id() {
@@ -135,7 +133,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
os << ", ";
os << name;
}
os << "}, id_grid_map.at(id_), stream); \n";
os << "}, *id_grid_map.at(id_), stream); \n";
}
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -214,7 +212,7 @@ std::tuple<std::string,
parser.Parse();
// triton-ir code-gen
ir::context ctx;
auto ir = std::unique_ptr<ir::module>(new ir::module("", ctx));
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
Generator gen(&parser);
gen.Gen(&*ir);
// function
@@ -245,8 +243,8 @@ using GPUDevice = Eigen::GpuDevice;
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
extern std::map<size_t, rt::function*> id_fn_map;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
class )" << opname << R"(: public OpKernel {
@@ -294,6 +292,7 @@ oss << R"(
)";
gen_register_op(oss, cc_name, fn->args(), outputs);
return {oss.str(), name};
}