save all variables to make sure we save the vec_normalize normalization
This commit is contained in:
@@ -74,8 +74,8 @@ class Model(object):
|
||||
self.step = step_model.step
|
||||
self.value = step_model.value
|
||||
self.initial_state = step_model.initial_state
|
||||
self.save = functools.partial(tf_util.save_variables, sess=sess, variables=params)
|
||||
self.load = functools.partial(tf_util.load_variables, sess=sess, variables=params)
|
||||
self.save = functools.partial(tf_util.save_variables, sess=sess)
|
||||
self.load = functools.partial(tf_util.load_variables, sess=sess)
|
||||
tf.global_variables_initializer().run(session=sess)
|
||||
|
||||
|
||||
|
@@ -83,8 +83,8 @@ class Model(object):
|
||||
|
||||
|
||||
self.train = train
|
||||
self.save = functools.partial(save_variables, variables=params, sess=sess)
|
||||
self.load = functools.partial(load_variables, variables=params, sess=sess)
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
self.train_model = train_model
|
||||
self.step_model = step_model
|
||||
self.step = step_model.step
|
||||
|
@@ -86,6 +86,7 @@ register_benchmark({
|
||||
'description': 'Some small 2D MuJoCo tasks, run for 1M timesteps',
|
||||
'tasks': [{'env_id': _envid, 'trials': 6, 'num_timesteps': int(1e6)} for _envid in _mujocosmall]
|
||||
})
|
||||
|
||||
register_benchmark({
|
||||
'name': 'MujocoWalkers',
|
||||
'description': 'MuJoCo forward walkers, run for 8M, humanoid 100M',
|
||||
|
@@ -46,10 +46,10 @@ class TfRunningMeanStd(object):
|
||||
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
|
||||
|
||||
|
||||
with tf.variable_scope(scope, reuse=False):
|
||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape))
|
||||
_var = tf.get_variable('std', initializer=np.ones(shape))
|
||||
_count = tf.get_variable('count', initializer=np.ones(shape=())*epsilon)
|
||||
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
|
||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
|
||||
_var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
|
||||
_count = tf.get_variable('count', initializer=np.ones((), 'float64')*epsilon, dtype=tf.float64)
|
||||
|
||||
delta = _batch_mean - _mean
|
||||
tot_count = _count + _batch_count
|
||||
|
@@ -325,20 +325,20 @@ def save_state(fname, sess=None):
|
||||
# The methods above and below are clearly doing the same thing, and in a rather similar way
|
||||
# TODO: ensure there is no subtle differences and remove one
|
||||
|
||||
def save_variables(save_path, variables, sess=None):
|
||||
if sess is None:
|
||||
sess = get_session()
|
||||
def save_variables(save_path, variables=None, sess=None):
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.trainable_variables()
|
||||
|
||||
ps = sess.run(variables)
|
||||
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
joblib.dump(save_dict, save_path)
|
||||
|
||||
def load_variables(load_path, variables, sess=None):
|
||||
if sess is None:
|
||||
sess = get_session()
|
||||
def load_variables(load_path, variables=None, sess=None):
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.trainable_variables()
|
||||
|
||||
loaded_params = joblib.load(load_path)
|
||||
loaded_params = joblib.load(os.path.expanduser(load_path))
|
||||
restores = []
|
||||
for v in variables:
|
||||
restores.append(v.assign(loaded_params[v.name]))
|
||||
|
@@ -81,9 +81,8 @@ class Model(object):
|
||||
self.value = act_model.value
|
||||
self.initial_state = act_model.initial_state
|
||||
|
||||
# If you want to load weights, also save/load observation scaling inside VecNormalize ?
|
||||
self.save = functools.partial(save_variables, sess=sess, variables=params)
|
||||
self.load = functools.partial(load_variables, sess=sess, variables=params)
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
initialize()
|
||||
|
Reference in New Issue
Block a user