Compare commits

...

1 Commits

Author SHA1 Message Date
JongGyun Kim
fc0c43b199 RNN support for PPO2 (#859)
* initial implementaion of ppo2_rnn.

* set lstm memory as tf.GraphKeys.LOCAL_VARIABLES.

* replace dones with tf.placeholder_with_default.

* improves for 'play' option.

* removed unnecessary TODO .

* improve lstm code.

* move learning rate placeholer to optimizer scope.

* support the microbatched model.

* sync cnn lstm layer with originals.

* add cnn_lnlstm layer.

* fix a case when `states` is None.

* add initial_state variable to help test.

* make ppo2 rnn test available.

* rename 'obs' with 'observations'.
rename 'transition' with 'transitions'.
fix forgetting `dones` in the replay buffer.
fix a misuse of `states` and `next_states` in the replay buffer.

* make initialization once.
make `test_fixed_sequence` compatible with ppo2.

* adjust input shape.

* fix checking of a model input args in `simple_test` function.

* disable warning on purpose.

* support the play.

* improve scopes to compatible with multiple models (i.e, other tensorflow global/local variables)

* clean the scope of ppo2 policy model.

* name the memory variable of PPO RNNs more describly

* wrap the initializations in ppo2.

* remove redundant lines.

* update `REAMD.md`.

* add RNN layers.

* add the result of HalfCheeta-v2 env  experiment.

* correct a typo.

* add RNN class.

* rename `nlstm` with `num_units` in RNN builder functions.

* remove state saving.

* reuse RNNs in a2c.utils.

* revert baselines/run.py.

* replace `ppo2.step()` with original interface.

* revert `baselines/common/tests/util.py`.

* remove redundant lines.

* revert `baselines/common/test/util.py` to b875fb7.

* remove `states` variable.

* move RNN class to `baselines/ppo2/layers.py' and revert `baselines/common/models.py` to 858afa8.

* rename `model.step_as_dict` with `model.step_with_dict`.

* removed `ppo_lstm_mlp`.

* fix 02e26fd.
2019-04-26 15:17:56 -07:00
13 changed files with 626 additions and 276 deletions

View File

@@ -17,10 +17,10 @@ learn_kwargs = {
# 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001) # 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001)
} }
alg_list = learn_kwargs.keys() alg_list = learn_kwargs.keys()
rnn_list = ['lstm'] rnn_list = ['lstm']
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("alg", alg_list) @pytest.mark.parametrize("alg", alg_list)
@pytest.mark.parametrize("rnn", rnn_list) @pytest.mark.parametrize("rnn", rnn_list)
@@ -33,6 +33,9 @@ def test_fixed_sequence(alg, rnn):
kwargs = learn_kwargs[alg] kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs) kwargs.update(common_kwargs)
if alg == 'ppo2' and rnn.endswith('lstm'):
rnn = 'ppo_' + rnn
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5) env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
learn = lambda e: get_learn_function(alg)( learn = lambda e: get_learn_function(alg)(
env=e, env=e,
@@ -45,6 +48,3 @@ def test_fixed_sequence(alg, rnn):
if __name__ == '__main__': if __name__ == '__main__':
test_fixed_sequence('ppo2', 'lstm') test_fixed_sequence('ppo2', 'lstm')

View File

@@ -1,17 +1,16 @@
import os import os
import gym
import tempfile import tempfile
import pytest
import tensorflow as tf
import numpy as np
from baselines.common.tests.envs.mnist_env import MnistEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.run import get_learn_function
from baselines.common.tf_util import make_session, get_session
from functools import partial from functools import partial
import gym
import numpy as np
import pytest
import tensorflow as tf
from baselines.common.tests.envs.mnist_env import MnistEnv
from baselines.common.tf_util import make_session, get_session
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.run import get_learn_function
learn_kwargs = { learn_kwargs = {
'deepq': {}, 'deepq': {},
@@ -37,12 +36,15 @@ def test_serialization(learn_fn, network_fn):
Test if the trained model can be serialized Test if the trained model can be serialized
''' '''
_network_kwargs = network_kwargs[network_fn]
if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']: if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies # TODO make acktr work with recurrent policies
# and test # and test
# github issue: https://github.com/openai/baselines/issues/660 # github issue: https://github.com/openai/baselines/issues/660
return return
elif network_fn.endswith('lstm') and learn_fn == 'ppo2':
network_fn = 'ppo_' + network_fn
def make_env(): def make_env():
env = MnistEnv(episode_len=100) env = MnistEnv(episode_len=100)
@@ -54,10 +56,9 @@ def test_serialization(learn_fn, network_fn):
learn = get_learn_function(learn_fn) learn = get_learn_function(learn_fn)
kwargs = {} kwargs = {}
kwargs.update(network_kwargs[network_fn]) kwargs.update(_network_kwargs)
kwargs.update(learn_kwargs[learn_fn]) kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs) learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
@@ -117,7 +118,6 @@ def test_coexistence(learn_fn, network_fn):
model2.step(env.observation_space.sample()) model2.step(env.observation_space.sample())
def _serialize_variables(): def _serialize_variables():
sess = get_session() sess = get_session()
variables = tf.trainable_variables() variables = tf.trainable_variables()
@@ -136,4 +136,3 @@ def _get_action_stats(model, ob):
std = np.std(actions, axis=0) std = np.std(actions, axis=0)
return mean, std return mean, std

View File

@@ -3,6 +3,16 @@
- Original paper: https://arxiv.org/abs/1707.06347 - Original paper: https://arxiv.org/abs/1707.06347
- Baselines blog post: https://blog.openai.com/openai-baselines-ppo/ - Baselines blog post: https://blog.openai.com/openai-baselines-ppo/
## Examples
- `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. - `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
- `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M frames on a Mujoco Ant environment. - `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M frames on a Mujoco Ant environment.
- also refer to the repo-wide [README.md](../../README.md#training-models)
### RNN networks
- `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --network=ppo_cnn_lstm` runs on an Atari Pong with
`ppo_cnn_lstm` network.
- `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6 --network=ppo_lstm --value_network=copy`
runs on a Mujoco Ant environment with `ppo_lstm` network whose value and policy networks are separated, but have
same structure.
## See Also
- refer to the repo-wide [README.md](../../README.md#training-models)

View File

@@ -0,0 +1 @@
from baselines.ppo2.layers import ppo_lstm, ppo_cnn_lstm, ppo_cnn_lnlstm # pylint: disable=unused-import # noqa: F401

55
baselines/ppo2/layers.py Normal file
View File

@@ -0,0 +1,55 @@
import numpy as np
import tensorflow as tf
from baselines.a2c.utils import ortho_init, lstm, lnlstm
from baselines.common.models import register, nature_cnn
class RNN(object):
def __init__(self, func, memory_size=None):
self._func = func
self.memory_size = memory_size
def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
@register("ppo_lstm")
def ppo_lstm(num_units=128, layer_norm=False):
def network_fn(input, mask, state):
input = tf.layers.flatten(input)
mask = tf.to_float(mask)
if layer_norm:
h, next_state = lnlstm([input], [mask[:, None]], state, scope='lnlstm', nh=num_units)
else:
h, next_state = lstm([input], [mask[:, None]], state, scope='lstm', nh=num_units)
h = h[0]
return h, next_state
return RNN(network_fn, memory_size=num_units * 2)
@register("ppo_cnn_lstm")
def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
def network_fn(input, mask, state):
mask = tf.to_float(mask)
initializer = ortho_init(np.sqrt(2))
h = nature_cnn(input, **conv_kwargs)
h = tf.layers.flatten(h)
h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer)
if layer_norm:
h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
else:
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
h = h[0]
return h, next_state
return RNN(network_fn, memory_size=num_units * 2)
@register("ppo_cnn_lnlstm")
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)

View File

@@ -1,18 +1,21 @@
import tensorflow as tf
import numpy as np import numpy as np
import tensorflow as tf
from baselines.ppo2.model import Model from baselines.ppo2.model import Model
class MicrobatchedModel(Model): class MicrobatchedModel(Model):
""" """
Model that does training one microbatch at a time - when gradient computation Model that does training one microbatch at a time - when gradient computation
on the entire minibatch causes some overflow on the entire minibatch causes some overflow
""" """
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train, def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size): nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size):
self.nmicrobatches = nbatch_train // microbatch_size self.nmicrobatches = nbatch_train // microbatch_size
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(microbatch_size, nbatch_train) assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(
microbatch_size, nbatch_train)
super().__init__( super().__init__(
policy=policy, policy=policy,
@@ -29,14 +32,16 @@ class MicrobatchedModel(Model):
grads_ph_and_vars = list(zip(self.grads_ph, self.var)) grads_ph_and_vars = list(zip(self.grads_ph, self.var))
self._apply_gradients_op = self.trainer.apply_gradients(grads_ph_and_vars) self._apply_gradients_op = self.trainer.apply_gradients(grads_ph_and_vars)
def train(self,
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None): lr,
assert states is None, "microbatches with recurrent models are not supported yet" cliprange,
observations,
# Here we calculate advantage A(s,a) = R + yV(s') - V(s) advs,
# Returns = R + yV(s') returns,
advs = returns - values actions,
values,
neglogpacs,
**_kwargs):
# Normalize the advantages # Normalize the advantages
advs = (advs - advs.mean()) / (advs.std() + 1e-8) advs = (advs - advs.mean()) / (advs.std() + 1e-8)
@@ -45,16 +50,21 @@ class MicrobatchedModel(Model):
for microbatch_idx in range(self.nmicrobatches): for microbatch_idx in range(self.nmicrobatches):
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size) _sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size)
td_map = { td_map = {
self.train_model.X: obs[_sli], self.train_model.X: observations[_sli],
self.A: actions[_sli], self.A: actions[_sli],
self.ADV: advs[_sli], self.ADV: advs[_sli],
self.R:returns[_sli], self.RETURNS: returns[_sli],
self.LR: lr,
self.CLIPRANGE: cliprange, self.CLIPRANGE: cliprange,
self.OLDNEGLOGPAC: neglogpacs[_sli], self.OLDNEGLOGPAC: neglogpacs[_sli],
self.OLDVPRED:values[_sli] self.VALUE_PREV: values[_sli],
} }
sliced_kwargs = {key: _kwargs[key][_sli] for key in _kwargs}
td_map.update(self.train_model.feed_dict(**sliced_kwargs))
# Compute gradient on a microbatch (note that variables do not change here) ... # Compute gradient on a microbatch (note that variables do not change here) ...
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map) grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)
if microbatch_idx == 0: if microbatch_idx == 0:
@@ -71,6 +81,3 @@ class MicrobatchedModel(Model):
self.sess.run(self._apply_gradients_op, feed_dict) self.sess.run(self._apply_gradients_op, feed_dict)
# Return average of the stats # Return average of the stats
return np.mean(np.array(stats_vs), axis=0).tolist() return np.mean(np.array(stats_vs), axis=0).tolist()

View File

@@ -1,8 +1,8 @@
import tensorflow as tf
import functools import functools
import tensorflow as tf
from baselines.common.tf_util import get_session, save_variables, load_variables from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.common.tf_util import initialize
try: try:
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
@@ -11,6 +11,7 @@ try:
except ImportError: except ImportError:
MPI = None MPI = None
class Model(object): class Model(object):
""" """
We use this object to : We use this object to :
@@ -24,72 +25,82 @@ class Model(object):
save/load(): save/load():
- Save load the model - Save load the model
""" """
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size=None):
self.sess = sess = get_session()
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE): def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
nsteps, ent_coef, vf_coef, max_grad_norm,
name='ppo_model',
sess=None,
microbatch_size=None):
if sess is None:
sess = get_session()
self.sess = sess
self.name = name
with tf.variable_scope(name) as scope:
self.scope = scope
with tf.variable_scope('models', reuse=tf.AUTO_REUSE):
with tf.name_scope('act_model'):
# CREATE OUR TWO MODELS # CREATE OUR TWO MODELS
# act_model that is used for sampling # act_model that is used for sampling
act_model = policy(nbatch_act, 1, sess) act_model = policy(nbatch_act, 1, sess)
with tf.name_scope('train_model'):
# Train model for training # Train model for training
if microbatch_size is None: if microbatch_size is None:
train_model = policy(nbatch_train, nsteps, sess) train_model = policy(nbatch_train, nsteps, sess)
else: else:
train_model = policy(microbatch_size, nsteps, sess) train_model = policy(microbatch_size, nsteps, sess)
with tf.variable_scope('losses'):
# CREATE THE PLACEHOLDERS # CREATE THE PLACEHOLDERS
self.A = A = train_model.pdtype.sample_placeholder([None]) self.A = A = train_model.pdtype.sample_placeholder([None], name='action')
self.ADV = ADV = tf.placeholder(tf.float32, [None]) self.ADV = ADV = tf.placeholder(tf.float32, [None], name='advantage')
self.R = R = tf.placeholder(tf.float32, [None]) self.RETURNS = RETURNS = tf.placeholder(tf.float32, [None], name='reward')
# Keep track of old actor self.VALUE_PREV = VALUE_PREV = tf.placeholder(tf.float32, [None], name='value_prev')
self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None]) self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None],
# Keep track of old critic name='negative_log_p_action_old')
self.OLDVPRED = OLDVPRED = tf.placeholder(tf.float32, [None]) self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [], name='clip_range')
self.LR = LR = tf.placeholder(tf.float32, [])
# Cliprange
self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [])
with tf.name_scope("neglogpac"):
neglogpac = train_model.pd.neglogp(A) neglogpac = train_model.pd.neglogp(A)
with tf.name_scope("entropy"):
# Calculate the entropy # Calculate the entropy
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy. # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
entropy = tf.reduce_mean(train_model.pd.entropy()) entropy = tf.reduce_mean(train_model.pd.entropy())
entropy_loss = (- ent_coef) * entropy
with tf.name_scope("value_loss"):
# CALCULATE THE LOSS # CALCULATE THE LOSS
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss value = train_model.value
value_clipped = VALUE_PREV + tf.clip_by_value(value - VALUE_PREV, -CLIPRANGE, CLIPRANGE)
# Clip the value to reduce variability during Critic training vf_losses1 = tf.squared_difference(value, RETURNS)
# Get the predicted value vf_losses2 = tf.squared_difference(value_clipped, RETURNS)
vpred = train_model.vf vf_loss = 0.5 * vf_coef * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
# Unclipped value
vf_losses1 = tf.square(vpred - R)
# Clipped value
vf_losses2 = tf.square(vpredclipped - R)
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
with tf.name_scope("policy_loss"):
# Calculate ratio (pi current policy / pi old policy) # Calculate ratio (pi current policy / pi old policy)
ratio = tf.exp(OLDNEGLOGPAC - neglogpac) ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
# Defining Loss = - J is equivalent to max J
pg_losses = -ADV * ratio pg_losses = -ADV * ratio
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE) pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
# Final PG loss
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2)) pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
with tf.name_scope("approxkl"):
approxkl = .5 * tf.reduce_mean(tf.squared_difference(neglogpac, OLDNEGLOGPAC))
with tf.name_scope("clip_fraction"):
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE))) clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
# Total loss with tf.name_scope("total_loss"):
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef loss = pg_loss + entropy_loss + vf_loss
with tf.variable_scope('optimizer'):
self.LR = LR = tf.placeholder(tf.float32, [], name='learning_rate')
# UPDATE THE PARAMETERS USING LOSS # UPDATE THE PARAMETERS USING LOSS
# 1. Get the model parameters # 1. Get the model parameters
params = tf.trainable_variables('ppo2_model') params = tf.trainable_variables(self.scope.name)
# 2. Build our trainer # 2. Build our trainer
if MPI is not None: if MPI is not None:
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
@@ -103,54 +114,68 @@ class Model(object):
# Clip the gradients (normalize) # Clip the gradients (normalize)
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
grads_and_var = list(zip(grads, var)) grads_and_var = list(zip(grads, var))
# zip aggregate each gradient with parameters associated
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
self.grads = grads self.grads = grads
self.var = var self.var = var
self._train_op = self.trainer.apply_gradients(grads_and_var) self._train_op = self.trainer.apply_gradients(grads_and_var)
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
self.stats_list = [pg_loss, vf_loss, entropy, approxkl, clipfrac]
self.loss_names = ['policy_loss', 'value_loss', 'entropy_loss', 'approxkl', 'clipfrac',
'total_loss']
self.stats_list = [pg_loss, vf_loss, entropy_loss, approxkl, clipfrac, loss]
self.train_model = train_model self.train_model = train_model
self.act_model = act_model self.act_model = act_model
self.step = act_model.step
self.value = act_model.value
self.initial_state = act_model.initial_state self.initial_state = act_model.initial_state
self.save = functools.partial(save_variables, sess=sess) self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess) self.load = functools.partial(load_variables, sess=sess)
initialize() with tf.variable_scope('initialization'):
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="") sess.run(tf.initializers.variables(tf.global_variables(self.scope.name)))
sess.run(tf.initializers.variables(tf.local_variables(self.scope.name)))
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope.name)
if MPI is not None: if MPI is not None:
sync_from_root(sess, global_variables) # pylint: disable=E1101 sync_from_root(sess, global_variables) # pylint: disable=E1101
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None): def step_with_dict(self, **kwargs):
# Here we calculate advantage A(s,a) = R + yV(s') - V(s) return self.act_model.step(**kwargs)
# Returns = R + yV(s')
advs = returns - values
def step(self, obs, M=None, S=None, **kwargs):
kwargs.update({'observations': obs})
if M is not None and S is not None:
kwargs.update({'dones': M})
kwargs.update({'states': S})
transition = self.act_model.step(**kwargs)
states = transition['next_states'] if 'next_states' in transition else None
return transition['actions'], transition['values'], states, transition['neglogpacs']
def train(self,
lr,
cliprange,
observations,
advs,
returns,
actions,
values,
neglogpacs,
**_kwargs):
# Normalize the advantages # Normalize the advantages
advs = (advs - advs.mean()) / (advs.std() + 1e-8) advs = (advs - advs.mean()) / (advs.std() + 1e-8)
td_map = { td_map = {
self.train_model.X : obs, self.train_model.X: observations,
self.A: actions, self.A: actions,
self.ADV: advs, self.ADV: advs,
self.R : returns, self.RETURNS: returns,
self.LR: lr, self.LR: lr,
self.CLIPRANGE: cliprange, self.CLIPRANGE: cliprange,
self.OLDNEGLOGPAC: neglogpacs, self.OLDNEGLOGPAC: neglogpacs,
self.OLDVPRED : values self.VALUE_PREV: values,
} }
if states is not None:
td_map[self.train_model.S] = states td_map.update(self.train_model.feed_dict(**_kwargs))
td_map[self.train_model.M] = masks
return self.sess.run( return self.sess.run(
self.stats_list + [self._train_op], self.stats_list + [self._train_op],
td_map td_map
)[:-1] )[:-1]

188
baselines/ppo2/policies.py Normal file
View File

@@ -0,0 +1,188 @@
import gym
import numpy as np
import tensorflow as tf
from baselines.a2c.utils import fc
from baselines.common import tf_util
from baselines.common.distributions import make_pdtype
from baselines.common.input import observation_placeholder, encode_observation
from baselines.common.models import get_network_builder
from baselines.common.tf_util import adjust_shape
from baselines.ppo2.layers import RNN
class PolicyWithValue(object):
"""
Encapsulates fields and methods for RL policy and two value function estimation with shared parameters
"""
def __init__(self, env, observations, latent, dones, states=None, estimate_q=False, vf_latent=None, sess=None):
"""
Parameters:
----------
env RL environment
observations tensorflow placeholder in which the observations will be fed
latent latent state from which policy distribution parameters should be inferred
vf_latent latent state from which value function should be inferred (if None, then latent is used)
sess tensorflow session to run calculations in (if None, default session is used)
**tensors tensorflow tensors for additional attributes such as state or mask
"""
self.X = observations
self.dones = dones
self.pdtype = make_pdtype(env.action_space)
self.states = states
self.sess = sess or tf.get_default_session()
vf_latent = vf_latent if vf_latent is not None else latent
with tf.variable_scope('policy'):
latent = tf.layers.flatten(latent)
# Based on the action space, will select what probability distribution type
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
with tf.variable_scope('sample_action'):
self.action = self.pd.sample()
with tf.variable_scope('negative_log_probability'):
# Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action)
with tf.variable_scope('value'):
vf_latent = tf.layers.flatten(vf_latent)
if estimate_q:
assert isinstance(env.action_space, gym.spaces.Discrete)
self.q = fc(vf_latent, 'q', env.action_space.n)
self.value = self.q
else:
vf_latent = tf.layers.flatten(vf_latent)
self.value = fc(vf_latent, 'value', 1, init_scale=0.01)
self.value = self.value[:, 0]
self.step_input = {
'observations': observations,
'dones': self.dones,
}
self.step_output = {
'actions': self.action,
'values': self.value,
'neglogpacs': self.neglogp,
}
if self.states:
self.initial_state = np.zeros(self.states['current'].get_shape())
self.step_input.update({'states': self.states['current']})
self.step_output.update({'states': self.states['current'],
'next_states': self.states['next']})
else:
self.initial_state = None
def feed_dict(self, **kwargs):
feed_dict = {}
for key in kwargs:
if key in self.step_input:
feed_dict[self.step_input[key]] = adjust_shape(self.step_input[key], kwargs[key])
return feed_dict
def step(self, **kwargs):
return self.sess.run(self.step_output,
feed_dict=self.feed_dict(**kwargs))
def values(self, **kwargs):
return self.sess.run({'values': self.value},
feed_dict=self.feed_dict(**kwargs))
def save(self, save_path):
tf_util.save_state(save_path, sess=self.sess)
def load(self, load_path):
tf_util.load_state(load_path, sess=self.sess)
def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, **policy_kwargs):
if isinstance(policy_network, str):
network_type = policy_network
policy_network = get_network_builder(network_type)(**policy_kwargs)
if value_network is None:
value_network = 'shared'
def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
next_states_list = []
state_map = {}
state_placeholder = None
ob_space = env.observation_space
X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space,
batch_size=nbatch)
dones = tf.placeholder(tf.float32, shape=[X.shape[0]], name='dones')
encoded_x = encode_observation(ob_space, X)
with tf.variable_scope('current_rnn_memory'):
if value_network == 'shared':
value_network_ = value_network
else:
if value_network == 'copy':
value_network_ = policy_network
else:
assert callable(value_network)
value_network_ = value_network
policy_memory_size = policy_network.memory_size if isinstance(policy_network, RNN) else 0
value_memory_size = value_network_.memory_size if isinstance(value_network_, RNN) else 0
state_size = policy_memory_size + value_memory_size
if state_size > 0:
state_placeholder = tf.placeholder(dtype=tf.float32, shape=(nbatch, state_size),
name='states')
state_map['policy'] = state_placeholder[:, 0:policy_memory_size]
state_map['value'] = state_placeholder[:, policy_memory_size:]
with tf.variable_scope('policy_latent', reuse=tf.AUTO_REUSE):
if isinstance(policy_network, RNN):
assert policy_memory_size > 0
policy_latent, next_policy_state = \
policy_network(encoded_x, dones, state_map['policy'])
next_states_list.append(next_policy_state)
else:
policy_latent = policy_network(encoded_x)
with tf.variable_scope('value_latent', reuse=tf.AUTO_REUSE):
if value_network_ == 'shared':
value_latent = policy_latent
elif isinstance(value_network_, RNN):
assert value_memory_size > 0
value_latent, next_value_state = \
value_network_(encoded_x, dones, state_map['value'])
next_states_list.append(next_value_state)
else:
value_latent = value_network_(encoded_x)
with tf.name_scope("next_rnn_memory"):
if state_size > 0:
next_states = tf.concat(next_states_list, axis=1)
state_info = {'current': state_placeholder,
'next': next_states, }
else:
state_info = None
policy = PolicyWithValue(
env=env,
observations=X,
dones=dones,
latent=policy_latent,
vf_latent=value_latent,
states=state_info,
sess=sess,
estimate_q=estimate_q,
)
return policy
return policy_fn

View File

@@ -1,28 +1,35 @@
import os import os
import time
import numpy as np
import os.path as osp import os.path as osp
from baselines import logger import time
from collections import deque from collections import deque
from baselines.common import explained_variance, set_global_seeds
from baselines.common.policies import build_policy import numpy as np
import tensorflow as tf
from baselines import logger
from baselines.common import explained_variance
from baselines.common import set_global_seeds
from baselines.common.tf_util import display_var_info
from baselines.ppo2.policies import build_ppo_policy
from baselines.ppo2.runner import Runner
try: try:
from mpi4py import MPI from mpi4py import MPI
except ImportError: except ImportError:
MPI = None MPI = None
from baselines.ppo2.runner import Runner
def constfn(val): def constfn(val):
def f(_): def f(_):
return val return val
return f return f
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
def learn(*, network, env, total_timesteps, eval_env=None, seed=None, nsteps=128, ent_coef=0.0, lr=3e-4,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
save_interval=0, load_path=None, model_fn=None, **network_kwargs): save_interval=10, load_path=None, model_fn=None, **network_kwargs):
''' """
Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347) Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)
Parameters: Parameters:
@@ -52,7 +59,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
max_grad_norm: float or None gradient norm clipping coefficient max_grad_norm: float or None gradient norm clipping coefficient
gamma: float discounting factor gamma: float discounting factor for rewards
lam: float advantage estimation discounting factor (lambda in the paper) lam: float advantage estimation discounting factor (lambda in the paper)
@@ -72,20 +79,21 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network **network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
For instance, 'mlp' network architecture has arguments num_hidden and num_layers. For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
"""
'''
set_global_seeds(seed) set_global_seeds(seed)
if isinstance(lr, float): lr = constfn(lr) if isinstance(lr, float):
else: assert callable(lr) lr = constfn(lr)
if isinstance(cliprange, float): cliprange = constfn(cliprange) else:
else: assert callable(cliprange) assert callable(lr)
if isinstance(cliprange, float):
cliprange = constfn(cliprange)
else:
assert callable(cliprange)
total_timesteps = int(total_timesteps) total_timesteps = int(total_timesteps)
policy = build_policy(env, network, **network_kwargs) policy = build_ppo_policy(env, network, **network_kwargs)
# Get the nb of env # Get the nb of env
nenvs = env.num_envs nenvs = env.num_envs
@@ -104,15 +112,19 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
model_fn = Model model_fn = Model
model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm)
max_grad_norm=max_grad_norm)
if load_path is not None: if load_path is not None:
model.load(load_path) model.load(load_path)
allvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name)
display_var_info(allvars)
# Instantiate the runner object # Instantiate the runner object
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam)
if eval_env is not None: if eval_env is not None:
eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam) eval_runner = Runner(env=eval_env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam)
epinfobuf = deque(maxlen=100) epinfobuf = deque(maxlen=100)
if eval_env is not None: if eval_env is not None:
@@ -120,8 +132,8 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
# Start total timer # Start total timer
tfirststart = time.perf_counter() tfirststart = time.perf_counter()
nupdates = total_timesteps // nbatch nupdates = total_timesteps // nbatch
for update in range(1, nupdates + 1): for update in range(1, nupdates + 1):
assert nbatch % nminibatches == 0 assert nbatch % nminibatches == 0
# Start timer # Start timer
@@ -131,18 +143,28 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
lrnow = lr(frac) lrnow = lr(frac)
# Calculate the cliprange # Calculate the cliprange
cliprangenow = cliprange(frac) cliprangenow = cliprange(frac)
# Get minibatch
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
if eval_env is not None:
eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632
epinfobuf.extend(epinfos) # Get minibatch
minibatch = runner.run()
if eval_env is not None:
eval_minibatch = eval_runner.run()
_eval_obs = eval_minibatch['observations'] # noqa: F841
_eval_returns = eval_minibatch['returns'] # noqa: F841
_eval_masks = eval_minibatch['masks'] # noqa: F841
_eval_actions = eval_minibatch['actions'] # noqa: F841
_eval_values = eval_minibatch['values'] # noqa: F841
_eval_neglogpacs = eval_minibatch['neglogpacs'] # noqa: F841
_eval_states = eval_minibatch['state'] # noqa: F841
eval_epinfos = eval_minibatch['epinfos']
epinfobuf.extend(minibatch.pop('epinfos'))
if eval_env is not None: if eval_env is not None:
eval_epinfobuf.extend(eval_epinfos) eval_epinfobuf.extend(eval_epinfos)
# Here what we're going to do is for each minibatch calculate the loss and append it. # Here what we're going to do is for each minibatch calculate the loss and append it.
mblossvals = [] mblossvals = []
if states is None: # nonrecurrent version
# Index of each element of batch_size # Index of each element of batch_size
# Create the indices array # Create the indices array
inds = np.arange(nbatch) inds = np.arange(nbatch)
@@ -153,22 +175,8 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
for start in range(0, nbatch, nbatch_train): for start in range(0, nbatch, nbatch_train):
end = start + nbatch_train end = start + nbatch_train
mbinds = inds[start:end] mbinds = inds[start:end]
slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) slices = {key: minibatch[key][mbinds] for key in minibatch}
mblossvals.append(model.train(lrnow, cliprangenow, *slices)) mblossvals.append(model.train(lrnow, cliprangenow, **slices))
else: # recurrent version
assert nenvs % nminibatches == 0
envsperbatch = nenvs // nminibatches
envinds = np.arange(nenvs)
flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
for _ in range(noptepochs):
np.random.shuffle(envinds)
for start in range(0, nenvs, envsperbatch):
end = start + envsperbatch
mbenvinds = envinds[start:end]
mbflatinds = flatinds[mbenvinds].ravel()
slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
mbstates = states[mbenvinds]
mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))
# Feedforward --> get losses --> update # Feedforward --> get losses --> update
lossvals = np.mean(mblossvals, axis=0) lossvals = np.mean(mblossvals, axis=0)
@@ -179,7 +187,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
if update % log_interval == 0 or update == 1: if update % log_interval == 0 or update == 1:
# Calculates if value function is a good predicator of the returns (ev > 1) # Calculates if value function is a good predicator of the returns (ev > 1)
# or if it's just worse than predicting nothing (ev =< 0) # or if it's just worse than predicting nothing (ev =< 0)
ev = explained_variance(values, returns) ev = explained_variance(minibatch['values'], minibatch['returns'])
logger.logkv("serial_timesteps", update * nsteps) logger.logkv("serial_timesteps", update * nsteps)
logger.logkv("nupdates", update) logger.logkv("nupdates", update)
logger.logkv("total_timesteps", update * nbatch) logger.logkv("total_timesteps", update * nbatch)
@@ -187,6 +195,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
logger.logkv("explained_variance", float(ev)) logger.logkv("explained_variance", float(ev))
logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
logger.logkv('rewards_per_step', safemean(minibatch['rewards']))
logger.logkv('advantages_per_step', safemean(minibatch['advs']))
if eval_env is not None: if eval_env is not None:
logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf])) logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]))
logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf])) logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]))
@@ -195,16 +206,17 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
logger.logkv(lossname, lossval) logger.logkv(lossname, lossval)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
logger.dumpkvs() logger.dumpkvs()
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0): if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (
MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
checkdir = osp.join(logger.get_dir(), 'checkpoints') checkdir = osp.join(logger.get_dir(), 'checkpoints')
os.makedirs(checkdir, exist_ok=True) os.makedirs(checkdir, exist_ok=True)
savepath = osp.join(checkdir, '%.5i' % update) savepath = osp.join(checkdir, '%.5i' % update)
print('Saving to', savepath) print('Saving to', savepath)
model.save(savepath) model.save(savepath)
del minibatch
return model return model
# Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error) # Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error)
def safemean(xs): def safemean(xs):
return np.nan if len(xs) == 0 else np.mean(xs) return np.nan if len(xs) == 0 else np.mean(xs)

Binary file not shown.

After

Width:  |  Height:  |  Size: 177 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

View File

@@ -1,6 +1,8 @@
import numpy as np import numpy as np
from baselines.common.runners import AbstractEnvRunner from baselines.common.runners import AbstractEnvRunner
class Runner(AbstractEnvRunner): class Runner(AbstractEnvRunner):
""" """
We use this object to make a mini batch of experiences We use this object to make a mini batch of experiences
@@ -10,67 +12,118 @@ class Runner(AbstractEnvRunner):
run(): run():
- Make a mini batch - Make a mini batch
""" """
def __init__(self, *, env, model, nsteps, gamma, lam):
def __init__(self, *, env, model, nsteps, gamma, ob_space, lam):
super().__init__(env=env, model=model, nsteps=nsteps) super().__init__(env=env, model=model, nsteps=nsteps)
# Lambda used in GAE (General Advantage Estimation)
self.lam = lam self.lam = lam # Lambda used in GAE (General Advantage Estimation)
# Discount rate self.gamma = gamma # Discount rate for rewards
self.gamma = gamma self.ob_space = ob_space
def run(self): def run(self):
# Here, we init the lists that will contain the mb of experiences minibatch = {
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[] "observations": [],
mb_states = self.states "actions": [],
"rewards": [],
"values": [],
"dones": [],
"neglogpacs": [],
}
data_type = {
"observations": self.obs.dtype,
"actions": np.float32,
"rewards": np.float32,
"values": np.float32,
"dones": np.float32,
"neglogpacs": np.float32,
}
prev_transition = {'next_states': self.model.initial_state} if self.model.initial_state is not None else {}
epinfos = [] epinfos = []
# For n in range number of steps # For n in range number of steps
for _ in range(self.nsteps): for _ in range(self.nsteps):
# Given observations, get action value and neglopacs transitions = {}
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init transitions['observations'] = self.obs.copy()
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones) transitions['dones'] = self.dones
mb_obs.append(self.obs.copy()) if 'next_states' in prev_transition:
mb_actions.append(actions) transitions['states'] = prev_transition['next_states']
mb_values.append(values) transitions.update(self.model.step_with_dict(**transitions))
mb_neglogpacs.append(neglogpacs)
mb_dones.append(self.dones)
# Take actions in env and look the results # Take actions in env and look the results
# Infos contains a ton of useful informations # Infos contains a ton of useful informations
self.obs[:], rewards, self.dones, infos = self.env.step(actions) self.obs, transitions['rewards'], self.dones, infos = self.env.step(transitions['actions'])
self.dones = np.array(self.dones, dtype=np.float)
for info in infos: for info in infos:
maybeepinfo = info.get('episode') maybeepinfo = info.get('episode')
if maybeepinfo: epinfos.append(maybeepinfo) if maybeepinfo:
mb_rewards.append(rewards) epinfos.append(maybeepinfo)
#batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
mb_actions = np.asarray(mb_actions)
mb_values = np.asarray(mb_values, dtype=np.float32)
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
mb_dones = np.asarray(mb_dones, dtype=np.bool)
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
# discount/bootstrap off value fn for key in transitions:
mb_returns = np.zeros_like(mb_rewards) if key not in minibatch:
mb_advs = np.zeros_like(mb_rewards) minibatch[key] = []
lastgaelam = 0 minibatch[key].append(transitions[key])
prev_transition = transitions
for key in minibatch:
dtype = data_type[key] if key in data_type else np.float
minibatch[key] = np.array(minibatch[key], dtype=dtype)
transitions['observations'] = self.obs.copy()
transitions['dones'] = self.dones
if 'states' in transitions:
transitions['states'] = transitions.pop('next_states')
for key in minibatch:
dtype = data_type[key] if key in data_type else np.float
minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
last_values = self.model.step_with_dict(**transitions)['values']
# Calculate returns and advantages.
minibatch['advs'], minibatch['returns'] = \
self.advantage_and_returns(values=minibatch['values'],
rewards=minibatch['rewards'],
dones=minibatch['dones'],
last_values=last_values,
last_dones=self.dones,
gamma=self.gamma)
for key in minibatch:
minibatch[key] = sf01(minibatch[key])
minibatch['epinfos'] = epinfos
return minibatch
def advantage_and_returns(self, values, rewards, dones, last_values, last_dones, gamma,
use_non_episodic_rewards=False):
"""
calculate Generalized Advantage Estimation (GAE), https://arxiv.org/abs/1506.02438
see also Proximal Policy Optimization Algorithms, https://arxiv.org/abs/1707.06347
"""
advantages = np.zeros_like(rewards)
lastgaelam = 0 # Lambda used in General Advantage Estimation
for t in reversed(range(self.nsteps)): for t in reversed(range(self.nsteps)):
if not use_non_episodic_rewards:
if t == self.nsteps - 1: if t == self.nsteps - 1:
nextnonterminal = 1.0 - self.dones next_non_terminal = 1.0 - last_dones
nextvalues = last_values
else: else:
nextnonterminal = 1.0 - mb_dones[t+1] next_non_terminal = 1.0 - dones[t + 1]
nextvalues = mb_values[t+1] else:
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t] next_non_terminal = 1.0
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam next_value = values[t + 1] if t < self.nsteps - 1 else last_values
mb_returns = mb_advs + mb_values delta = rewards[t] + gamma * next_value * next_non_terminal - values[t]
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)), advantages[t] = lastgaelam = delta + gamma * self.lam * next_non_terminal * lastgaelam
mb_states, epinfos) returns = advantages + values
# obs, returns, masks, actions, values, neglogpacs, states = runner.run() return advantages, returns
def sf01(arr): def sf01(arr):
""" """
swap and then flatten axes 0 and 1 swap and then flatten axes 0 and 1
""" """
s = arr.shape s = arr.shape
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:]) return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])