use ncpu=1 for mujoco sessions - gives a bit of a performance speedup
This commit is contained in:
@@ -41,15 +41,15 @@ class TfRunningMeanStd(object):
|
|||||||
def __init__(self, epsilon=1e-4, shape=(), scope=''):
|
def __init__(self, epsilon=1e-4, shape=(), scope=''):
|
||||||
sess = get_session()
|
sess = get_session()
|
||||||
|
|
||||||
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float64)
|
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float32)
|
||||||
_batch_var = tf.placeholder(shape=shape, dtype=tf.float64)
|
_batch_var = tf.placeholder(shape=shape, dtype=tf.float32)
|
||||||
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
|
_batch_count = tf.placeholder(shape=(), dtype=tf.float32)
|
||||||
|
|
||||||
|
|
||||||
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
|
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
|
||||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
|
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float32'), dtype=tf.float32)
|
||||||
_var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
|
_var = tf.get_variable('std', initializer=np.ones(shape, 'float32'), dtype=tf.float32)
|
||||||
_count = tf.get_variable('count', initializer=np.ones((), 'float64')*epsilon, dtype=tf.float64)
|
_count = tf.get_variable('count', initializer=np.full((), epsilon, 'float32'), dtype=tf.float32)
|
||||||
|
|
||||||
delta = _batch_mean - _mean
|
delta = _batch_mean - _mean
|
||||||
tot_count = _count + _batch_count
|
tot_count = _count + _batch_count
|
||||||
@@ -78,35 +78,29 @@ class TfRunningMeanStd(object):
|
|||||||
|
|
||||||
def update_from_moments(batch_mean, batch_var, batch_count):
|
def update_from_moments(batch_mean, batch_var, batch_count):
|
||||||
for op in update_ops:
|
for op in update_ops:
|
||||||
sess.run(update_ops, feed_dict={
|
sess.run(op, feed_dict={
|
||||||
_batch_mean: batch_mean,
|
_batch_mean: batch_mean,
|
||||||
_batch_var: batch_var,
|
_batch_var: batch_var,
|
||||||
_batch_count: batch_count
|
_batch_count: batch_count
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sess.run(tf.variables_initializer([_mean, _var, _count]))
|
sess.run(tf.variables_initializer([_mean, _var, _count]))
|
||||||
self.sess = sess
|
self.sess = sess
|
||||||
self.update_from_moments = update_from_moments
|
self.update_from_moments = update_from_moments
|
||||||
|
self._set_mean_var_count()
|
||||||
|
|
||||||
@property
|
def _set_mean_var_count(self):
|
||||||
def mean(self):
|
self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
|
||||||
return self.sess.run(self._mean)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var(self):
|
|
||||||
return self.sess.run(self._var)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def count(self):
|
|
||||||
return self.sess.run(self._count)
|
|
||||||
|
|
||||||
|
|
||||||
def update(self, x):
|
def update(self, x):
|
||||||
batch_mean = np.mean(x, axis=0)
|
batch_mean = np.mean(x, axis=0)
|
||||||
batch_var = np.var(x, axis=0)
|
batch_var = np.var(x, axis=0)
|
||||||
batch_count = x.shape[0]
|
batch_count = x.shape[0]
|
||||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||||
|
self._set_mean_var_count()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -10,8 +10,8 @@ class VecNormalize(VecEnvWrapper):
|
|||||||
VecEnvWrapper.__init__(self, venv)
|
VecEnvWrapper.__init__(self, venv)
|
||||||
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||||
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||||
#self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
||||||
#self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
||||||
self.clipob = clipob
|
self.clipob = clipob
|
||||||
self.cliprew = cliprew
|
self.cliprew = cliprew
|
||||||
self.ret = np.zeros(self.num_envs)
|
self.ret = np.zeros(self.num_envs)
|
||||||
|
@@ -4,10 +4,11 @@ import os
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import gym
|
import gym
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
|
||||||
from baselines.common.tf_util import save_state, load_state
|
from baselines.common.tf_util import save_state, load_state, get_session
|
||||||
from baselines import bench, logger
|
from baselines import bench, logger
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
@@ -84,6 +85,10 @@ def build_env(args, render=False):
|
|||||||
|
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args.env)
|
||||||
if env_type == 'mujoco':
|
if env_type == 'mujoco':
|
||||||
|
get_session(tf.ConfigProto(allow_soft_placement=True,
|
||||||
|
intra_op_parallelism_threads=1,
|
||||||
|
inter_op_parallelism_threads=1))
|
||||||
|
|
||||||
if args.num_env:
|
if args.num_env:
|
||||||
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
|
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
|
||||||
else:
|
else:
|
||||||
@@ -193,6 +198,7 @@ def main():
|
|||||||
args, unknown_args = arg_parser.parse_known_args()
|
args, unknown_args = arg_parser.parse_known_args()
|
||||||
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
||||||
|
|
||||||
|
|
||||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
rank = 0
|
rank = 0
|
||||||
logger.configure()
|
logger.configure()
|
||||||
|
Reference in New Issue
Block a user