Compare commits
1 Commits
master
...
stateful_r
Author | SHA1 | Date | |
---|---|---|---|
|
fc0c43b199 |
@@ -17,10 +17,10 @@ learn_kwargs = {
|
||||
# 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001)
|
||||
}
|
||||
|
||||
|
||||
alg_list = learn_kwargs.keys()
|
||||
rnn_list = ['lstm']
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alg", alg_list)
|
||||
@pytest.mark.parametrize("rnn", rnn_list)
|
||||
@@ -33,6 +33,9 @@ def test_fixed_sequence(alg, rnn):
|
||||
kwargs = learn_kwargs[alg]
|
||||
kwargs.update(common_kwargs)
|
||||
|
||||
if alg == 'ppo2' and rnn.endswith('lstm'):
|
||||
rnn = 'ppo_' + rnn
|
||||
|
||||
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
|
||||
learn = lambda e: get_learn_function(alg)(
|
||||
env=e,
|
||||
@@ -45,6 +48,3 @@ def test_fixed_sequence(alg, rnn):
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fixed_sequence('ppo2', 'lstm')
|
||||
|
||||
|
||||
|
||||
|
@@ -1,17 +1,16 @@
|
||||
import os
|
||||
import gym
|
||||
import tempfile
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.tests.envs.mnist_env import MnistEnv
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.run import get_learn_function
|
||||
from baselines.common.tf_util import make_session, get_session
|
||||
|
||||
from functools import partial
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.common.tests.envs.mnist_env import MnistEnv
|
||||
from baselines.common.tf_util import make_session, get_session
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.run import get_learn_function
|
||||
|
||||
learn_kwargs = {
|
||||
'deepq': {},
|
||||
@@ -37,12 +36,15 @@ def test_serialization(learn_fn, network_fn):
|
||||
Test if the trained model can be serialized
|
||||
'''
|
||||
|
||||
_network_kwargs = network_kwargs[network_fn]
|
||||
|
||||
if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
|
||||
# TODO make acktr work with recurrent policies
|
||||
# and test
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
# TODO make acktr work with recurrent policies
|
||||
# and test
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
elif network_fn.endswith('lstm') and learn_fn == 'ppo2':
|
||||
network_fn = 'ppo_' + network_fn
|
||||
|
||||
def make_env():
|
||||
env = MnistEnv(episode_len=100)
|
||||
@@ -54,10 +56,9 @@ def test_serialization(learn_fn, network_fn):
|
||||
learn = get_learn_function(learn_fn)
|
||||
|
||||
kwargs = {}
|
||||
kwargs.update(network_kwargs[network_fn])
|
||||
kwargs.update(_network_kwargs)
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
@@ -76,7 +77,7 @@ def test_serialization(learn_fn, network_fn):
|
||||
|
||||
for k, v in variables_dict1.items():
|
||||
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
|
||||
err_msg='saved and loaded variable {} value mismatch'.format(k))
|
||||
err_msg='saved and loaded variable {} value mismatch'.format(k))
|
||||
|
||||
np.testing.assert_allclose(mean1, mean2, atol=0.5)
|
||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||
@@ -90,15 +91,15 @@ def test_coexistence(learn_fn, network_fn):
|
||||
'''
|
||||
|
||||
if learn_fn == 'deepq':
|
||||
# TODO enable multiple DQN models to be useable at the same time
|
||||
# github issue https://github.com/openai/baselines/issues/656
|
||||
return
|
||||
# TODO enable multiple DQN models to be useable at the same time
|
||||
# github issue https://github.com/openai/baselines/issues/656
|
||||
return
|
||||
|
||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||
# TODO make acktr work with recurrent policies
|
||||
# and test
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
# TODO make acktr work with recurrent policies
|
||||
# and test
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
|
||||
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
|
||||
learn = get_learn_function(learn_fn)
|
||||
@@ -107,7 +108,7 @@ def test_coexistence(learn_fn, network_fn):
|
||||
kwargs.update(network_kwargs[network_fn])
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
model1 = learn(seed=1)
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
@@ -117,7 +118,6 @@ def test_coexistence(learn_fn, network_fn):
|
||||
model2.step(env.observation_space.sample())
|
||||
|
||||
|
||||
|
||||
def _serialize_variables():
|
||||
sess = get_session()
|
||||
variables = tf.trainable_variables()
|
||||
@@ -136,4 +136,3 @@ def _get_action_stats(model, ob):
|
||||
std = np.std(actions, axis=0)
|
||||
|
||||
return mean, std
|
||||
|
||||
|
@@ -3,6 +3,16 @@
|
||||
- Original paper: https://arxiv.org/abs/1707.06347
|
||||
- Baselines blog post: https://blog.openai.com/openai-baselines-ppo/
|
||||
|
||||
## Examples
|
||||
- `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
|
||||
- `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M frames on a Mujoco Ant environment.
|
||||
- also refer to the repo-wide [README.md](../../README.md#training-models)
|
||||
|
||||
### RNN networks
|
||||
- `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --network=ppo_cnn_lstm` runs on an Atari Pong with
|
||||
`ppo_cnn_lstm` network.
|
||||
- `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6 --network=ppo_lstm --value_network=copy`
|
||||
runs on a Mujoco Ant environment with `ppo_lstm` network whose value and policy networks are separated, but have
|
||||
same structure.
|
||||
|
||||
## See Also
|
||||
- refer to the repo-wide [README.md](../../README.md#training-models)
|
||||
|
@@ -0,0 +1 @@
|
||||
from baselines.ppo2.layers import ppo_lstm, ppo_cnn_lstm, ppo_cnn_lnlstm # pylint: disable=unused-import # noqa: F401
|
||||
|
55
baselines/ppo2/layers.py
Normal file
55
baselines/ppo2/layers.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.a2c.utils import ortho_init, lstm, lnlstm
|
||||
from baselines.common.models import register, nature_cnn
|
||||
|
||||
|
||||
class RNN(object):
|
||||
def __init__(self, func, memory_size=None):
|
||||
self._func = func
|
||||
self.memory_size = memory_size
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
|
||||
@register("ppo_lstm")
|
||||
def ppo_lstm(num_units=128, layer_norm=False):
|
||||
def network_fn(input, mask, state):
|
||||
input = tf.layers.flatten(input)
|
||||
mask = tf.to_float(mask)
|
||||
|
||||
if layer_norm:
|
||||
h, next_state = lnlstm([input], [mask[:, None]], state, scope='lnlstm', nh=num_units)
|
||||
else:
|
||||
h, next_state = lstm([input], [mask[:, None]], state, scope='lstm', nh=num_units)
|
||||
h = h[0]
|
||||
return h, next_state
|
||||
|
||||
return RNN(network_fn, memory_size=num_units * 2)
|
||||
|
||||
|
||||
@register("ppo_cnn_lstm")
|
||||
def ppo_cnn_lstm(num_units=128, layer_norm=False, **conv_kwargs):
|
||||
def network_fn(input, mask, state):
|
||||
mask = tf.to_float(mask)
|
||||
initializer = ortho_init(np.sqrt(2))
|
||||
|
||||
h = nature_cnn(input, **conv_kwargs)
|
||||
h = tf.layers.flatten(h)
|
||||
h = tf.layers.dense(h, units=512, activation=tf.nn.relu, kernel_initializer=initializer)
|
||||
|
||||
if layer_norm:
|
||||
h, next_state = lnlstm([h], [mask[:, None]], state, scope='lnlstm', nh=num_units)
|
||||
else:
|
||||
h, next_state = lstm([h], [mask[:, None]], state, scope='lstm', nh=num_units)
|
||||
h = h[0]
|
||||
return h, next_state
|
||||
|
||||
return RNN(network_fn, memory_size=num_units * 2)
|
||||
|
||||
|
||||
@register("ppo_cnn_lnlstm")
|
||||
def ppo_cnn_lnlstm(num_units=128, **conv_kwargs):
|
||||
return ppo_cnn_lstm(num_units, layer_norm=True, **conv_kwargs)
|
@@ -1,42 +1,47 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines.ppo2.model import Model
|
||||
|
||||
|
||||
class MicrobatchedModel(Model):
|
||||
"""
|
||||
Model that does training one microbatch at a time - when gradient computation
|
||||
on the entire minibatch causes some overflow
|
||||
"""
|
||||
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size):
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size):
|
||||
|
||||
self.nmicrobatches = nbatch_train // microbatch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(microbatch_size, nbatch_train)
|
||||
assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(
|
||||
microbatch_size, nbatch_train)
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
ob_space=ob_space,
|
||||
ac_space=ac_space,
|
||||
nbatch_act=nbatch_act,
|
||||
nbatch_train=microbatch_size,
|
||||
nsteps=nsteps,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
policy=policy,
|
||||
ob_space=ob_space,
|
||||
ac_space=ac_space,
|
||||
nbatch_act=nbatch_act,
|
||||
nbatch_train=microbatch_size,
|
||||
nsteps=nsteps,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
|
||||
self.grads_ph = [tf.placeholder(dtype=g.dtype, shape=g.shape) for g in self.grads]
|
||||
grads_ph_and_vars = list(zip(self.grads_ph, self.var))
|
||||
self._apply_gradients_op = self.trainer.apply_gradients(grads_ph_and_vars)
|
||||
|
||||
|
||||
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
assert states is None, "microbatches with recurrent models are not supported yet"
|
||||
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
def train(self,
|
||||
lr,
|
||||
cliprange,
|
||||
observations,
|
||||
advs,
|
||||
returns,
|
||||
actions,
|
||||
values,
|
||||
neglogpacs,
|
||||
**_kwargs):
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
|
||||
@@ -44,19 +49,24 @@ class MicrobatchedModel(Model):
|
||||
stats_vs = []
|
||||
|
||||
for microbatch_idx in range(self.nmicrobatches):
|
||||
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx+1) * self.microbatch_size)
|
||||
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size)
|
||||
|
||||
td_map = {
|
||||
self.train_model.X: obs[_sli],
|
||||
self.A:actions[_sli],
|
||||
self.ADV:advs[_sli],
|
||||
self.R:returns[_sli],
|
||||
self.CLIPRANGE:cliprange,
|
||||
self.OLDNEGLOGPAC:neglogpacs[_sli],
|
||||
self.OLDVPRED:values[_sli]
|
||||
self.train_model.X: observations[_sli],
|
||||
self.A: actions[_sli],
|
||||
self.ADV: advs[_sli],
|
||||
self.RETURNS: returns[_sli],
|
||||
self.LR: lr,
|
||||
self.CLIPRANGE: cliprange,
|
||||
self.OLDNEGLOGPAC: neglogpacs[_sli],
|
||||
self.VALUE_PREV: values[_sli],
|
||||
}
|
||||
|
||||
sliced_kwargs = {key: _kwargs[key][_sli] for key in _kwargs}
|
||||
td_map.update(self.train_model.feed_dict(**sliced_kwargs))
|
||||
|
||||
# Compute gradient on a microbatch (note that variables do not change here) ...
|
||||
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)
|
||||
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)
|
||||
if microbatch_idx == 0:
|
||||
sum_grad_v = grad_v
|
||||
else:
|
||||
@@ -71,6 +81,3 @@ class MicrobatchedModel(Model):
|
||||
self.sess.run(self._apply_gradients_op, feed_dict)
|
||||
# Return average of the stats
|
||||
return np.mean(np.array(stats_vs), axis=0).tolist()
|
||||
|
||||
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import tensorflow as tf
|
||||
import functools
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
try:
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
@@ -11,6 +11,7 @@ try:
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
We use this object to :
|
||||
@@ -24,133 +25,157 @@ class Model(object):
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size=None):
|
||||
self.sess = sess = get_session()
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm,
|
||||
name='ppo_model',
|
||||
sess=None,
|
||||
microbatch_size=None):
|
||||
if sess is None:
|
||||
sess = get_session()
|
||||
self.sess = sess
|
||||
self.name = name
|
||||
|
||||
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
with tf.variable_scope(name) as scope:
|
||||
self.scope = scope
|
||||
with tf.variable_scope('models', reuse=tf.AUTO_REUSE):
|
||||
with tf.name_scope('act_model'):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
|
||||
# Train model for training
|
||||
if microbatch_size is None:
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
else:
|
||||
train_model = policy(microbatch_size, nsteps, sess)
|
||||
with tf.name_scope('train_model'):
|
||||
# Train model for training
|
||||
if microbatch_size is None:
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
else:
|
||||
train_model = policy(microbatch_size, nsteps, sess)
|
||||
|
||||
# CREATE THE PLACEHOLDERS
|
||||
self.A = A = train_model.pdtype.sample_placeholder([None])
|
||||
self.ADV = ADV = tf.placeholder(tf.float32, [None])
|
||||
self.R = R = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old actor
|
||||
self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old critic
|
||||
self.OLDVPRED = OLDVPRED = tf.placeholder(tf.float32, [None])
|
||||
self.LR = LR = tf.placeholder(tf.float32, [])
|
||||
# Cliprange
|
||||
self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [])
|
||||
with tf.variable_scope('losses'):
|
||||
# CREATE THE PLACEHOLDERS
|
||||
self.A = A = train_model.pdtype.sample_placeholder([None], name='action')
|
||||
self.ADV = ADV = tf.placeholder(tf.float32, [None], name='advantage')
|
||||
self.RETURNS = RETURNS = tf.placeholder(tf.float32, [None], name='reward')
|
||||
self.VALUE_PREV = VALUE_PREV = tf.placeholder(tf.float32, [None], name='value_prev')
|
||||
self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None],
|
||||
name='negative_log_p_action_old')
|
||||
self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [], name='clip_range')
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
with tf.name_scope("neglogpac"):
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
with tf.name_scope("entropy"):
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
entropy_loss = (- ent_coef) * entropy
|
||||
|
||||
# CALCULATE THE LOSS
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
with tf.name_scope("value_loss"):
|
||||
# CALCULATE THE LOSS
|
||||
value = train_model.value
|
||||
value_clipped = VALUE_PREV + tf.clip_by_value(value - VALUE_PREV, -CLIPRANGE, CLIPRANGE)
|
||||
vf_losses1 = tf.squared_difference(value, RETURNS)
|
||||
vf_losses2 = tf.squared_difference(value_clipped, RETURNS)
|
||||
vf_loss = 0.5 * vf_coef * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
|
||||
# Clip the value to reduce variability during Critic training
|
||||
# Get the predicted value
|
||||
vpred = train_model.vf
|
||||
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
|
||||
# Unclipped value
|
||||
vf_losses1 = tf.square(vpred - R)
|
||||
# Clipped value
|
||||
vf_losses2 = tf.square(vpredclipped - R)
|
||||
with tf.name_scope("policy_loss"):
|
||||
# Calculate ratio (pi current policy / pi old policy)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
pg_losses = -ADV * ratio
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
|
||||
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
with tf.name_scope("approxkl"):
|
||||
approxkl = .5 * tf.reduce_mean(tf.squared_difference(neglogpac, OLDNEGLOGPAC))
|
||||
|
||||
# Calculate ratio (pi current policy / pi old policy)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
with tf.name_scope("clip_fraction"):
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
|
||||
# Defining Loss = - J is equivalent to max J
|
||||
pg_losses = -ADV * ratio
|
||||
with tf.name_scope("total_loss"):
|
||||
loss = pg_loss + entropy_loss + vf_loss
|
||||
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
with tf.variable_scope('optimizer'):
|
||||
self.LR = LR = tf.placeholder(tf.float32, [], name='learning_rate')
|
||||
|
||||
# Final PG loss
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables(self.scope.name)
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = self.trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = self.trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
self.grads = grads
|
||||
self.var = var
|
||||
self._train_op = self.trainer.apply_gradients(grads_and_var)
|
||||
|
||||
self.grads = grads
|
||||
self.var = var
|
||||
self._train_op = self.trainer.apply_gradients(grads_and_var)
|
||||
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
|
||||
self.stats_list = [pg_loss, vf_loss, entropy, approxkl, clipfrac]
|
||||
self.loss_names = ['policy_loss', 'value_loss', 'entropy_loss', 'approxkl', 'clipfrac',
|
||||
'total_loss']
|
||||
self.stats_list = [pg_loss, vf_loss, entropy_loss, approxkl, clipfrac, loss]
|
||||
|
||||
self.train_model = train_model
|
||||
self.act_model = act_model
|
||||
self.initial_state = act_model.initial_state
|
||||
|
||||
self.train_model = train_model
|
||||
self.act_model = act_model
|
||||
self.step = act_model.step
|
||||
self.value = act_model.value
|
||||
self.initial_state = act_model.initial_state
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
with tf.variable_scope('initialization'):
|
||||
sess.run(tf.initializers.variables(tf.global_variables(self.scope.name)))
|
||||
sess.run(tf.initializers.variables(tf.local_variables(self.scope.name)))
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope.name)
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) # pylint: disable=E1101
|
||||
|
||||
initialize()
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
def step_with_dict(self, **kwargs):
|
||||
return self.act_model.step(**kwargs)
|
||||
|
||||
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
def step(self, obs, M=None, S=None, **kwargs):
|
||||
kwargs.update({'observations': obs})
|
||||
if M is not None and S is not None:
|
||||
kwargs.update({'dones': M})
|
||||
kwargs.update({'states': S})
|
||||
transition = self.act_model.step(**kwargs)
|
||||
states = transition['next_states'] if 'next_states' in transition else None
|
||||
return transition['actions'], transition['values'], states, transition['neglogpacs']
|
||||
|
||||
def train(self,
|
||||
lr,
|
||||
cliprange,
|
||||
observations,
|
||||
advs,
|
||||
returns,
|
||||
actions,
|
||||
values,
|
||||
neglogpacs,
|
||||
**_kwargs):
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
|
||||
td_map = {
|
||||
self.train_model.X : obs,
|
||||
self.A : actions,
|
||||
self.ADV : advs,
|
||||
self.R : returns,
|
||||
self.LR : lr,
|
||||
self.CLIPRANGE : cliprange,
|
||||
self.OLDNEGLOGPAC : neglogpacs,
|
||||
self.OLDVPRED : values
|
||||
self.train_model.X: observations,
|
||||
self.A: actions,
|
||||
self.ADV: advs,
|
||||
self.RETURNS: returns,
|
||||
self.LR: lr,
|
||||
self.CLIPRANGE: cliprange,
|
||||
self.OLDNEGLOGPAC: neglogpacs,
|
||||
self.VALUE_PREV: values,
|
||||
}
|
||||
if states is not None:
|
||||
td_map[self.train_model.S] = states
|
||||
td_map[self.train_model.M] = masks
|
||||
|
||||
td_map.update(self.train_model.feed_dict(**_kwargs))
|
||||
|
||||
return self.sess.run(
|
||||
self.stats_list + [self._train_op],
|
||||
td_map
|
||||
)[:-1]
|
||||
|
||||
|
188
baselines/ppo2/policies.py
Normal file
188
baselines/ppo2/policies.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.a2c.utils import fc
|
||||
from baselines.common import tf_util
|
||||
from baselines.common.distributions import make_pdtype
|
||||
from baselines.common.input import observation_placeholder, encode_observation
|
||||
from baselines.common.models import get_network_builder
|
||||
from baselines.common.tf_util import adjust_shape
|
||||
from baselines.ppo2.layers import RNN
|
||||
|
||||
|
||||
class PolicyWithValue(object):
|
||||
"""
|
||||
Encapsulates fields and methods for RL policy and two value function estimation with shared parameters
|
||||
"""
|
||||
|
||||
def __init__(self, env, observations, latent, dones, states=None, estimate_q=False, vf_latent=None, sess=None):
|
||||
"""
|
||||
Parameters:
|
||||
----------
|
||||
env RL environment
|
||||
|
||||
observations tensorflow placeholder in which the observations will be fed
|
||||
|
||||
latent latent state from which policy distribution parameters should be inferred
|
||||
|
||||
vf_latent latent state from which value function should be inferred (if None, then latent is used)
|
||||
|
||||
sess tensorflow session to run calculations in (if None, default session is used)
|
||||
|
||||
**tensors tensorflow tensors for additional attributes such as state or mask
|
||||
|
||||
"""
|
||||
self.X = observations
|
||||
self.dones = dones
|
||||
self.pdtype = make_pdtype(env.action_space)
|
||||
self.states = states
|
||||
self.sess = sess or tf.get_default_session()
|
||||
|
||||
vf_latent = vf_latent if vf_latent is not None else latent
|
||||
|
||||
with tf.variable_scope('policy'):
|
||||
latent = tf.layers.flatten(latent)
|
||||
# Based on the action space, will select what probability distribution type
|
||||
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
|
||||
|
||||
with tf.variable_scope('sample_action'):
|
||||
self.action = self.pd.sample()
|
||||
|
||||
with tf.variable_scope('negative_log_probability'):
|
||||
# Calculate the neg log of our probability
|
||||
self.neglogp = self.pd.neglogp(self.action)
|
||||
|
||||
with tf.variable_scope('value'):
|
||||
vf_latent = tf.layers.flatten(vf_latent)
|
||||
|
||||
if estimate_q:
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self.q = fc(vf_latent, 'q', env.action_space.n)
|
||||
self.value = self.q
|
||||
else:
|
||||
vf_latent = tf.layers.flatten(vf_latent)
|
||||
self.value = fc(vf_latent, 'value', 1, init_scale=0.01)
|
||||
self.value = self.value[:, 0]
|
||||
|
||||
self.step_input = {
|
||||
'observations': observations,
|
||||
'dones': self.dones,
|
||||
}
|
||||
|
||||
self.step_output = {
|
||||
'actions': self.action,
|
||||
'values': self.value,
|
||||
'neglogpacs': self.neglogp,
|
||||
}
|
||||
if self.states:
|
||||
self.initial_state = np.zeros(self.states['current'].get_shape())
|
||||
self.step_input.update({'states': self.states['current']})
|
||||
self.step_output.update({'states': self.states['current'],
|
||||
'next_states': self.states['next']})
|
||||
else:
|
||||
self.initial_state = None
|
||||
|
||||
def feed_dict(self, **kwargs):
|
||||
feed_dict = {}
|
||||
for key in kwargs:
|
||||
if key in self.step_input:
|
||||
feed_dict[self.step_input[key]] = adjust_shape(self.step_input[key], kwargs[key])
|
||||
return feed_dict
|
||||
|
||||
def step(self, **kwargs):
|
||||
return self.sess.run(self.step_output,
|
||||
feed_dict=self.feed_dict(**kwargs))
|
||||
|
||||
def values(self, **kwargs):
|
||||
return self.sess.run({'values': self.value},
|
||||
feed_dict=self.feed_dict(**kwargs))
|
||||
|
||||
def save(self, save_path):
|
||||
tf_util.save_state(save_path, sess=self.sess)
|
||||
|
||||
def load(self, load_path):
|
||||
tf_util.load_state(load_path, sess=self.sess)
|
||||
|
||||
|
||||
def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False, **policy_kwargs):
|
||||
if isinstance(policy_network, str):
|
||||
network_type = policy_network
|
||||
policy_network = get_network_builder(network_type)(**policy_kwargs)
|
||||
|
||||
if value_network is None:
|
||||
value_network = 'shared'
|
||||
|
||||
def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
|
||||
next_states_list = []
|
||||
state_map = {}
|
||||
state_placeholder = None
|
||||
|
||||
ob_space = env.observation_space
|
||||
X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space,
|
||||
batch_size=nbatch)
|
||||
dones = tf.placeholder(tf.float32, shape=[X.shape[0]], name='dones')
|
||||
encoded_x = encode_observation(ob_space, X)
|
||||
|
||||
with tf.variable_scope('current_rnn_memory'):
|
||||
if value_network == 'shared':
|
||||
value_network_ = value_network
|
||||
else:
|
||||
if value_network == 'copy':
|
||||
value_network_ = policy_network
|
||||
else:
|
||||
assert callable(value_network)
|
||||
value_network_ = value_network
|
||||
|
||||
policy_memory_size = policy_network.memory_size if isinstance(policy_network, RNN) else 0
|
||||
value_memory_size = value_network_.memory_size if isinstance(value_network_, RNN) else 0
|
||||
state_size = policy_memory_size + value_memory_size
|
||||
|
||||
if state_size > 0:
|
||||
state_placeholder = tf.placeholder(dtype=tf.float32, shape=(nbatch, state_size),
|
||||
name='states')
|
||||
|
||||
state_map['policy'] = state_placeholder[:, 0:policy_memory_size]
|
||||
state_map['value'] = state_placeholder[:, policy_memory_size:]
|
||||
|
||||
with tf.variable_scope('policy_latent', reuse=tf.AUTO_REUSE):
|
||||
if isinstance(policy_network, RNN):
|
||||
assert policy_memory_size > 0
|
||||
policy_latent, next_policy_state = \
|
||||
policy_network(encoded_x, dones, state_map['policy'])
|
||||
next_states_list.append(next_policy_state)
|
||||
else:
|
||||
policy_latent = policy_network(encoded_x)
|
||||
|
||||
with tf.variable_scope('value_latent', reuse=tf.AUTO_REUSE):
|
||||
if value_network_ == 'shared':
|
||||
value_latent = policy_latent
|
||||
elif isinstance(value_network_, RNN):
|
||||
assert value_memory_size > 0
|
||||
value_latent, next_value_state = \
|
||||
value_network_(encoded_x, dones, state_map['value'])
|
||||
next_states_list.append(next_value_state)
|
||||
else:
|
||||
value_latent = value_network_(encoded_x)
|
||||
|
||||
with tf.name_scope("next_rnn_memory"):
|
||||
if state_size > 0:
|
||||
next_states = tf.concat(next_states_list, axis=1)
|
||||
state_info = {'current': state_placeholder,
|
||||
'next': next_states, }
|
||||
else:
|
||||
state_info = None
|
||||
|
||||
policy = PolicyWithValue(
|
||||
env=env,
|
||||
observations=X,
|
||||
dones=dones,
|
||||
latent=policy_latent,
|
||||
vf_latent=value_latent,
|
||||
states=state_info,
|
||||
sess=sess,
|
||||
estimate_q=estimate_q,
|
||||
)
|
||||
return policy
|
||||
|
||||
return policy_fn
|
@@ -1,28 +1,35 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from baselines import logger
|
||||
import time
|
||||
from collections import deque
|
||||
from baselines.common import explained_variance, set_global_seeds
|
||||
from baselines.common.policies import build_policy
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from baselines.common import explained_variance
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.tf_util import display_var_info
|
||||
from baselines.ppo2.policies import build_ppo_policy
|
||||
from baselines.ppo2.runner import Runner
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
from baselines.ppo2.runner import Runner
|
||||
|
||||
|
||||
def constfn(val):
|
||||
def f(_):
|
||||
return val
|
||||
|
||||
return f
|
||||
|
||||
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
|
||||
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
|
||||
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
|
||||
save_interval=0, load_path=None, model_fn=None, **network_kwargs):
|
||||
'''
|
||||
|
||||
def learn(*, network, env, total_timesteps, eval_env=None, seed=None, nsteps=128, ent_coef=0.0, lr=3e-4,
|
||||
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
|
||||
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
|
||||
save_interval=10, load_path=None, model_fn=None, **network_kwargs):
|
||||
"""
|
||||
Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)
|
||||
|
||||
Parameters:
|
||||
@@ -52,7 +59,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
|
||||
max_grad_norm: float or None gradient norm clipping coefficient
|
||||
|
||||
gamma: float discounting factor
|
||||
gamma: float discounting factor for rewards
|
||||
|
||||
lam: float advantage estimation discounting factor (lambda in the paper)
|
||||
|
||||
@@ -72,20 +79,21 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
|
||||
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
|
||||
For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
|
||||
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
set_global_seeds(seed)
|
||||
|
||||
if isinstance(lr, float): lr = constfn(lr)
|
||||
else: assert callable(lr)
|
||||
if isinstance(cliprange, float): cliprange = constfn(cliprange)
|
||||
else: assert callable(cliprange)
|
||||
if isinstance(lr, float):
|
||||
lr = constfn(lr)
|
||||
else:
|
||||
assert callable(lr)
|
||||
if isinstance(cliprange, float):
|
||||
cliprange = constfn(cliprange)
|
||||
else:
|
||||
assert callable(cliprange)
|
||||
total_timesteps = int(total_timesteps)
|
||||
|
||||
policy = build_policy(env, network, **network_kwargs)
|
||||
policy = build_ppo_policy(env, network, **network_kwargs)
|
||||
|
||||
# Get the nb of env
|
||||
nenvs = env.num_envs
|
||||
@@ -104,15 +112,19 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
model_fn = Model
|
||||
|
||||
model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
||||
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm)
|
||||
|
||||
if load_path is not None:
|
||||
model.load(load_path)
|
||||
|
||||
allvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name)
|
||||
display_var_info(allvars)
|
||||
|
||||
# Instantiate the runner object
|
||||
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
|
||||
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam)
|
||||
|
||||
if eval_env is not None:
|
||||
eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam)
|
||||
eval_runner = Runner(env=eval_env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam)
|
||||
|
||||
epinfobuf = deque(maxlen=100)
|
||||
if eval_env is not None:
|
||||
@@ -120,9 +132,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
|
||||
# Start total timer
|
||||
tfirststart = time.perf_counter()
|
||||
nupdates = total_timesteps // nbatch
|
||||
|
||||
nupdates = total_timesteps//nbatch
|
||||
for update in range(1, nupdates+1):
|
||||
for update in range(1, nupdates + 1):
|
||||
assert nbatch % nminibatches == 0
|
||||
# Start timer
|
||||
tstart = time.perf_counter()
|
||||
@@ -131,44 +143,40 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
lrnow = lr(frac)
|
||||
# Calculate the cliprange
|
||||
cliprangenow = cliprange(frac)
|
||||
# Get minibatch
|
||||
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
|
||||
if eval_env is not None:
|
||||
eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632
|
||||
|
||||
epinfobuf.extend(epinfos)
|
||||
# Get minibatch
|
||||
minibatch = runner.run()
|
||||
|
||||
if eval_env is not None:
|
||||
eval_minibatch = eval_runner.run()
|
||||
_eval_obs = eval_minibatch['observations'] # noqa: F841
|
||||
_eval_returns = eval_minibatch['returns'] # noqa: F841
|
||||
_eval_masks = eval_minibatch['masks'] # noqa: F841
|
||||
_eval_actions = eval_minibatch['actions'] # noqa: F841
|
||||
_eval_values = eval_minibatch['values'] # noqa: F841
|
||||
_eval_neglogpacs = eval_minibatch['neglogpacs'] # noqa: F841
|
||||
_eval_states = eval_minibatch['state'] # noqa: F841
|
||||
eval_epinfos = eval_minibatch['epinfos']
|
||||
|
||||
epinfobuf.extend(minibatch.pop('epinfos'))
|
||||
if eval_env is not None:
|
||||
eval_epinfobuf.extend(eval_epinfos)
|
||||
|
||||
# Here what we're going to do is for each minibatch calculate the loss and append it.
|
||||
mblossvals = []
|
||||
if states is None: # nonrecurrent version
|
||||
# Index of each element of batch_size
|
||||
# Create the indices array
|
||||
inds = np.arange(nbatch)
|
||||
for _ in range(noptepochs):
|
||||
# Randomize the indexes
|
||||
np.random.shuffle(inds)
|
||||
# 0 to batch_size with batch_train_size step
|
||||
for start in range(0, nbatch, nbatch_train):
|
||||
end = start + nbatch_train
|
||||
mbinds = inds[start:end]
|
||||
slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
|
||||
mblossvals.append(model.train(lrnow, cliprangenow, *slices))
|
||||
else: # recurrent version
|
||||
assert nenvs % nminibatches == 0
|
||||
envsperbatch = nenvs // nminibatches
|
||||
envinds = np.arange(nenvs)
|
||||
flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
|
||||
for _ in range(noptepochs):
|
||||
np.random.shuffle(envinds)
|
||||
for start in range(0, nenvs, envsperbatch):
|
||||
end = start + envsperbatch
|
||||
mbenvinds = envinds[start:end]
|
||||
mbflatinds = flatinds[mbenvinds].ravel()
|
||||
slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
|
||||
mbstates = states[mbenvinds]
|
||||
mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))
|
||||
|
||||
# Index of each element of batch_size
|
||||
# Create the indices array
|
||||
inds = np.arange(nbatch)
|
||||
for _ in range(noptepochs):
|
||||
# Randomize the indexes
|
||||
np.random.shuffle(inds)
|
||||
# 0 to batch_size with batch_train_size step
|
||||
for start in range(0, nbatch, nbatch_train):
|
||||
end = start + nbatch_train
|
||||
mbinds = inds[start:end]
|
||||
slices = {key: minibatch[key][mbinds] for key in minibatch}
|
||||
mblossvals.append(model.train(lrnow, cliprangenow, **slices))
|
||||
|
||||
# Feedforward --> get losses --> update
|
||||
lossvals = np.mean(mblossvals, axis=0)
|
||||
@@ -179,32 +187,36 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
if update % log_interval == 0 or update == 1:
|
||||
# Calculates if value function is a good predicator of the returns (ev > 1)
|
||||
# or if it's just worse than predicting nothing (ev =< 0)
|
||||
ev = explained_variance(values, returns)
|
||||
logger.logkv("serial_timesteps", update*nsteps)
|
||||
ev = explained_variance(minibatch['values'], minibatch['returns'])
|
||||
logger.logkv("serial_timesteps", update * nsteps)
|
||||
logger.logkv("nupdates", update)
|
||||
logger.logkv("total_timesteps", update*nbatch)
|
||||
logger.logkv("total_timesteps", update * nbatch)
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv("explained_variance", float(ev))
|
||||
logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
|
||||
logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
|
||||
logger.logkv('rewards_per_step', safemean(minibatch['rewards']))
|
||||
logger.logkv('advantages_per_step', safemean(minibatch['advs']))
|
||||
|
||||
if eval_env is not None:
|
||||
logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) )
|
||||
logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) )
|
||||
logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]))
|
||||
logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]))
|
||||
logger.logkv('time_elapsed', tnow - tfirststart)
|
||||
for (lossval, lossname) in zip(lossvals, model.loss_names):
|
||||
logger.logkv(lossname, lossval)
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
logger.dumpkvs()
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (
|
||||
MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
||||
checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
||||
os.makedirs(checkdir, exist_ok=True)
|
||||
savepath = osp.join(checkdir, '%.5i'%update)
|
||||
savepath = osp.join(checkdir, '%.5i' % update)
|
||||
print('Saving to', savepath)
|
||||
model.save(savepath)
|
||||
del minibatch
|
||||
return model
|
||||
|
||||
|
||||
# Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error)
|
||||
def safemean(xs):
|
||||
return np.nan if len(xs) == 0 else np.mean(xs)
|
||||
|
||||
|
||||
|
||||
|
BIN
baselines/ppo2/result/all_result.png
Normal file
BIN
baselines/ppo2/result/all_result.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 177 KiB |
BIN
baselines/ppo2/result/original_vs_pr.png
Normal file
BIN
baselines/ppo2/result/original_vs_pr.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 100 KiB |
BIN
baselines/ppo2/result/rnn_comparison.png
Normal file
BIN
baselines/ppo2/result/rnn_comparison.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 92 KiB |
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this object to make a mini batch of experiences
|
||||
@@ -10,67 +12,118 @@ class Runner(AbstractEnvRunner):
|
||||
run():
|
||||
- Make a mini batch
|
||||
"""
|
||||
def __init__(self, *, env, model, nsteps, gamma, lam):
|
||||
|
||||
def __init__(self, *, env, model, nsteps, gamma, ob_space, lam):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
# Lambda used in GAE (General Advantage Estimation)
|
||||
self.lam = lam
|
||||
# Discount rate
|
||||
self.gamma = gamma
|
||||
|
||||
self.lam = lam # Lambda used in GAE (General Advantage Estimation)
|
||||
self.gamma = gamma # Discount rate for rewards
|
||||
self.ob_space = ob_space
|
||||
|
||||
def run(self):
|
||||
# Here, we init the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
|
||||
mb_states = self.states
|
||||
minibatch = {
|
||||
"observations": [],
|
||||
"actions": [],
|
||||
"rewards": [],
|
||||
"values": [],
|
||||
"dones": [],
|
||||
"neglogpacs": [],
|
||||
}
|
||||
|
||||
data_type = {
|
||||
"observations": self.obs.dtype,
|
||||
"actions": np.float32,
|
||||
"rewards": np.float32,
|
||||
"values": np.float32,
|
||||
"dones": np.float32,
|
||||
"neglogpacs": np.float32,
|
||||
}
|
||||
|
||||
prev_transition = {'next_states': self.model.initial_state} if self.model.initial_state is not None else {}
|
||||
epinfos = []
|
||||
|
||||
# For n in range number of steps
|
||||
for _ in range(self.nsteps):
|
||||
# Given observations, get action value and neglopacs
|
||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
mb_obs.append(self.obs.copy())
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_neglogpacs.append(neglogpacs)
|
||||
mb_dones.append(self.dones)
|
||||
transitions = {}
|
||||
transitions['observations'] = self.obs.copy()
|
||||
transitions['dones'] = self.dones
|
||||
if 'next_states' in prev_transition:
|
||||
transitions['states'] = prev_transition['next_states']
|
||||
transitions.update(self.model.step_with_dict(**transitions))
|
||||
|
||||
# Take actions in env and look the results
|
||||
# Infos contains a ton of useful informations
|
||||
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
|
||||
self.obs, transitions['rewards'], self.dones, infos = self.env.step(transitions['actions'])
|
||||
self.dones = np.array(self.dones, dtype=np.float)
|
||||
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||
mb_rewards.append(rewards)
|
||||
#batch of steps to batch of rollouts
|
||||
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
|
||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
|
||||
mb_actions = np.asarray(mb_actions)
|
||||
mb_values = np.asarray(mb_values, dtype=np.float32)
|
||||
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
|
||||
mb_dones = np.asarray(mb_dones, dtype=np.bool)
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
|
||||
if maybeepinfo:
|
||||
epinfos.append(maybeepinfo)
|
||||
|
||||
# discount/bootstrap off value fn
|
||||
mb_returns = np.zeros_like(mb_rewards)
|
||||
mb_advs = np.zeros_like(mb_rewards)
|
||||
lastgaelam = 0
|
||||
for key in transitions:
|
||||
if key not in minibatch:
|
||||
minibatch[key] = []
|
||||
minibatch[key].append(transitions[key])
|
||||
prev_transition = transitions
|
||||
|
||||
for key in minibatch:
|
||||
dtype = data_type[key] if key in data_type else np.float
|
||||
minibatch[key] = np.array(minibatch[key], dtype=dtype)
|
||||
|
||||
transitions['observations'] = self.obs.copy()
|
||||
transitions['dones'] = self.dones
|
||||
if 'states' in transitions:
|
||||
transitions['states'] = transitions.pop('next_states')
|
||||
|
||||
for key in minibatch:
|
||||
dtype = data_type[key] if key in data_type else np.float
|
||||
minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
|
||||
|
||||
last_values = self.model.step_with_dict(**transitions)['values']
|
||||
|
||||
# Calculate returns and advantages.
|
||||
minibatch['advs'], minibatch['returns'] = \
|
||||
self.advantage_and_returns(values=minibatch['values'],
|
||||
rewards=minibatch['rewards'],
|
||||
dones=minibatch['dones'],
|
||||
last_values=last_values,
|
||||
last_dones=self.dones,
|
||||
gamma=self.gamma)
|
||||
|
||||
for key in minibatch:
|
||||
minibatch[key] = sf01(minibatch[key])
|
||||
|
||||
minibatch['epinfos'] = epinfos
|
||||
return minibatch
|
||||
|
||||
def advantage_and_returns(self, values, rewards, dones, last_values, last_dones, gamma,
|
||||
use_non_episodic_rewards=False):
|
||||
"""
|
||||
calculate Generalized Advantage Estimation (GAE), https://arxiv.org/abs/1506.02438
|
||||
see also Proximal Policy Optimization Algorithms, https://arxiv.org/abs/1707.06347
|
||||
"""
|
||||
|
||||
advantages = np.zeros_like(rewards)
|
||||
lastgaelam = 0 # Lambda used in General Advantage Estimation
|
||||
for t in reversed(range(self.nsteps)):
|
||||
if t == self.nsteps - 1:
|
||||
nextnonterminal = 1.0 - self.dones
|
||||
nextvalues = last_values
|
||||
if not use_non_episodic_rewards:
|
||||
if t == self.nsteps - 1:
|
||||
next_non_terminal = 1.0 - last_dones
|
||||
else:
|
||||
next_non_terminal = 1.0 - dones[t + 1]
|
||||
else:
|
||||
nextnonterminal = 1.0 - mb_dones[t+1]
|
||||
nextvalues = mb_values[t+1]
|
||||
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
|
||||
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
|
||||
mb_returns = mb_advs + mb_values
|
||||
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
|
||||
mb_states, epinfos)
|
||||
# obs, returns, masks, actions, values, neglogpacs, states = runner.run()
|
||||
next_non_terminal = 1.0
|
||||
next_value = values[t + 1] if t < self.nsteps - 1 else last_values
|
||||
delta = rewards[t] + gamma * next_value * next_non_terminal - values[t]
|
||||
advantages[t] = lastgaelam = delta + gamma * self.lam * next_non_terminal * lastgaelam
|
||||
returns = advantages + values
|
||||
return advantages, returns
|
||||
|
||||
|
||||
def sf01(arr):
|
||||
"""
|
||||
swap and then flatten axes 0 and 1
|
||||
"""
|
||||
s = arr.shape
|
||||
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user