use ncpu=1 for mujoco sessions - gives a bit of a performance speedup

This commit is contained in:
Peter Zhokhov
2018-08-02 10:24:21 -07:00
parent c2df27bee4
commit f6d5a47bed
3 changed files with 22 additions and 22 deletions

View File

@@ -41,15 +41,15 @@ class TfRunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=(), scope=''):
sess = get_session()
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float64)
_batch_var = tf.placeholder(shape=shape, dtype=tf.float64)
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float32)
_batch_var = tf.placeholder(shape=shape, dtype=tf.float32)
_batch_count = tf.placeholder(shape=(), dtype=tf.float32)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
_var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
_count = tf.get_variable('count', initializer=np.ones((), 'float64')*epsilon, dtype=tf.float64)
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float32'), dtype=tf.float32)
_var = tf.get_variable('std', initializer=np.ones(shape, 'float32'), dtype=tf.float32)
_count = tf.get_variable('count', initializer=np.full((), epsilon, 'float32'), dtype=tf.float32)
delta = _batch_mean - _mean
tot_count = _count + _batch_count
@@ -78,35 +78,29 @@ class TfRunningMeanStd(object):
def update_from_moments(batch_mean, batch_var, batch_count):
for op in update_ops:
sess.run(update_ops, feed_dict={
sess.run(op, feed_dict={
_batch_mean: batch_mean,
_batch_var: batch_var,
_batch_count: batch_count
})
sess.run(tf.variables_initializer([_mean, _var, _count]))
self.sess = sess
self.update_from_moments = update_from_moments
self._set_mean_var_count()
@property
def mean(self):
return self.sess.run(self._mean)
@property
def var(self):
return self.sess.run(self._var)
@property
def count(self):
return self.sess.run(self._count)
def _set_mean_var_count(self):
self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
self._set_mean_var_count()

View File

@@ -10,8 +10,8 @@ class VecNormalize(VecEnvWrapper):
VecEnvWrapper.__init__(self, venv)
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=()) if ret else None
#self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
#self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
self.clipob = clipob
self.cliprew = cliprew
self.ret = np.zeros(self.num_envs)

View File

@@ -4,10 +4,11 @@ import os
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
from baselines.common.tf_util import save_state, load_state
from baselines.common.tf_util import save_state, load_state, get_session
from baselines import bench, logger
from importlib import import_module
@@ -84,6 +85,10 @@ def build_env(args, render=False):
env_type, env_id = get_env_type(args.env)
if env_type == 'mujoco':
get_session(tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1))
if args.num_env:
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
else:
@@ -193,6 +198,7 @@ def main():
args, unknown_args = arg_parser.parse_known_args()
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()