serialize variables as a dict, not as a list
This commit is contained in:
@@ -330,8 +330,9 @@ def save_variables(save_path, variables, sess=None):
|
||||
sess = get_session()
|
||||
|
||||
ps = sess.run(variables)
|
||||
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
joblib.dump(ps, save_path)
|
||||
joblib.dump(save_dict, save_path)
|
||||
|
||||
def load_variables(load_path, variables, sess=None):
|
||||
if sess is None:
|
||||
@@ -339,8 +340,8 @@ def load_variables(load_path, variables, sess=None):
|
||||
|
||||
loaded_params = joblib.load(load_path)
|
||||
restores = []
|
||||
for p, loaded_p in zip(variables, loaded_params):
|
||||
restores.append(p.assign(loaded_p))
|
||||
for v in variables:
|
||||
restores.append(v.assign(loaded_params[v.name]))
|
||||
sess.run(restores)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user