make load_variables compatible with old list format (#71)
* make load_variables compatible with old list format * cosmetic fixes
This commit is contained in:
@@ -312,11 +312,15 @@ def get_available_gpus():
|
|||||||
# ================================================================
|
# ================================================================
|
||||||
|
|
||||||
def load_state(fname, sess=None):
|
def load_state(fname, sess=None):
|
||||||
|
from baselines import logger
|
||||||
|
logger.warn('load_state method is deprecated, please use load_variables instead')
|
||||||
sess = sess or get_session()
|
sess = sess or get_session()
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.restore(tf.get_default_session(), fname)
|
saver.restore(tf.get_default_session(), fname)
|
||||||
|
|
||||||
def save_state(fname, sess=None):
|
def save_state(fname, sess=None):
|
||||||
|
from baselines import logger
|
||||||
|
logger.warn('save_state method is deprecated, please use save_variables instead')
|
||||||
sess = sess or get_session()
|
sess = sess or get_session()
|
||||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
@@ -339,11 +343,16 @@ def load_variables(load_path, variables=None, sess=None):
|
|||||||
variables = variables or tf.trainable_variables()
|
variables = variables or tf.trainable_variables()
|
||||||
|
|
||||||
loaded_params = joblib.load(os.path.expanduser(load_path))
|
loaded_params = joblib.load(os.path.expanduser(load_path))
|
||||||
restores = []
|
restores = []
|
||||||
for v in variables:
|
if isinstance(loaded_params, list):
|
||||||
restores.append(v.assign(loaded_params[v.name]))
|
assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'
|
||||||
sess.run(restores)
|
for d, v in zip(loaded_params, variables):
|
||||||
|
restores.append(v.assign(d))
|
||||||
|
else:
|
||||||
|
for v in variables:
|
||||||
|
restores.append(v.assign(loaded_params[v.name]))
|
||||||
|
|
||||||
|
sess.run(restores)
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
# Shape adjustment for feeding into tf placeholders
|
# Shape adjustment for feeding into tf placeholders
|
||||||
|
Reference in New Issue
Block a user