serialize variables as a dict, not as a list

This commit is contained in:
Peter Zhokhov
2018-07-31 11:13:31 -07:00
parent 9c48f9fad5
commit 2a93ea8782

View File

@@ -330,8 +330,9 @@ def save_variables(save_path, variables, sess=None):
sess = get_session()
ps = sess.run(variables)
save_dict = {v.name: value for v, value in zip(variables, ps)}
os.makedirs(os.path.dirname(save_path), exist_ok=True)
joblib.dump(ps, save_path)
joblib.dump(save_dict, save_path)
def load_variables(load_path, variables, sess=None):
if sess is None:
@@ -339,8 +340,8 @@ def load_variables(load_path, variables, sess=None):
loaded_params = joblib.load(load_path)
restores = []
for p, loaded_p in zip(variables, loaded_params):
restores.append(p.assign(loaded_p))
for v in variables:
restores.append(v.assign(loaded_params[v.name]))
sess.run(restores)