use ncpu=1 for mujoco sessions - gives a bit of a performance speedup
This commit is contained in:
@@ -41,15 +41,15 @@ class TfRunningMeanStd(object):
|
||||
def __init__(self, epsilon=1e-4, shape=(), scope=''):
|
||||
sess = get_session()
|
||||
|
||||
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
_batch_var = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
|
||||
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float32)
|
||||
_batch_var = tf.placeholder(shape=shape, dtype=tf.float32)
|
||||
_batch_count = tf.placeholder(shape=(), dtype=tf.float32)
|
||||
|
||||
|
||||
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
|
||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
|
||||
_var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
|
||||
_count = tf.get_variable('count', initializer=np.ones((), 'float64')*epsilon, dtype=tf.float64)
|
||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float32'), dtype=tf.float32)
|
||||
_var = tf.get_variable('std', initializer=np.ones(shape, 'float32'), dtype=tf.float32)
|
||||
_count = tf.get_variable('count', initializer=np.full((), epsilon, 'float32'), dtype=tf.float32)
|
||||
|
||||
delta = _batch_mean - _mean
|
||||
tot_count = _count + _batch_count
|
||||
@@ -78,35 +78,29 @@ class TfRunningMeanStd(object):
|
||||
|
||||
def update_from_moments(batch_mean, batch_var, batch_count):
|
||||
for op in update_ops:
|
||||
sess.run(update_ops, feed_dict={
|
||||
sess.run(op, feed_dict={
|
||||
_batch_mean: batch_mean,
|
||||
_batch_var: batch_var,
|
||||
_batch_count: batch_count
|
||||
})
|
||||
|
||||
|
||||
|
||||
sess.run(tf.variables_initializer([_mean, _var, _count]))
|
||||
self.sess = sess
|
||||
self.update_from_moments = update_from_moments
|
||||
self._set_mean_var_count()
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.sess.run(self._mean)
|
||||
|
||||
@property
|
||||
def var(self):
|
||||
return self.sess.run(self._var)
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return self.sess.run(self._count)
|
||||
|
||||
def _set_mean_var_count(self):
|
||||
self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
|
||||
|
||||
def update(self, x):
|
||||
batch_mean = np.mean(x, axis=0)
|
||||
batch_var = np.var(x, axis=0)
|
||||
batch_count = x.shape[0]
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
self._set_mean_var_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -10,8 +10,8 @@ class VecNormalize(VecEnvWrapper):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||
#self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
||||
#self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
||||
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
||||
self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
||||
self.clipob = clipob
|
||||
self.cliprew = cliprew
|
||||
self.ret = np.zeros(self.num_envs)
|
||||
|
@@ -4,10 +4,11 @@ import os
|
||||
import os.path as osp
|
||||
import gym
|
||||
from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
|
||||
from baselines.common.tf_util import save_state, load_state
|
||||
from baselines.common.tf_util import save_state, load_state, get_session
|
||||
from baselines import bench, logger
|
||||
from importlib import import_module
|
||||
|
||||
@@ -84,6 +85,10 @@ def build_env(args, render=False):
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
if env_type == 'mujoco':
|
||||
get_session(tf.ConfigProto(allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1))
|
||||
|
||||
if args.num_env:
|
||||
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
|
||||
else:
|
||||
@@ -193,6 +198,7 @@ def main():
|
||||
args, unknown_args = arg_parser.parse_known_args()
|
||||
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
||||
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
rank = 0
|
||||
logger.configure()
|
||||
|
Reference in New Issue
Block a user