Merge branch 'peterz_profile_vec_normalize' into peterz_migrate_rlalgs

This commit is contained in:
Peter Zhokhov
2018-08-03 11:47:36 -07:00

View File

@@ -16,21 +16,22 @@ class RunningMeanStd(object):
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
delta = batch_mean - self.mean
tot_count = self.count + batch_count
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
new_mean = self.mean + delta * batch_count / tot_count
m_a = self.var * (self.count)
m_b = batch_var * (batch_count)
M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
new_var = M2 / (self.count + batch_count)
new_count = batch_count + self.count
self.mean = new_mean
self.var = new_var
self.count = new_count
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / (count + batch_count)
new_var = M2 / (count + batch_count)
new_count = batch_count + count
return new_mean, new_var, new_count
class TfRunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
@@ -41,76 +42,45 @@ class TfRunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=(), scope=''):
sess = get_session()
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float64)
_batch_var = tf.placeholder(shape=shape, dtype=tf.float64)
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
self._new_mean = tf.placeholder(shape=shape, dtype=tf.float64)
self._new_var = tf.placeholder(shape=shape, dtype=tf.float64)
self._new_count = tf.placeholder(shape=(), dtype=tf.float64)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
_var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
_count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
self._mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
self._var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
self._count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
delta = _batch_mean - _mean
tot_count = _count + _batch_count
self.update_ops = tf.group([
self._var.assign(self._new_var),
self._mean.assign(self._new_mean),
self._count.assign(self._new_count)
])
new_mean = _mean + delta * _batch_count / tot_count
m_a = _var * (_count)
m_b = _batch_var * (_batch_count)
M2 = m_a + m_b + np.square(delta) * _count * _batch_count / (_count + _batch_count)
new_var = M2 / (_count + _batch_count)
new_count = _batch_count + _count
update_ops = [
_var.assign(new_var),
_mean.assign(new_mean),
_count.assign(new_count)
]
self._mean = _mean
self._var = _var
self._count = _count
self._batch_mean = _batch_mean
self._batch_var = _batch_var
self._batch_count = _batch_count
def update_from_moments(batch_mean, batch_var, batch_count):
for op in update_ops:
sess.run(op, feed_dict={
_batch_mean: batch_mean,
_batch_var: batch_var,
_batch_count: batch_count
})
sess.run(tf.variables_initializer([_mean, _var, _count]))
sess.run(tf.variables_initializer([self._mean, self._var, self._count]))
self.sess = sess
self.update_from_moments = update_from_moments
@property
def mean(self):
return self.sess.run(self._mean)
@property
def var(self):
return self.sess.run(self._var)
@property
def count(self):
return self.sess.run(self._count)
self._set_mean_var_count()
def _set_mean_var_count(self):
self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
self.sess.run(self.update_ops, feed_dict={
self._new_mean: new_mean,
self._new_var: new_var,
self._new_count: new_count
})
self._set_mean_var_count()
def test_runningmeanstd():
for (x1, x2, x3) in [
@@ -145,3 +115,71 @@ def test_tf_runningmeanstd():
ms2 = [rms.mean, rms.var]
np.testing.assert_allclose(ms1, ms2)
def profile_tf_runningmeanstd():
import time
from baselines.common import tf_util
tf_util.get_session( config=tf.ConfigProto(
inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1,
allow_soft_placement=True
))
x = np.random.random((376,))
n_trials = 10000
rms = RunningMeanStd()
tfrms = TfRunningMeanStd()
tic1 = time.time()
for _ in range(n_trials):
rms.update(x)
tic2 = time.time()
for _ in range(n_trials):
tfrms.update(x)
tic3 = time.time()
print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1))
print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2))
tic1 = time.time()
for _ in range(n_trials):
z1 = rms.mean
tic2 = time.time()
for _ in range(n_trials):
z2 = tfrms.mean
tic3 = time.time()
print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1))
print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2))
'''
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101
run_metadata = tf.RunMetadata()
profile_opts = dict(options=options, run_metadata=run_metadata)
from tensorflow.python.client import timeline
fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101
chrome_trace = fetched_timeline.generate_chrome_trace_format()
outfile = '/tmp/timeline.json'
with open(outfile, 'wt') as f:
f.write(chrome_trace)
print(f'Successfully saved profile to {outfile}. Exiting.')
exit(0)
'''
if __name__ == '__main__':
profile_tf_runningmeanstd()