diff --git a/baselines/common/running_mean_std.py b/baselines/common/running_mean_std.py index b58935a..17515b2 100644 --- a/baselines/common/running_mean_std.py +++ b/baselines/common/running_mean_std.py @@ -16,21 +16,22 @@ class RunningMeanStd(object): self.update_from_moments(batch_mean, batch_var, batch_count) def update_from_moments(self, batch_mean, batch_var, batch_count): - delta = batch_mean - self.mean - tot_count = self.count + batch_count + self.mean, self.var, self.count = update_mean_var_count_from_moments( + self.mean, self.var, self.count, batch_mean, batch_var, batch_count) - new_mean = self.mean + delta * batch_count / tot_count - m_a = self.var * (self.count) - m_b = batch_var * (batch_count) - M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) - new_var = M2 / (self.count + batch_count) - - new_count = batch_count + self.count - - self.mean = new_mean - self.var = new_var - self.count = new_count +def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): + delta = batch_mean - mean + tot_count = count + batch_count + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * count * batch_count / (count + batch_count) + new_var = M2 / (count + batch_count) + new_count = batch_count + count + + return new_mean, new_var, new_count + class TfRunningMeanStd(object): # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm @@ -41,76 +42,45 @@ class TfRunningMeanStd(object): def __init__(self, epsilon=1e-4, shape=(), scope=''): sess = get_session() - _batch_mean = tf.placeholder(shape=shape, dtype=tf.float64) - _batch_var = tf.placeholder(shape=shape, dtype=tf.float64) - _batch_count = tf.placeholder(shape=(), dtype=tf.float64) + self._new_mean = tf.placeholder(shape=shape, dtype=tf.float64) + self._new_var = tf.placeholder(shape=shape, dtype=tf.float64) + self._new_count = tf.placeholder(shape=(), dtype=tf.float64) with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): - _mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64) - _var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64) - _count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64) + self._mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64) + self._var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64) + self._count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64) - delta = _batch_mean - _mean - tot_count = _count + _batch_count + self.update_ops = tf.group([ + self._var.assign(self._new_var), + self._mean.assign(self._new_mean), + self._count.assign(self._new_count) + ]) - new_mean = _mean + delta * _batch_count / tot_count - m_a = _var * (_count) - m_b = _batch_var * (_batch_count) - M2 = m_a + m_b + np.square(delta) * _count * _batch_count / (_count + _batch_count) - new_var = M2 / (_count + _batch_count) - new_count = _batch_count + _count - - update_ops = [ - _var.assign(new_var), - _mean.assign(new_mean), - _count.assign(new_count) - ] - - self._mean = _mean - self._var = _var - self._count = _count - - self._batch_mean = _batch_mean - self._batch_var = _batch_var - self._batch_count = _batch_count - - - def update_from_moments(batch_mean, batch_var, batch_count): - for op in update_ops: - sess.run(op, feed_dict={ - _batch_mean: batch_mean, - _batch_var: batch_var, - _batch_count: batch_count - }) - - - - sess.run(tf.variables_initializer([_mean, _var, _count])) + sess.run(tf.variables_initializer([self._mean, self._var, self._count])) self.sess = sess - self.update_from_moments = update_from_moments - - @property - def mean(self): - return self.sess.run(self._mean) - - @property - def var(self): - return self.sess.run(self._var) - - @property - def count(self): - return self.sess.run(self._count) - + self._set_mean_var_count() + + def _set_mean_var_count(self): + self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count]) def update(self, x): batch_mean = np.mean(x, axis=0) batch_var = np.var(x, axis=0) batch_count = x.shape[0] - self.update_from_moments(batch_mean, batch_var, batch_count) - - + new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count) + + self.sess.run(self.update_ops, feed_dict={ + self._new_mean: new_mean, + self._new_var: new_var, + self._new_count: new_count + }) + + self._set_mean_var_count() + + def test_runningmeanstd(): for (x1, x2, x3) in [ @@ -145,3 +115,71 @@ def test_tf_runningmeanstd(): ms2 = [rms.mean, rms.var] np.testing.assert_allclose(ms1, ms2) + + +def profile_tf_runningmeanstd(): + import time + from baselines.common import tf_util + + tf_util.get_session( config=tf.ConfigProto( + inter_op_parallelism_threads=1, + intra_op_parallelism_threads=1, + allow_soft_placement=True + )) + + x = np.random.random((376,)) + + n_trials = 10000 + rms = RunningMeanStd() + tfrms = TfRunningMeanStd() + + tic1 = time.time() + for _ in range(n_trials): + rms.update(x) + + tic2 = time.time() + for _ in range(n_trials): + tfrms.update(x) + + tic3 = time.time() + + print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1)) + print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2)) + + + tic1 = time.time() + for _ in range(n_trials): + z1 = rms.mean + + tic2 = time.time() + for _ in range(n_trials): + z2 = tfrms.mean + + tic3 = time.time() + + print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1)) + print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2)) + + + + ''' + options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101 + run_metadata = tf.RunMetadata() + profile_opts = dict(options=options, run_metadata=run_metadata) + + + + from tensorflow.python.client import timeline + fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101 + chrome_trace = fetched_timeline.generate_chrome_trace_format() + outfile = '/tmp/timeline.json' + with open(outfile, 'wt') as f: + f.write(chrome_trace) + print(f'Successfully saved profile to {outfile}. Exiting.') + exit(0) + ''' + + + +if __name__ == '__main__': + profile_tf_runningmeanstd()