serialize variables as a dict, not as a list
This commit is contained in:
@@ -330,8 +330,9 @@ def save_variables(save_path, variables, sess=None):
|
|||||||
sess = get_session()
|
sess = get_session()
|
||||||
|
|
||||||
ps = sess.run(variables)
|
ps = sess.run(variables)
|
||||||
|
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
joblib.dump(ps, save_path)
|
joblib.dump(save_dict, save_path)
|
||||||
|
|
||||||
def load_variables(load_path, variables, sess=None):
|
def load_variables(load_path, variables, sess=None):
|
||||||
if sess is None:
|
if sess is None:
|
||||||
@@ -339,8 +340,8 @@ def load_variables(load_path, variables, sess=None):
|
|||||||
|
|
||||||
loaded_params = joblib.load(load_path)
|
loaded_params = joblib.load(load_path)
|
||||||
restores = []
|
restores = []
|
||||||
for p, loaded_p in zip(variables, loaded_params):
|
for v in variables:
|
||||||
restores.append(p.assign(loaded_p))
|
restores.append(v.assign(loaded_params[v.name]))
|
||||||
sess.run(restores)
|
sess.run(restores)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user