diff --git a/baselines/common/running_mean_std.py b/baselines/common/running_mean_std.py index 28d0ba1..9573cb0 100644 --- a/baselines/common/running_mean_std.py +++ b/baselines/common/running_mean_std.py @@ -41,15 +41,15 @@ class TfRunningMeanStd(object): def __init__(self, epsilon=1e-4, shape=(), scope=''): sess = get_session() - _batch_mean = tf.placeholder(shape=shape, dtype=tf.float64) - _batch_var = tf.placeholder(shape=shape, dtype=tf.float64) - _batch_count = tf.placeholder(shape=(), dtype=tf.float64) + _batch_mean = tf.placeholder(shape=shape, dtype=tf.float32) + _batch_var = tf.placeholder(shape=shape, dtype=tf.float32) + _batch_count = tf.placeholder(shape=(), dtype=tf.float32) with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): - _mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64) - _var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64) - _count = tf.get_variable('count', initializer=np.ones((), 'float64')*epsilon, dtype=tf.float64) + _mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float32'), dtype=tf.float32) + _var = tf.get_variable('std', initializer=np.ones(shape, 'float32'), dtype=tf.float32) + _count = tf.get_variable('count', initializer=np.full((), epsilon, 'float32'), dtype=tf.float32) delta = _batch_mean - _mean tot_count = _count + _batch_count @@ -78,35 +78,29 @@ class TfRunningMeanStd(object): def update_from_moments(batch_mean, batch_var, batch_count): for op in update_ops: - sess.run(update_ops, feed_dict={ + sess.run(op, feed_dict={ _batch_mean: batch_mean, _batch_var: batch_var, _batch_count: batch_count }) + sess.run(tf.variables_initializer([_mean, _var, _count])) self.sess = sess self.update_from_moments = update_from_moments + self._set_mean_var_count() - @property - def mean(self): - return self.sess.run(self._mean) - - @property - def var(self): - return self.sess.run(self._var) - - @property - def count(self): - return self.sess.run(self._count) - + def _set_mean_var_count(self): + self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count]) def update(self, x): batch_mean = np.mean(x, axis=0) batch_var = np.var(x, axis=0) batch_count = x.shape[0] self.update_from_moments(batch_mean, batch_var, batch_count) + self._set_mean_var_count() + diff --git a/baselines/common/vec_env/vec_normalize.py b/baselines/common/vec_env/vec_normalize.py index 37ee02f..2f2cc8d 100644 --- a/baselines/common/vec_env/vec_normalize.py +++ b/baselines/common/vec_env/vec_normalize.py @@ -10,8 +10,8 @@ class VecNormalize(VecEnvWrapper): VecEnvWrapper.__init__(self, venv) self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None self.ret_rms = RunningMeanStd(shape=()) if ret else None - #self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None - #self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None + self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None + self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None self.clipob = clipob self.cliprew = cliprew self.ret = np.zeros(self.num_envs) diff --git a/baselines/run.py b/baselines/run.py index 344f035..cba8515 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -4,10 +4,11 @@ import os import os.path as osp import gym from collections import defaultdict +import tensorflow as tf from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env -from baselines.common.tf_util import save_state, load_state +from baselines.common.tf_util import save_state, load_state, get_session from baselines import bench, logger from importlib import import_module @@ -84,6 +85,10 @@ def build_env(args, render=False): env_type, env_id = get_env_type(args.env) if env_type == 'mujoco': + get_session(tf.ConfigProto(allow_soft_placement=True, + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1)) + if args.num_env: env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)]) else: @@ -193,6 +198,7 @@ def main(): args, unknown_args = arg_parser.parse_known_args() extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()} + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: rank = 0 logger.configure()