running_mean_std uses tensorflow variables
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from baselines.common.tf_util import get_session
|
||||
|
||||
class RunningMeanStd(object):
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
def __init__(self, epsilon=1e-4, shape=()):
|
||||
@@ -28,6 +31,85 @@ class RunningMeanStd(object):
|
||||
self.var = new_var
|
||||
self.count = new_count
|
||||
|
||||
|
||||
class TfRunningMeanStd(object):
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
'''
|
||||
TensorFlow variables-based implmentation of computing running mean and std
|
||||
Benefit of this implementation is that it can be saved / loaded together with the tensorflow model
|
||||
'''
|
||||
def __init__(self, epsilon=1e-4, shape=(), scope=''):
|
||||
sess = get_session()
|
||||
|
||||
_batch_mean = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
_batch_var = tf.placeholder(shape=shape, dtype=tf.float64)
|
||||
_batch_count = tf.placeholder(shape=(), dtype=tf.float64)
|
||||
|
||||
|
||||
with tf.variable_scope(scope, reuse=False):
|
||||
_mean = tf.get_variable('mean', initializer=np.zeros(shape))
|
||||
_var = tf.get_variable('std', initializer=np.ones(shape))
|
||||
_count = tf.get_variable('count', initializer=np.ones(shape=())*epsilon)
|
||||
|
||||
delta = _batch_mean - _mean
|
||||
tot_count = _count + _batch_count
|
||||
|
||||
new_mean = _mean + delta * _batch_count / tot_count
|
||||
m_a = _var * (_count)
|
||||
m_b = _batch_var * (_batch_count)
|
||||
M2 = m_a + m_b + np.square(delta) * _count * _batch_count / (_count + _batch_count)
|
||||
new_var = M2 / (_count + _batch_count)
|
||||
new_count = _batch_count + _count
|
||||
|
||||
update_ops = [
|
||||
_var.assign(new_var),
|
||||
_mean.assign(new_mean),
|
||||
_count.assign(new_count)
|
||||
]
|
||||
|
||||
self._mean = _mean
|
||||
self._var = _var
|
||||
self._count = _count
|
||||
|
||||
self._batch_mean = _batch_mean
|
||||
self._batch_var = _batch_var
|
||||
self._batch_count = _batch_count
|
||||
|
||||
|
||||
def update_from_moments(batch_mean, batch_var, batch_count):
|
||||
for op in update_ops:
|
||||
sess.run(update_ops, feed_dict={
|
||||
_batch_mean: batch_mean,
|
||||
_batch_var: batch_var,
|
||||
_batch_count: batch_count
|
||||
})
|
||||
|
||||
|
||||
sess.run(tf.variables_initializer([_mean, _var, _count]))
|
||||
self.sess = sess
|
||||
self.update_from_moments = update_from_moments
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.sess.run(self._mean)
|
||||
|
||||
@property
|
||||
def var(self):
|
||||
return self.sess.run(self._var)
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return self.sess.run(self._count)
|
||||
|
||||
|
||||
def update(self, x):
|
||||
batch_mean = np.mean(x, axis=0)
|
||||
batch_var = np.var(x, axis=0)
|
||||
batch_count = x.shape[0]
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
|
||||
|
||||
|
||||
def test_runningmeanstd():
|
||||
for (x1, x2, x3) in [
|
||||
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
||||
@@ -43,4 +125,21 @@ def test_runningmeanstd():
|
||||
rms.update(x3)
|
||||
ms2 = [rms.mean, rms.var]
|
||||
|
||||
assert np.allclose(ms1, ms2)
|
||||
np.testing.assert_allclose(ms1, ms2)
|
||||
|
||||
def test_tf_runningmeanstd():
|
||||
for (x1, x2, x3) in [
|
||||
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
||||
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
|
||||
]:
|
||||
|
||||
rms = TfRunningMeanStd(epsilon=0.0, shape=x1.shape[1:], scope='running_mean_std' + str(np.random.randint(0, 128)))
|
||||
|
||||
x = np.concatenate([x1, x2, x3], axis=0)
|
||||
ms1 = [x.mean(axis=0), x.var(axis=0)]
|
||||
rms.update(x1)
|
||||
rms.update(x2)
|
||||
rms.update(x3)
|
||||
ms2 = [rms.mean, rms.var]
|
||||
|
||||
np.testing.assert_allclose(ms1, ms2)
|
||||
|
@@ -39,6 +39,7 @@ def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
'sum of rewards {} is less than {} of the total number of trials {}'.format(sum_rew, min_reward_fraction, n_trials)
|
||||
|
||||
|
||||
|
||||
def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODES):
|
||||
env = DummyVecEnv([env_fn])
|
||||
|
||||
@@ -50,24 +51,44 @@ def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODE
|
||||
|
||||
rewards = []
|
||||
|
||||
for i in range(N_TRIALS):
|
||||
obs = env.reset()
|
||||
state = model.initial_state
|
||||
episode_rew = 0
|
||||
while True:
|
||||
if state is not None:
|
||||
a, v, state, _ = model.step(obs, S=state, M=[False])
|
||||
else:
|
||||
a,v, _, _ = model.step(obs)
|
||||
observations, actions, rewards = rollout(env, model, N_TRIALS)
|
||||
rewards = [sum(r) for r in rewards]
|
||||
|
||||
obs, rew, done, _ = env.step(a)
|
||||
episode_rew += rew
|
||||
if done:
|
||||
break
|
||||
|
||||
rewards.append(episode_rew)
|
||||
avg_rew = sum(rewards) / N_TRIALS
|
||||
print("Average reward in {} episodes is {}".format(n_trials, avg_rew))
|
||||
assert avg_rew > min_avg_reward, \
|
||||
'average reward in {} episodes ({}) is less than {}'.format(n_trials, avg_rew, min_avg_reward)
|
||||
|
||||
def rollout(env, model, n_trials):
|
||||
rewards = []
|
||||
actions = []
|
||||
observations = []
|
||||
|
||||
for i in range(n_trials):
|
||||
obs = env.reset()
|
||||
state = model.initial_state
|
||||
episode_rew = []
|
||||
episode_actions = []
|
||||
episode_obs = []
|
||||
|
||||
while True:
|
||||
if state is not None:
|
||||
a, v, state, _ = model.step(obs, S=state, M=[False])
|
||||
else:
|
||||
a,v, _, _ = model.step(obs)
|
||||
|
||||
obs, rew, done, _ = env.step(a)
|
||||
|
||||
episode_rew.append(rew)
|
||||
episode_actions.append(a)
|
||||
episode_obs.append(obs)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
rewards.append(episode_rew)
|
||||
actions.append(episode_actions)
|
||||
observations.append(episode_obs)
|
||||
|
||||
return observations, actions, rewards
|
||||
|
||||
|
@@ -8,8 +8,10 @@ class VecNormalize(VecEnvWrapper):
|
||||
"""
|
||||
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||
# self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||
# self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
||||
self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
||||
self.clipob = clipob
|
||||
self.cliprew = cliprew
|
||||
self.ret = np.zeros(self.num_envs)
|
||||
|
@@ -148,10 +148,10 @@ def get_alg_module(alg, submodule=None):
|
||||
submodule = submodule or alg
|
||||
try:
|
||||
# first try to import the alg module from baselines
|
||||
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
||||
except ImportError:
|
||||
# then from baselines
|
||||
alg_module = import_module('.'.join(['baselines', alg, submodule]))
|
||||
except ImportError:
|
||||
# then from rl_algs
|
||||
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
||||
|
||||
return alg_module
|
||||
|
||||
|
Reference in New Issue
Block a user