Compare commits
24 Commits
peterz_viz
...
peterz_383
Author | SHA1 | Date | |
---|---|---|---|
|
cddd97bd2d | ||
|
146bbf886b | ||
|
f3a5abaeeb | ||
|
97e039127f | ||
|
25ecb64821 | ||
|
7dc6bc7c70 | ||
|
7139a66d33 | ||
|
8607dca99e | ||
|
9f9835fe38 | ||
|
d3fed181b5 | ||
|
339d5640b9 | ||
|
a75bc37a40 | ||
|
87b3a04a38 | ||
|
c5b1a1b643 | ||
|
c59a10947d | ||
|
5cd66010dc | ||
|
0a13da8dfe | ||
|
18b6390be6 | ||
|
52255beda5 | ||
|
d80acbb4d1 | ||
|
556b198454 | ||
|
cc88804042 | ||
|
c14d307834 | ||
|
0b71d4c6c4 |
@@ -1,3 +1,5 @@
|
||||
**Status:** Active (under active development, breaking changes may occur)
|
||||
|
||||
<img src="data/logo.jpg" width=25% align="right" /> [](https://travis-ci.org/openai/baselines)
|
||||
|
||||
# Baselines
|
||||
@@ -110,7 +112,7 @@ python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num_timesteps=0 --
|
||||
*NOTE:* At the moment Mujoco training uses VecNormalize wrapper for the environment which is not being saved correctly; so loading the models trained on Mujoco will not work well if the environment is recreated. If necessary, you can work around that by replacing RunningMeanStd by TfRunningMeanStd in [baselines/common/vec_env/vec_normalize.py](baselines/common/vec_env/vec_normalize.py#L12). This way, mean and std of environment normalizing wrapper will be saved in tensorflow variables and included in the model file; however, training is slower that way - hence not including it by default
|
||||
|
||||
## Loading and vizualizing learning curves and other training metrics
|
||||
See [here](docs/viz/viz.md) for instructions on how to load and display the training data.
|
||||
See [here](docs/viz/viz.ipynb) for instructions on how to load and display the training data.
|
||||
|
||||
## Subpackages
|
||||
|
||||
|
@@ -37,9 +37,6 @@ class Runner(AbstractEnvRunner):
|
||||
obs, rewards, dones, _ = self.env.step(actions)
|
||||
self.states = states
|
||||
self.dones = dones
|
||||
for n, done in enumerate(dones):
|
||||
if done:
|
||||
self.obs[n] = self.obs[n]*0
|
||||
self.obs = obs
|
||||
mb_rewards.append(rewards)
|
||||
mb_dones.append(self.dones)
|
||||
|
@@ -72,8 +72,8 @@ class EpisodicLifeEnv(gym.Wrapper):
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert sometimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
||||
# so it's important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
@@ -129,18 +129,26 @@ class ClipRewardEnv(gym.RewardWrapper):
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
def __init__(self, env, width=84, height=84, grayscale=True):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.width = 84
|
||||
self.height = 84
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 1), dtype=np.uint8)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.grayscale = grayscale
|
||||
if self.grayscale:
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 1), dtype=np.uint8)
|
||||
else:
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 3), dtype=np.uint8)
|
||||
|
||||
def observation(self, frame):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
if self.grayscale:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||
return frame[:, :, None]
|
||||
if self.grayscale:
|
||||
frame = np.expand_dims(frame, -1)
|
||||
return frame
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
@@ -156,7 +164,7 @@ class FrameStack(gym.Wrapper):
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
|
||||
|
||||
def reset(self):
|
||||
ob = self.env.reset()
|
||||
@@ -197,7 +205,7 @@ class LazyFrames(object):
|
||||
|
||||
def _force(self):
|
||||
if self._out is None:
|
||||
self._out = np.concatenate(self._frames, axis=2)
|
||||
self._out = np.concatenate(self._frames, axis=-1)
|
||||
self._frames = None
|
||||
return self._out
|
||||
|
||||
|
@@ -62,7 +62,7 @@ class CategoricalPdType(PdType):
|
||||
def pdclass(self):
|
||||
return CategoricalPd
|
||||
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
|
||||
pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
def param_shape(self):
|
||||
@@ -82,7 +82,7 @@ class MultiCategoricalPdType(PdType):
|
||||
return MultiCategoricalPd(self.ncats, flat)
|
||||
|
||||
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
||||
pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
def param_shape(self):
|
||||
@@ -99,7 +99,7 @@ class DiagGaussianPdType(PdType):
|
||||
return DiagGaussianPd
|
||||
|
||||
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
|
||||
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
|
||||
return self.pdfromflat(pdparam), mean
|
||||
@@ -123,7 +123,7 @@ class BernoulliPdType(PdType):
|
||||
def sample_dtype(self):
|
||||
return tf.int32
|
||||
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
# WRONG SECOND DERIVATIVES
|
||||
@@ -345,3 +345,9 @@ def validate_probtype(probtype, pdparam):
|
||||
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
|
||||
print('ok on', probtype, pdparam)
|
||||
|
||||
|
||||
def _matching_fc(tensor, name, size, init_scale, init_bias):
|
||||
if tensor.shape[-1] == size:
|
||||
return tensor
|
||||
else:
|
||||
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
|
||||
|
@@ -332,7 +332,7 @@ def plot_results(
|
||||
xys = gresults[group]
|
||||
if not any(xys):
|
||||
continue
|
||||
color = COLORS[groups.index(group)]
|
||||
color = COLORS[groups.index(group) % len(COLORS)]
|
||||
origxs = [xy[0] for xy in xys]
|
||||
minxlen = min(map(len, origxs))
|
||||
def allequal(qs):
|
||||
|
@@ -132,10 +132,8 @@ class MovieRecord(gym.Wrapper):
|
||||
self.epcount = 0
|
||||
def reset(self):
|
||||
if self.epcount % self.k == 0:
|
||||
print('saving movie this episode', self.savedir)
|
||||
self.env.unwrapped.movie_path = self.savedir
|
||||
else:
|
||||
print('not saving this episode')
|
||||
self.env.unwrapped.movie_path = None
|
||||
self.env.unwrapped.movie = None
|
||||
self.epcount += 1
|
||||
|
@@ -103,9 +103,9 @@ def test_coexistence(learn_fn, network_fn):
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
model1 = learn(seed=1)
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
model2 = learn(seed=2)
|
||||
|
||||
model1.step(env.observation_space.sample())
|
||||
|
@@ -165,6 +165,10 @@ def function(inputs, outputs, updates=None, givens=None):
|
||||
outputs: [tf.Variable] or tf.Variable
|
||||
list of outputs or a single output to be returned from function. Returned
|
||||
value will also have the same shape.
|
||||
updates: [tf.Operation] or tf.Operation
|
||||
list of update functions or single update function that will be run whenever
|
||||
the function is called. The return is ignored.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, list):
|
||||
return _Function(inputs, outputs, updates, givens=givens)
|
||||
|
@@ -67,7 +67,6 @@ class DDPG(object):
|
||||
def __init__(self, actor, critic, memory, observation_shape, action_shape, param_noise=None, action_noise=None,
|
||||
gamma=0.99, tau=0.001, normalize_returns=False, enable_popart=False, normalize_observations=True,
|
||||
batch_size=128, observation_range=(-5., 5.), action_range=(-1., 1.), return_range=(-np.inf, np.inf),
|
||||
adaptive_param_noise=True, adaptive_param_noise_policy_threshold=.1,
|
||||
critic_l2_reg=0., actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.):
|
||||
# Inputs.
|
||||
self.obs0 = tf.placeholder(tf.float32, shape=(None,) + observation_shape, name='obs0')
|
||||
@@ -186,7 +185,7 @@ class DDPG(object):
|
||||
normalized_critic_target_tf = tf.clip_by_value(normalize(self.critic_target, self.ret_rms), self.return_range[0], self.return_range[1])
|
||||
self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
|
||||
if self.critic_l2_reg > 0.:
|
||||
critic_reg_vars = [var for var in self.critic.trainable_vars if 'kernel' in var.name and 'output' not in var.name]
|
||||
critic_reg_vars = [var for var in self.critic.trainable_vars if var.name.endswith('/w:0') and 'output' not in var.name]
|
||||
for var in critic_reg_vars:
|
||||
logger.info(' regularizing: {}'.format(var.name))
|
||||
logger.info(' applying l2 regularization with {}'.format(self.critic_l2_reg))
|
||||
|
@@ -42,7 +42,7 @@ class Critic(Model):
|
||||
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
|
||||
x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated
|
||||
x = self.network_builder(x)
|
||||
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
|
||||
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3), name='output')
|
||||
return x
|
||||
|
||||
@property
|
||||
|
17
baselines/ddpg/test_smoke.py
Normal file
17
baselines/ddpg/test_smoke.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from baselines.run import main as M
|
||||
|
||||
def _run(argstr):
|
||||
M(('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
||||
|
||||
def test_popart():
|
||||
_run('--normalize_returns=True --popart=True')
|
||||
|
||||
def test_noise_normal():
|
||||
_run('--noise_type=normal_0.1')
|
||||
|
||||
def test_noise_ou():
|
||||
_run('--noise_type=ou_0.1')
|
||||
|
||||
def test_noise_adaptive():
|
||||
_run('--noise_type=adaptive-param_0.2,normal_0.1')
|
||||
|
@@ -5,4 +5,4 @@ from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
def wrap_atari_dqn(env):
|
||||
from baselines.common.atari_wrappers import wrap_deepmind
|
||||
return wrap_deepmind(env, frame_stack=True, scale=True)
|
||||
return wrap_deepmind(env, frame_stack=True, scale=False)
|
||||
|
@@ -33,7 +33,7 @@ The functions in this file can are used to create the following functions:
|
||||
stochastic: bool
|
||||
if set to False all the actions are always deterministic (default False)
|
||||
update_eps_ph: float
|
||||
update epsilon a new value, if negative not update happens
|
||||
update epsilon to a new value, if negative no update happens
|
||||
(default: no update)
|
||||
reset_ph: bool
|
||||
reset the perturbed policy by sampling a new perturbation
|
||||
|
@@ -2,9 +2,9 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
|
||||
def _mlp(hiddens, inpt, num_actions, scope, reuse=False, layer_norm=False):
|
||||
def _mlp(hiddens, input_, num_actions, scope, reuse=False, layer_norm=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
out = input_
|
||||
for hidden in hiddens:
|
||||
out = layers.fully_connected(out, num_outputs=hidden, activation_fn=None)
|
||||
if layer_norm:
|
||||
@@ -21,6 +21,9 @@ def mlp(hiddens=[], layer_norm=False):
|
||||
----------
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
layer_norm: bool
|
||||
if true applies layer normalization for every layer
|
||||
as described in https://arxiv.org/abs/1607.06450
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -30,9 +33,9 @@ def mlp(hiddens=[], layer_norm=False):
|
||||
return lambda *args, **kwargs: _mlp(hiddens, layer_norm=layer_norm, *args, **kwargs)
|
||||
|
||||
|
||||
def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False, layer_norm=False):
|
||||
def _cnn_to_mlp(convs, hiddens, dueling, input_, num_actions, scope, reuse=False, layer_norm=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
out = input_
|
||||
with tf.variable_scope("convnet"):
|
||||
for num_outputs, kernel_size, stride in convs:
|
||||
out = layers.convolution2d(out,
|
||||
@@ -72,7 +75,7 @@ def cnn_to_mlp(convs, hiddens, dueling=False, layer_norm=False):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
convs: [(int, int int)]
|
||||
convs: [(int, int, int)]
|
||||
list of convolutional layers in form of
|
||||
(num_outputs, kernel_size, stride)
|
||||
hiddens: [int]
|
||||
@@ -80,6 +83,9 @@ def cnn_to_mlp(convs, hiddens, dueling=False, layer_norm=False):
|
||||
dueling: bool
|
||||
if true double the output MLP to compute a baseline
|
||||
for action scores
|
||||
layer_norm: bool
|
||||
if true applies layer normalization for every layer
|
||||
as described in https://arxiv.org/abs/1607.06450
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@@ -367,8 +367,6 @@ class DDPG(object):
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
|
||||
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
|
||||
assert len(self._vars('main/Q')) == len(Q_grads_tf)
|
||||
|
@@ -54,7 +54,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
||||
# Write out the data
|
||||
dashes = '-' * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in sorted(key2str.items()):
|
||||
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
||||
lines.append('| %s%s | %s%s |' % (
|
||||
key,
|
||||
' ' * (keywidth - len(key)),
|
||||
|
@@ -97,7 +97,7 @@ def learn(env, policy_fn, *,
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
|
||||
clip_param = clip_param * lrmult # Annealed cliping parameter epislon
|
||||
clip_param = clip_param * lrmult # Annealed clipping parameter epsilon
|
||||
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
|
76
baselines/ppo2/microbatched_model.py
Normal file
76
baselines/ppo2/microbatched_model.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from baselines.ppo2.model import Model
|
||||
|
||||
class MicrobatchedModel(Model):
|
||||
"""
|
||||
Model that does training one microbatch at a time - when gradient computation
|
||||
on the entire minibatch causes some overflow
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size):
|
||||
|
||||
self.nmicrobatches = nbatch_train // microbatch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert nbatch_train % microbatch_size == 0, 'microbatch_size ({}) should divide nbatch_train ({}) evenly'.format(microbatch_size, nbatch_train)
|
||||
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
ob_space=ob_space,
|
||||
ac_space=ac_space,
|
||||
nbatch_act=nbatch_act,
|
||||
nbatch_train=microbatch_size,
|
||||
nsteps=nsteps,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
|
||||
self.grads_ph = [tf.placeholder(dtype=g.dtype, shape=g.shape) for g in self.grads]
|
||||
grads_ph_and_vars = list(zip(self.grads_ph, self.var))
|
||||
self._apply_gradients_op = self.trainer.apply_gradients(grads_ph_and_vars)
|
||||
|
||||
|
||||
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
assert states is None, "microbatches with recurrent models are not supported yet"
|
||||
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
|
||||
# Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list)
|
||||
stats_vs = []
|
||||
|
||||
for microbatch_idx in range(self.nmicrobatches):
|
||||
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx+1) * self.microbatch_size)
|
||||
td_map = {
|
||||
self.train_model.X: obs[_sli],
|
||||
self.A:actions[_sli],
|
||||
self.ADV:advs[_sli],
|
||||
self.R:returns[_sli],
|
||||
self.CLIPRANGE:cliprange,
|
||||
self.OLDNEGLOGPAC:neglogpacs[_sli],
|
||||
self.OLDVPRED:values[_sli]
|
||||
}
|
||||
|
||||
# Compute gradient on a microbatch (note that variables do not change here) ...
|
||||
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)
|
||||
if microbatch_idx == 0:
|
||||
sum_grad_v = grad_v
|
||||
else:
|
||||
# .. and add to the total of the gradients
|
||||
for i, g in enumerate(grad_v):
|
||||
sum_grad_v[i] += g
|
||||
stats_vs.append(stats_v)
|
||||
|
||||
feed_dict = {ph: sum_g / self.nmicrobatches for ph, sum_g in zip(self.grads_ph, sum_grad_v)}
|
||||
feed_dict[self.LR] = lr
|
||||
# Update variables using average of the gradients
|
||||
self.sess.run(self._apply_gradients_op, feed_dict)
|
||||
# Return average of the stats
|
||||
return np.mean(np.array(stats_vs), axis=0).tolist()
|
||||
|
||||
|
||||
|
156
baselines/ppo2/model.py
Normal file
156
baselines/ppo2/model.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import tensorflow as tf
|
||||
import functools
|
||||
|
||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
try:
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
from mpi4py import MPI
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
We use this object to :
|
||||
__init__:
|
||||
- Creates the step_model
|
||||
- Creates the train_model
|
||||
|
||||
train():
|
||||
- Make the training part (feedforward and retropropagation of gradients)
|
||||
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm, microbatch_size=None):
|
||||
self.sess = sess = get_session()
|
||||
|
||||
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
|
||||
# Train model for training
|
||||
if microbatch_size is None:
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
else:
|
||||
train_model = policy(microbatch_size, nsteps, sess)
|
||||
|
||||
# CREATE THE PLACEHOLDERS
|
||||
self.A = A = train_model.pdtype.sample_placeholder([None])
|
||||
self.ADV = ADV = tf.placeholder(tf.float32, [None])
|
||||
self.R = R = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old actor
|
||||
self.OLDNEGLOGPAC = OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old critic
|
||||
self.OLDVPRED = OLDVPRED = tf.placeholder(tf.float32, [None])
|
||||
self.LR = LR = tf.placeholder(tf.float32, [])
|
||||
# Cliprange
|
||||
self.CLIPRANGE = CLIPRANGE = tf.placeholder(tf.float32, [])
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
|
||||
# CALCULATE THE LOSS
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
|
||||
# Clip the value to reduce variability during Critic training
|
||||
# Get the predicted value
|
||||
vpred = train_model.vf
|
||||
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
|
||||
# Unclipped value
|
||||
vf_losses1 = tf.square(vpred - R)
|
||||
# Clipped value
|
||||
vf_losses2 = tf.square(vpredclipped - R)
|
||||
|
||||
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
|
||||
# Calculate ratio (pi current policy / pi old policy)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
|
||||
# Defining Loss = - J is equivalent to max J
|
||||
pg_losses = -ADV * ratio
|
||||
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
|
||||
# Final PG loss
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
|
||||
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = self.trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
|
||||
self.grads = grads
|
||||
self.var = var
|
||||
self._train_op = self.trainer.apply_gradients(grads_and_var)
|
||||
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
|
||||
self.stats_list = [pg_loss, vf_loss, entropy, approxkl, clipfrac]
|
||||
|
||||
|
||||
self.train_model = train_model
|
||||
self.act_model = act_model
|
||||
self.step = act_model.step
|
||||
self.value = act_model.value
|
||||
self.initial_state = act_model.initial_state
|
||||
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
initialize()
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
|
||||
def train(self, lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
|
||||
td_map = {
|
||||
self.train_model.X : obs,
|
||||
self.A : actions,
|
||||
self.ADV : advs,
|
||||
self.R : returns,
|
||||
self.LR : lr,
|
||||
self.CLIPRANGE : cliprange,
|
||||
self.OLDNEGLOGPAC : neglogpacs,
|
||||
self.OLDVPRED : values
|
||||
}
|
||||
if states is not None:
|
||||
td_map[self.train_model.S] = states
|
||||
td_map[self.train_model.M] = masks
|
||||
|
||||
return self.sess.run(
|
||||
self.stats_list + [self._train_op],
|
||||
td_map
|
||||
)[:-1]
|
||||
|
@@ -1,226 +1,17 @@
|
||||
import os
|
||||
import time
|
||||
import functools
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from collections import deque
|
||||
from baselines.common import explained_variance, set_global_seeds
|
||||
from baselines.common.policies import build_policy
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
|
||||
try:
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
from mpi4py import MPI
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
except ImportError:
|
||||
MPI = None
|
||||
from baselines.ppo2.runner import Runner
|
||||
|
||||
from baselines.common.tf_util import initialize
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
We use this object to :
|
||||
__init__:
|
||||
- Creates the step_model
|
||||
- Creates the train_model
|
||||
|
||||
train():
|
||||
- Make the training part (feedforward and retropropagation of gradients)
|
||||
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm):
|
||||
sess = get_session()
|
||||
|
||||
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
|
||||
# Train model for training
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
|
||||
# CREATE THE PLACEHOLDERS
|
||||
A = train_model.pdtype.sample_placeholder([None])
|
||||
ADV = tf.placeholder(tf.float32, [None])
|
||||
R = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old actor
|
||||
OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old critic
|
||||
OLDVPRED = tf.placeholder(tf.float32, [None])
|
||||
LR = tf.placeholder(tf.float32, [])
|
||||
# Cliprange
|
||||
CLIPRANGE = tf.placeholder(tf.float32, [])
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
|
||||
# CALCULATE THE LOSS
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
|
||||
# Clip the value to reduce variability during Critic training
|
||||
# Get the predicted value
|
||||
vpred = train_model.vf
|
||||
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
|
||||
# Unclipped value
|
||||
vf_losses1 = tf.square(vpred - R)
|
||||
# Clipped value
|
||||
vf_losses2 = tf.square(vpredclipped - R)
|
||||
|
||||
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
|
||||
# Calculate ratio (pi current policy / pi old policy)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
|
||||
# Defining Loss = - J is equivalent to max J
|
||||
pg_losses = -ADV * ratio
|
||||
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
|
||||
# Final PG loss
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
|
||||
# Total loss
|
||||
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
|
||||
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
|
||||
_train = trainer.apply_gradients(grads_and_var)
|
||||
|
||||
def train(lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
||||
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
||||
if states is not None:
|
||||
td_map[train_model.S] = states
|
||||
td_map[train_model.M] = masks
|
||||
return sess.run(
|
||||
[pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
|
||||
td_map
|
||||
)[:-1]
|
||||
self.loss_names = ['policy_loss', 'value_loss', 'policy_entropy', 'approxkl', 'clipfrac']
|
||||
|
||||
|
||||
self.train = train
|
||||
self.train_model = train_model
|
||||
self.act_model = act_model
|
||||
self.step = act_model.step
|
||||
self.value = act_model.value
|
||||
self.initial_state = act_model.initial_state
|
||||
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
initialize()
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this object to make a mini batch of experiences
|
||||
__init__:
|
||||
- Initialize the runner
|
||||
|
||||
run():
|
||||
- Make a mini batch
|
||||
"""
|
||||
def __init__(self, *, env, model, nsteps, gamma, lam):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
# Lambda used in GAE (General Advantage Estimation)
|
||||
self.lam = lam
|
||||
# Discount rate
|
||||
self.gamma = gamma
|
||||
|
||||
def run(self):
|
||||
# Here, we init the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
|
||||
mb_states = self.states
|
||||
epinfos = []
|
||||
# For n in range number of steps
|
||||
for _ in range(self.nsteps):
|
||||
# Given observations, get action value and neglopacs
|
||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
mb_obs.append(self.obs.copy())
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_neglogpacs.append(neglogpacs)
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
# Infos contains a ton of useful informations
|
||||
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||
mb_rewards.append(rewards)
|
||||
#batch of steps to batch of rollouts
|
||||
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
|
||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
|
||||
mb_actions = np.asarray(mb_actions)
|
||||
mb_values = np.asarray(mb_values, dtype=np.float32)
|
||||
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
|
||||
mb_dones = np.asarray(mb_dones, dtype=np.bool)
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
|
||||
|
||||
# discount/bootstrap off value fn
|
||||
mb_returns = np.zeros_like(mb_rewards)
|
||||
mb_advs = np.zeros_like(mb_rewards)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.nsteps)):
|
||||
if t == self.nsteps - 1:
|
||||
nextnonterminal = 1.0 - self.dones
|
||||
nextvalues = last_values
|
||||
else:
|
||||
nextnonterminal = 1.0 - mb_dones[t+1]
|
||||
nextvalues = mb_values[t+1]
|
||||
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
|
||||
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
|
||||
mb_returns = mb_advs + mb_values
|
||||
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
|
||||
mb_states, epinfos)
|
||||
# obs, returns, masks, actions, values, neglogpacs, states = runner.run()
|
||||
def sf01(arr):
|
||||
"""
|
||||
swap and then flatten axes 0 and 1
|
||||
"""
|
||||
s = arr.shape
|
||||
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
|
||||
|
||||
def constfn(val):
|
||||
def f(_):
|
||||
@@ -230,7 +21,7 @@ def constfn(val):
|
||||
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
|
||||
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
|
||||
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
|
||||
save_interval=0, load_path=None, **network_kwargs):
|
||||
save_interval=0, load_path=None, model_fn=None, **network_kwargs):
|
||||
'''
|
||||
Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)
|
||||
|
||||
@@ -308,10 +99,14 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
nbatch_train = nbatch // nminibatches
|
||||
|
||||
# Instantiate the model object (that creates act_model and train_model)
|
||||
make_model = lambda : Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
||||
if model_fn is None:
|
||||
from baselines.ppo2.model import Model
|
||||
model_fn = Model
|
||||
|
||||
model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
||||
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
model = make_model()
|
||||
|
||||
if load_path is not None:
|
||||
model.load(load_path)
|
||||
# Instantiate the runner object
|
||||
@@ -319,8 +114,6 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
if eval_env is not None:
|
||||
eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam)
|
||||
|
||||
|
||||
|
||||
epinfobuf = deque(maxlen=100)
|
||||
if eval_env is not None:
|
||||
eval_epinfobuf = deque(maxlen=100)
|
||||
|
76
baselines/ppo2/runner.py
Normal file
76
baselines/ppo2/runner.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import numpy as np
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this object to make a mini batch of experiences
|
||||
__init__:
|
||||
- Initialize the runner
|
||||
|
||||
run():
|
||||
- Make a mini batch
|
||||
"""
|
||||
def __init__(self, *, env, model, nsteps, gamma, lam):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
# Lambda used in GAE (General Advantage Estimation)
|
||||
self.lam = lam
|
||||
# Discount rate
|
||||
self.gamma = gamma
|
||||
|
||||
def run(self):
|
||||
# Here, we init the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
|
||||
mb_states = self.states
|
||||
epinfos = []
|
||||
# For n in range number of steps
|
||||
for _ in range(self.nsteps):
|
||||
# Given observations, get action value and neglopacs
|
||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
mb_obs.append(self.obs.copy())
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_neglogpacs.append(neglogpacs)
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
# Infos contains a ton of useful informations
|
||||
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||
mb_rewards.append(rewards)
|
||||
#batch of steps to batch of rollouts
|
||||
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
|
||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
|
||||
mb_actions = np.asarray(mb_actions)
|
||||
mb_values = np.asarray(mb_values, dtype=np.float32)
|
||||
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
|
||||
mb_dones = np.asarray(mb_dones, dtype=np.bool)
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
|
||||
|
||||
# discount/bootstrap off value fn
|
||||
mb_returns = np.zeros_like(mb_rewards)
|
||||
mb_advs = np.zeros_like(mb_rewards)
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(self.nsteps)):
|
||||
if t == self.nsteps - 1:
|
||||
nextnonterminal = 1.0 - self.dones
|
||||
nextvalues = last_values
|
||||
else:
|
||||
nextnonterminal = 1.0 - mb_dones[t+1]
|
||||
nextvalues = mb_values[t+1]
|
||||
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
|
||||
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
|
||||
mb_returns = mb_advs + mb_values
|
||||
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
|
||||
mb_states, epinfos)
|
||||
# obs, returns, masks, actions, values, neglogpacs, states = runner.run()
|
||||
def sf01(arr):
|
||||
"""
|
||||
swap and then flatten axes 0 and 1
|
||||
"""
|
||||
s = arr.shape
|
||||
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
|
||||
|
||||
|
34
baselines/ppo2/test_microbatches.py
Normal file
34
baselines/ppo2/test_microbatches.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import gym
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common.tf_util import make_session
|
||||
from baselines.ppo2.ppo2 import learn
|
||||
|
||||
from baselines.ppo2.microbatched_model import MicrobatchedModel
|
||||
|
||||
def test_microbatches():
|
||||
def env_fn():
|
||||
env = gym.make('CartPole-v0')
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
learn_fn = partial(learn, network='mlp', nsteps=32, total_timesteps=32, seed=0)
|
||||
|
||||
env_ref = DummyVecEnv([env_fn])
|
||||
sess_ref = make_session(make_default=True, graph=tf.Graph())
|
||||
learn_fn(env=env_ref)
|
||||
vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()}
|
||||
|
||||
env_test = DummyVecEnv([env_fn])
|
||||
sess_test = make_session(make_default=True, graph=tf.Graph())
|
||||
learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2))
|
||||
vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()}
|
||||
|
||||
for v in vars_ref:
|
||||
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_microbatches()
|
@@ -5,7 +5,7 @@ matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
|
||||
import matplotlib.pyplot as plt
|
||||
plt.rcParams['svg.fonttype'] = 'none'
|
||||
|
||||
from baselines.bench.monitor import load_results
|
||||
from baselines.common import plot_util
|
||||
|
||||
X_TIMESTEPS = 'timesteps'
|
||||
X_EPISODES = 'episodes'
|
||||
@@ -16,7 +16,7 @@ POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
|
||||
EPISODES_WINDOW = 100
|
||||
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
|
||||
'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue']
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
|
||||
|
||||
def rolling_window(a, window):
|
||||
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
|
||||
@@ -50,7 +50,7 @@ def plot_curves(xy_list, xaxis, yaxis, title):
|
||||
maxx = max(xy[0][-1] for xy in xy_list)
|
||||
minx = 0
|
||||
for (i, (x, y)) in enumerate(xy_list):
|
||||
color = COLORS[i]
|
||||
color = COLORS[i % len(COLORS)]
|
||||
plt.scatter(x, y, s=2)
|
||||
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) #So returns average of last EPISODE_WINDOW episodes
|
||||
plt.plot(x, y_mean, color=color)
|
||||
@@ -62,19 +62,18 @@ def plot_curves(xy_list, xaxis, yaxis, title):
|
||||
fig.canvas.mpl_connect('resize_event', lambda event: plt.tight_layout())
|
||||
plt.grid(True)
|
||||
|
||||
def plot_results(dirs, num_timesteps, xaxis, yaxis, task_name):
|
||||
tslist = []
|
||||
for dir in dirs:
|
||||
ts = load_results(dir)
|
||||
ts = ts[ts.l.cumsum() <= num_timesteps]
|
||||
tslist.append(ts)
|
||||
xy_list = [ts2xy(ts, xaxis, yaxis) for ts in tslist]
|
||||
plot_curves(xy_list, xaxis, yaxis, task_name)
|
||||
|
||||
def split_by_task(taskpath):
|
||||
return taskpath['dirname'].split('/')[-1].split('-')[0]
|
||||
|
||||
def plot_results(dirs, num_timesteps=10e6, xaxis=X_TIMESTEPS, yaxis=Y_REWARD, title='', split_fn=split_by_task):
|
||||
results = plot_util.load_results(dirs)
|
||||
plot_util.plot_results(results, xy_fn=lambda r: ts2xy(r['monitor'], xaxis, yaxis), split_fn=split_fn, average_group=True, resample=int(1e6))
|
||||
|
||||
# Example usage in jupyter-notebook
|
||||
# from baselines import log_viewer
|
||||
# from baselines.results_plotter import plot_results
|
||||
# %matplotlib inline
|
||||
# log_viewer.plot_results(["./log"], 10e6, log_viewer.X_TIMESTEPS, "Breakout")
|
||||
# plot_results("./log")
|
||||
# Here ./log is a directory containing the monitor.csv files
|
||||
|
||||
def main():
|
||||
|
@@ -181,11 +181,11 @@ def parse_cmdline_kwargs(args):
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
def main(args):
|
||||
# configure logger, disable logging in child MPI processes (with rank > 0)
|
||||
|
||||
arg_parser = common_arg_parser()
|
||||
args, unknown_args = arg_parser.parse_known_args()
|
||||
args, unknown_args = arg_parser.parse_known_args(args)
|
||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
@@ -220,5 +220,7 @@ def main():
|
||||
|
||||
env.close()
|
||||
|
||||
return model
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main(sys.argv)
|
||||
|
808
docs/viz/viz.ipynb
Normal file
808
docs/viz/viz.ipynb
Normal file
File diff suppressed because one or more lines are too long
136
docs/viz/viz.md
136
docs/viz/viz.md
@@ -1,136 +0,0 @@
|
||||
# Loading and visualizing results ([colab notebook](https://colab.research.google.com/drive/1Wez1SA9PmNkCoYc8Fvl53bhU3F8OffGm))
|
||||
In order to compare performance of algorithms, we often would like to visualize learning curves (reward as a function of time steps), or some other auxiliary information about learning
|
||||
aggregated into a plot. Baselines repo provides tools for doing so in several different ways, depending on the goal.
|
||||
|
||||
## Preliminaries
|
||||
For all algorithms in baselines summary data is saved into a folder defined by logger. By default, a folder `$TMPDIR/openai-<date>-<time>` is used;
|
||||
you can see the location of logger directory at the beginning of the training in the message like this:
|
||||
|
||||
```
|
||||
Logging to /var/folders/mq/tgrn7bs17s1fnhlwt314b2fm0000gn/T/openai-2018-10-29-15-03-13-537078
|
||||
```
|
||||
The location can be changed by changing `OPENAI_LOGDIR` environment variable; for instance:
|
||||
```bash
|
||||
export OPENAI_LOGDIR=$HOME/logs/cartpole-ppo
|
||||
python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_timesteps=30000 --nsteps=128
|
||||
```
|
||||
will log data to `~/logs/cartpole-ppo`.
|
||||
|
||||
## Using TensorBoard
|
||||
One of the most straightforward ways to visualize data is to use [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard). Baselines logger can dump data in tensorboard-compatible format; to
|
||||
set that up, set environment variables `OPENAI_LOG_FORMAT`
|
||||
```bash
|
||||
export OPENAI_LOG_FORMAT='stdout,log,csv,tensorboard' # formats are comma-separated, but for tensorboard you only really need the last one
|
||||
```
|
||||
And you can now start TensorBoard with:
|
||||
```bash
|
||||
tensorboard --logdir=$OPENAI_LOGDIR
|
||||
```
|
||||
|
||||
## Loading summaries of the results
|
||||
If the summary overview provided by tensorboard is not sufficient, and you would like to either access to raw environment episode data, or use complex post-processing notavailable in tensorboard, you can load results into python as [pandas](https://pandas.pydata.org/) dataframes.
|
||||
|
||||
For instance, the following snippet:
|
||||
```python
|
||||
from baselines.common import plot_util as pu
|
||||
results = pu.load_results('~/logs/cartpole-ppo')
|
||||
```
|
||||
will search for all folders with baselines-compatible results in `~/logs/cartpole-ppo` and subfolders and
|
||||
return a list of Result objects. Each Result object is a named tuple with the following fields:
|
||||
|
||||
- dirname: str - name of the folder from which data was loaded
|
||||
|
||||
- metadata: dict) - dictionary with various metadata (read from metadata.json file)
|
||||
|
||||
- progress: pandas.DataFrame - tabular data saved by logger as a pandas dataframe. Available if csv is in logger formats.
|
||||
|
||||
- monitor: pandas.DataFrame - raw episode data (length, episode reward, timestamp). Available if environment wrapped with [Monitor](../../baselines/bench/monitor.py) wrapper
|
||||
|
||||
## Plotting: single- and few curve plots
|
||||
Once results are loaded, they can be plotted in all conventional means. For example:
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
r = results[0]
|
||||
plt.plot(np.cumsum(r.monitor.l), r.monitor.r)
|
||||
```
|
||||
will print a (very noisy learning curve) for CartPole (assuming we ran the training command for CartPole above). Note the cumulative sum trick to get convert length of the episode into number of time steps taken so far.
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-10-29%20at%204.44.46%20PM.png" width="500">
|
||||
|
||||
We can get a smoothened version of the same curve by using `plot_util.smooth()` function:
|
||||
```python
|
||||
plt.plot(np.cumsum(r.monitor.l), pu.smooth(r.monitor.r, radius=10))
|
||||
```
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-10-29%20at%204.49.13%20PM.png" width="730">
|
||||
|
||||
We can also get a similar curve by using logger summaries (instead of raw episode data in monitor.csv):
|
||||
```python
|
||||
plt.plot(r.progress.total_timesteps, r.progress.eprewmean)
|
||||
```
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-10-29%20at%205.04.31%20PM.png" width="730">
|
||||
|
||||
Note, however, that raw episode data is stored by the Monitor wrapper, and hence looks similar for all algorithms, whereas progress data
|
||||
is handled by the algorithm itself, and hence can vary (column names, type of data available) between algorithms.
|
||||
|
||||
## Plotting: many curves
|
||||
While the loading and the plotting functions described above in principle give you access to any slice of the training summaries,
|
||||
sometimes it is necessary to plot and compare many training runs (multiple algorithms, multiple seeds for random number generator),
|
||||
and usage of the functions above can get tedious and messy. For that case, `baselines.common.plot_util` provides convenience function
|
||||
`plot_results` that handles multiple Result objects that need to be routed in multiple plots. Consider the following bash snippet that
|
||||
runs ppo2 with cartpole with 6 different seeds for 30k time steps, first with batch size 32, and then with batch size 128:
|
||||
|
||||
```bash
|
||||
for seed in $(seq 0 5); do
|
||||
OPENAI_LOGDIR=$HOME/logs/cartpole-ppo/b32-$seed python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_timesteps=3e4 --seed=$seed --nsteps=32
|
||||
done
|
||||
for seed in $(seq 0 5); do
|
||||
OPENAI_LOGDIR=$HOME/logs/cartpole-ppo/b128-$seed python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_timesteps=3e4 --seed=$seed --nsteps=128
|
||||
done
|
||||
```
|
||||
These 12 runs can be loaded just as before:
|
||||
```python
|
||||
results = pu.load_results('~/logs/cartpole-ppo')
|
||||
```
|
||||
But how do we plot all 12 of them in a sensible manner? `baselines.common.plot_util` module provides `plot_results` function to do just that:
|
||||
```
|
||||
results = results[1:]
|
||||
pu.plot_results(results)
|
||||
```
|
||||
(note that now the length of the results list is 13, due to the data from the previous run stored directly in `~/logs/cartpole-ppo`; we discard first element for the same reason)
|
||||
The results are split into two groups based on batch size and are plotted on a separate graph. More specifically, by default `plot_results` considers digits after dash at the end of the directory name to be seed id and groups the runs that differ only by those together.
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-10-29%20at%205.53.45%20PM.png" width="700">
|
||||
|
||||
Showing all seeds on the same plot may be somewhat hard to comprehend and analyse. We can instead average over all seeds via the following command:
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-11-02%20at%204.42.52%20PM.png" width="720">
|
||||
|
||||
The lighter shade shows the standard deviation of data, and darker shade -
|
||||
error in estimate of the mean (that is, standard deviation divided by square root of number of seeds).
|
||||
Note that averaging over seeds requires resampling to a common grid, which, in turn, requires smoothing
|
||||
(using language of signal processing, we need to do low-pass filtering before resampling to avoid aliasing effects).
|
||||
You can change the amount of smoothing by adjusting `resample` and `smooth_step` arguments to achieve desired smoothing effect
|
||||
See the docstring of `plot_util` function for more info.
|
||||
|
||||
To plot both groups on the same graph, we can use the following:
|
||||
```python
|
||||
pu.plot_results(results, average_group=True, split_fn=lambda _: '')
|
||||
```
|
||||
Option `split_fn=labmda _:'' ` effectively disables splitting, so that all curves end up on the same panel.
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-11-06%20at%203.11.51%20PM.png" width=720>
|
||||
|
||||
Now, with many groups the overlapping shaded regions may start looking messy. We can disable either
|
||||
light shaded region (corresponding to standard deviation of the curves in the group) or darker shaded region
|
||||
(corresponding to the error in mean estimate) by using `shaded_std=False` or `shaded_err=False` options respectively.
|
||||
For instance,
|
||||
|
||||
```python
|
||||
pu.plot_results(results, average_group=True, split_fn=lambda _: '', shaded_std=False)
|
||||
```
|
||||
produces the following plot:
|
||||
|
||||
<img src="https://storage.googleapis.com/baselines/assets/viz/Screen%20Shot%202018-11-06%20at%203.12.02%20PM.png" width=820>
|
Reference in New Issue
Block a user