make load_variables compatible with old list format (#71)

* make load_variables compatible with old list format

* cosmetic fixes
This commit is contained in:
pzhokhov
2018-08-30 15:32:55 -07:00
committed by Peter Zhokhov
parent 7bccb2969f
commit 51cefc933b

View File

@@ -312,11 +312,15 @@ def get_available_gpus():
# ================================================================
def load_state(fname, sess=None):
from baselines import logger
logger.warn('load_state method is deprecated, please use load_variables instead')
sess = sess or get_session()
saver = tf.train.Saver()
saver.restore(tf.get_default_session(), fname)
def save_state(fname, sess=None):
from baselines import logger
logger.warn('save_state method is deprecated, please use save_variables instead')
sess = sess or get_session()
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
@@ -339,11 +343,16 @@ def load_variables(load_path, variables=None, sess=None):
variables = variables or tf.trainable_variables()
loaded_params = joblib.load(os.path.expanduser(load_path))
restores = []
for v in variables:
restores.append(v.assign(loaded_params[v.name]))
sess.run(restores)
restores = []
if isinstance(loaded_params, list):
assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'
for d, v in zip(loaded_params, variables):
restores.append(v.assign(d))
else:
for v in variables:
restores.append(v.assign(loaded_params[v.name]))
sess.run(restores)
# ================================================================
# Shape adjustment for feeding into tf placeholders