testing some register gradient
This commit is contained in:
@@ -24,21 +24,19 @@ namespace rt = triton::runtime;
|
||||
|
||||
/* TF triton op properties */
|
||||
|
||||
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
std::map<size_t, rt::function*> id_fn_map;
|
||||
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
void register_grid(size_t id,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id] = grid_fn;
|
||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
}
|
||||
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second;
|
||||
if(!is_inserted)
|
||||
assert(false);
|
||||
id_fn_map[id].reset(new rt::function(src, opt));
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
@@ -135,7 +133,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, id_grid_map.at(id_), stream); \n";
|
||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -214,7 +212,7 @@ std::tuple<std::string,
|
||||
parser.Parse();
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::unique_ptr<ir::module>(new ir::module("", ctx));
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
Generator gen(&parser);
|
||||
gen.Gen(&*ir);
|
||||
// function
|
||||
@@ -245,8 +243,8 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
extern std::map<size_t, rt::function*> id_fn_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
|
||||
class )" << opname << R"(: public OpKernel {
|
||||
@@ -294,6 +292,7 @@ oss << R"(
|
||||
)";
|
||||
gen_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user