reverted VecNormalize to use RunningMeanStd (no tf)
This commit is contained in:
@@ -16,21 +16,22 @@ class RunningMeanStd(object):
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
|
||||
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
||||
delta = batch_mean - self.mean
|
||||
tot_count = self.count + batch_count
|
||||
self.mean, self.var, self.count = update_mean_var_count_from_moments(
|
||||
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
|
||||
|
||||
new_mean = self.mean + delta * batch_count / tot_count
|
||||
m_a = self.var * (self.count)
|
||||
m_b = batch_var * (batch_count)
|
||||
M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
|
||||
new_var = M2 / (self.count + batch_count)
|
||||
|
||||
new_count = batch_count + self.count
|
||||
|
||||
self.mean = new_mean
|
||||
self.var = new_var
|
||||
self.count = new_count
|
||||
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
|
||||
delta = batch_mean - mean
|
||||
tot_count = count + batch_count
|
||||
|
||||
new_mean = mean + delta * batch_count / tot_count
|
||||
m_a = var * count
|
||||
m_b = batch_var * batch_count
|
||||
M2 = m_a + m_b + np.square(delta) * count * batch_count / (count + batch_count)
|
||||
new_var = M2 / (count + batch_count)
|
||||
new_count = batch_count + count
|
||||
|
||||
return new_mean, new_var, new_count
|
||||
|
||||
|
||||
class TfRunningMeanStd(object):
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
@@ -41,55 +42,27 @@ class TfRunningMeanStd(object):
|
||||
def __init__(self, epsilon=1e-4, shape=(), scope=''):
|
||||
sess = get_session()
|
||||
|
||||
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
_batch_var = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
|
||||
self._new_mean = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
self._new_var = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
self._new_count = tf.placeholder(shape=(), dtype=tf.float64)
|
||||
|
||||
|
||||
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
|
||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
|
||||
_var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
|
||||
_count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
|
||||
self._mean = tf.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
|
||||
self._var = tf.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
|
||||
self._count = tf.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
|
||||
|
||||
delta = _batch_mean - _mean
|
||||
tot_count = _count + _batch_count
|
||||
|
||||
new_mean = _mean + delta * _batch_count / tot_count
|
||||
m_a = _var * (_count)
|
||||
m_b = _batch_var * (_batch_count)
|
||||
M2 = m_a + m_b + np.square(delta) * _count * _batch_count / (_count + _batch_count)
|
||||
new_var = M2 / (_count + _batch_count)
|
||||
new_count = _batch_count + _count
|
||||
|
||||
update_ops = [
|
||||
_var.assign(new_var),
|
||||
_mean.assign(new_mean),
|
||||
_count.assign(new_count)
|
||||
self.update_ops = [
|
||||
self._var.assign(self._new_var),
|
||||
self._mean.assign(self._new_mean),
|
||||
self._count.assign(self._new_count)
|
||||
]
|
||||
|
||||
self._mean = _mean
|
||||
self._var = _var
|
||||
self._count = _count
|
||||
|
||||
self._batch_mean = _batch_mean
|
||||
self._batch_var = _batch_var
|
||||
self._batch_count = _batch_count
|
||||
sess.run(tf.variables_initializer([self._mean, self._var, self._count]))
|
||||
self.sess = sess
|
||||
|
||||
|
||||
def update_from_moments(batch_mean, batch_var, batch_count):
|
||||
for op in update_ops:
|
||||
sess.run(op, feed_dict={
|
||||
_batch_mean: batch_mean,
|
||||
_batch_var: batch_var,
|
||||
_batch_count: batch_count
|
||||
})
|
||||
|
||||
|
||||
|
||||
sess.run(tf.variables_initializer([_mean, _var, _count]))
|
||||
self.sess = sess
|
||||
self.update_from_moments = update_from_moments
|
||||
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.sess.run(self._mean)
|
||||
@@ -107,10 +80,18 @@ class TfRunningMeanStd(object):
|
||||
batch_mean = np.mean(x, axis=0)
|
||||
batch_var = np.var(x, axis=0)
|
||||
batch_count = x.shape[0]
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
|
||||
|
||||
|
||||
mean, var, count = self.sess.run([self._mean, self._var, self._count])
|
||||
new_mean, new_var, new_count = update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count)
|
||||
|
||||
|
||||
self.sess.run(self.update_ops, feed_dict={
|
||||
self._new_mean: new_mean,
|
||||
self._new_var: new_var,
|
||||
self._new_count: new_count
|
||||
})
|
||||
|
||||
|
||||
|
||||
def test_runningmeanstd():
|
||||
for (x1, x2, x3) in [
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from baselines.common.vec_env import VecEnvWrapper
|
||||
from baselines.common.running_mean_std import TfRunningMeanStd
|
||||
from baselines.common.running_mean_std import RunningMeanStd
|
||||
import numpy as np
|
||||
|
||||
class VecNormalize(VecEnvWrapper):
|
||||
@@ -8,10 +8,10 @@ class VecNormalize(VecEnvWrapper):
|
||||
"""
|
||||
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
#self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||
#self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
||||
self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
||||
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||
#self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
||||
#self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
||||
self.clipob = clipob
|
||||
self.cliprew = cliprew
|
||||
self.ret = np.zeros(self.num_envs)
|
||||
|
Reference in New Issue
Block a user