Wrote some comments to explain the A2C and PPO2 implementation (#607)

* added comments in A2C and PPO2

* Fixed format errors to respect PEP 8 style guide
This commit is contained in:
Thomas Simonini
2018-09-21 22:12:31 +02:00
committed by pzhokhov
parent a7fd8a4477
commit 8158f35611
5 changed files with 220 additions and 5 deletions

View File

@@ -4,3 +4,10 @@
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
- `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options
- also refer to the repo-wide [README.md](../../README.md#training-models)
## Files
- `run_atari`: file used to run the algorithm.
- `policies.py`: contains the different versions of the A2C architecture (MlpPolicy, CNNPolicy, LstmPolicy...).
- `a2c.py`: - Model : class used to initialize the step_model (sampling) and train_model (training)
- learn : Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
- `runner.py`: class used to generates a batch of experiences

View File

@@ -16,6 +16,18 @@ from tensorflow import losses
class Model(object):
"""
We use this class to :
__init__:
- Creates the step_model
- Creates the train_model
train():
- Make the training part (feedforward and retropropagation of gradients)
save/load():
- Save load the model
"""
def __init__(self, policy, env, nsteps,
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
@@ -26,7 +38,10 @@ class Model(object):
with tf.variable_scope('a2c_model', reuse=tf.AUTO_REUSE):
# step_model is used for sampling
step_model = policy(nenvs, 1, sess)
# train_model is used to train our network
train_model = policy(nbatch, nsteps, sess)
A = tf.placeholder(train_model.action.dtype, train_model.action.shape)
@@ -34,25 +49,48 @@ class Model(object):
R = tf.placeholder(tf.float32, [nbatch])
LR = tf.placeholder(tf.float32, [])
# Calculate the loss
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
# Policy loss
# Output -log(pi)
neglogpac = train_model.pd.neglogp(A)
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
entropy = tf.reduce_mean(train_model.pd.entropy())
# 1/n * sum A(si,ai) * -logpi(ai|si)
pg_loss = tf.reduce_mean(ADV * neglogpac)
# Value loss 1/2 SUM [R - V(s)]^2
vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
# Update parameters using loss
# 1. Get the model parameters
params = find_trainable_variables("a2c_model")
# 2. Calculate the gradients
grads = tf.gradients(loss, params)
if max_grad_norm is not None:
# Clip the gradients (normalize)
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
grads = list(zip(grads, params))
# zip aggregate each gradient with parameters associated
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
# 3. Build our trainer
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
# 4. Backpropagation
_train = trainer.apply_gradients(grads)
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
def train(obs, states, rewards, masks, actions, values):
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
# rewards = R + yV(s')
advs = rewards - values
for step in range(len(obs)):
cur_lr = lr.value()
@@ -148,23 +186,47 @@ def learn(
set_global_seeds(seed)
# Get the nb of env
nenvs = env.num_envs
policy = build_policy(env, network, **network_kwargs)
# Instantiate the model object (that creates step_model and train_model)
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
if load_path is not None:
model.load(load_path)
# Instantiate the runner object
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
# Calculate the batch_size
nbatch = nenvs*nsteps
# Start total timer
tstart = time.time()
for update in range(1, total_timesteps//nbatch+1):
# Get mini batch of experiences
obs, states, rewards, masks, actions, values = runner.run()
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
nseconds = time.time()-tstart
# Calculate the fps (frame per second)
fps = int((update*nbatch)/nseconds)
if update % log_interval == 0 or update == 1:
"""
explained_variances calculates if value function is a good
predicator of the returns or if it's just worse than predicting
nothing.
The goal is that ev goes closer and closer to 1.
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
"""
ev = explained_variance(values, rewards)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", update*nbatch)

View File

@@ -3,7 +3,15 @@ from baselines.a2c.utils import discount_with_dones
from baselines.common.runners import AbstractEnvRunner
class Runner(AbstractEnvRunner):
"""
We use this class to generate batches of experiences
__init__:
- Initialize the runner
run():
- Make a mini batch of experiences
"""
def __init__(self, env, model, nsteps=5, gamma=0.99):
super().__init__(env=env, model=model, nsteps=nsteps)
self.gamma = gamma
@@ -11,14 +19,21 @@ class Runner(AbstractEnvRunner):
self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
def run(self):
# We initialize the lists that will contain the mb of experiences
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
mb_states = self.states
for n in range(self.nsteps):
# Given observations, take action and value (V(s))
# We already have self.obs because AbstractEnvRunner run self.obs[:] = env.reset()
actions, values, states, _ = self.model.step(self.obs, S=self.states, M=self.dones)
# Append the experiences
mb_obs.append(np.copy(self.obs))
mb_actions.append(actions)
mb_values.append(values)
mb_dones.append(self.dones)
# Take actions in env and look the results
obs, rewards, dones, _ = self.env.step(actions)
self.states = states
self.dones = dones
@@ -28,8 +43,8 @@ class Runner(AbstractEnvRunner):
self.obs = obs
mb_rewards.append(rewards)
mb_dones.append(self.dones)
#batch of steps to batch of rollouts
# Batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
@@ -40,7 +55,7 @@ class Runner(AbstractEnvRunner):
if self.gamma > 0.0:
#discount/bootstrap off value fn
# Discount/bootstrap off value fn
last_values = self.model.value(self.obs, S=self.states, M=self.dones).tolist()
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
rewards = rewards.tolist()

View File

@@ -43,11 +43,17 @@ class PolicyWithValue(object):
vf_latent = tf.layers.flatten(vf_latent)
latent = tf.layers.flatten(latent)
# Based on the action space, will select what probability distribution type
self.pdtype = make_pdtype(env.action_space)
# This build a fc connected layer that returns a probability distribution
# over actions (self.pd) and our pi logits (self.pi).
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
# Take an action
self.action = self.pd.sample()
# Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action)
self.sess = sess

View File

@@ -17,50 +17,112 @@ from baselines.common.tf_util import initialize
from baselines.common.mpi_util import sync_from_root
class Model(object):
"""
We use this object to :
__init__:
- Creates the step_model
- Creates the train_model
train():
- Make the training part (feedforward and retropropagation of gradients)
save/load():
- Save load the model
"""
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
nsteps, ent_coef, vf_coef, max_grad_norm):
sess = get_session()
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
# CREATE OUR TWO MODELS
# act_model that is used for sampling
act_model = policy(nbatch_act, 1, sess)
# Train model for training
train_model = policy(nbatch_train, nsteps, sess)
# CREATE THE PLACEHOLDERS
A = train_model.pdtype.sample_placeholder([None])
ADV = tf.placeholder(tf.float32, [None])
R = tf.placeholder(tf.float32, [None])
# Keep track of old actor
OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
# Keep track of old critic
OLDVPRED = tf.placeholder(tf.float32, [None])
LR = tf.placeholder(tf.float32, [])
# Cliprange
CLIPRANGE = tf.placeholder(tf.float32, [])
neglogpac = train_model.pd.neglogp(A)
# Calculate the entropy
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
entropy = tf.reduce_mean(train_model.pd.entropy())
# CALCULATE THE LOSS
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
# Clip the value
# Get the value predicted
vpred = train_model.vf
# Clip the value = Oldvalue + clip(value - oldvalue, min = - cliprange, max = cliprange)
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
# Unclipped value
vf_losses1 = tf.square(vpred - R)
# Clipped value
vf_losses2 = tf.square(vpredclipped - R)
# Value loss 0.5 * SUM [max(unclipped, clipped)
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
# Remember we want ratio (pi current policy / pi old policy)
# But neglopac returns us -log(policy)
# So we want to transform it into ratio
# e^(-log old - (-log new)) == e^(log new - log old) == e^(log(new / old))
# = new/old (since exponential function cancels log)
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
# Remember also that we're doing gradient ascent, aka we want to MAXIMIZE the objective function which is equivalent to say
# Loss = - J
# To make objective function negative we can put a negation on the multiplication (pi new / pi old) * - Advantages
pg_losses = -ADV * ratio
# value, min [1 - e] , max [1 + e]
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
# Final PG loss
# Why maximum, because pg_loss_unclipped and pg_loss_clipped are negative, getting the min of positive elements = getting
# the max of negative elements
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
# Total loss (Remember that L = - J because it's the same thing than max J
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
# UPDATE THE PARAMETERS USING LOSS
# 1. Get the model parameters
params = tf.trainable_variables('ppo2_model')
# 2. Build our trainer
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
# 3. Calculate the gradients
grads_and_var = trainer.compute_gradients(loss, params)
grads, var = zip(*grads_and_var)
if max_grad_norm is not None:
# Clip the gradients (normalize)
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
grads_and_var = list(zip(grads, var))
# zip aggregate each gradient with parameters associated
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
# 4. Backpropagation
_train = trainer.apply_gradients(grads_and_var)
def train(lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
# Returns = R + yV(s')
advs = returns - values
# Normalize the advantages
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
@@ -90,23 +152,39 @@ class Model(object):
sync_from_root(sess, global_variables) #pylint: disable=E1101
class Runner(AbstractEnvRunner):
"""
We use this object to make a mini batch of experiences
__init__:
- Initialize the runner
run():
- Make a mini batch
"""
def __init__(self, *, env, model, nsteps, gamma, lam):
super().__init__(env=env, model=model, nsteps=nsteps)
# Lambda used in GAE (General Advantage Estimation)
self.lam = lam
# Discount rate
self.gamma = gamma
def run(self):
# Here, we init the lists that will contain the mb of experiences
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
mb_states = self.states
epinfos = []
# For n in range number of steps
for _ in range(self.nsteps):
# Given observations, get action value and neglopacs
# We already have self.obs because AbstractEnvRunner run self.obs[:] = env.reset()
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
mb_obs.append(self.obs.copy())
mb_actions.append(actions)
mb_values.append(values)
mb_neglogpacs.append(neglogpacs)
mb_dones.append(self.dones)
# Take actions in env and look the results
# Infos contains a ton of useful informations
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
for info in infos:
maybeepinfo = info.get('episode')
@@ -120,19 +198,36 @@ class Runner(AbstractEnvRunner):
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
mb_dones = np.asarray(mb_dones, dtype=np.bool)
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
#discount/bootstrap off value fn
### GENERALIZED ADVANTAGE ESTIMATION
# discount/bootstrap off value fn
# We create mb_returns and mb_advantages
# mb_returns will contain Advantage + value
mb_returns = np.zeros_like(mb_rewards)
mb_advs = np.zeros_like(mb_rewards)
lastgaelam = 0
# From last step to first step
for t in reversed(range(self.nsteps)):
# If t == before last step
if t == self.nsteps - 1:
# If a state is done, nextnonterminal = 0
# In fact nextnonterminal allows us to do that logic
#if done (so nextnonterminal = 0):
# delta = R - V(s) (because self.gamma * nextvalues * nextnonterminal = 0)
# else (not done)
#delta = R + gamma * V(st+1)
nextnonterminal = 1.0 - self.dones
# V(t+1)
nextvalues = last_values
else:
nextnonterminal = 1.0 - mb_dones[t+1]
nextvalues = mb_values[t+1]
# Delta = R(st) + gamma * V(t+1) * nextnonterminal - V(st)
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
# Advantage = delta + gamma * λ (lambda) * nextnonterminal * lastgaelam
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
# Returns
mb_returns = mb_advs + mb_values
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
mb_states, epinfos)
@@ -218,37 +313,56 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
policy = build_policy(env, network, **network_kwargs)
# Get the nb of env
nenvs = env.num_envs
# Get state_space and action_space
ob_space = env.observation_space
ac_space = env.action_space
# Calculate the batch_size
nbatch = nenvs * nsteps
nbatch_train = nbatch // nminibatches
# Instantiate the model object (that creates act_model and train_model)
make_model = lambda : Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm)
model = make_model()
if load_path is not None:
model.load(load_path)
# Instantiate the runner object
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
epinfobuf = deque(maxlen=100)
# Start total timer
tfirststart = time.time()
nupdates = total_timesteps//nbatch
for update in range(1, nupdates+1):
assert nbatch % nminibatches == 0
# Start timer
tstart = time.time()
frac = 1.0 - (update - 1.0) / nupdates
# Calculate the learning rate
lrnow = lr(frac)
# Calculate the cliprange
cliprangenow = cliprange(frac)
# Get minibatch
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
epinfobuf.extend(epinfos)
# Here what we're going to do is for each minibatch calculate the loss and append it.
mblossvals = []
if states is None: # nonrecurrent version
# Index of each element of batch_size
# Create the indices array
inds = np.arange(nbatch)
for _ in range(noptepochs):
# Randomize the indexes
np.random.shuffle(inds)
# 0 to batch_size with batch_train_size step
for start in range(0, nbatch, nbatch_train):
end = start + nbatch_train
mbinds = inds[start:end]
@@ -270,10 +384,21 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
mbstates = states[mbenvinds]
mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))
# Feedforward --> get losses --> update
lossvals = np.mean(mblossvals, axis=0)
# End timer
tnow = time.time()
# Calculate the fps (frame per second)
fps = int(nbatch / (tnow - tstart))
if update % log_interval == 0 or update == 1:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
"""
ev = explained_variance(values, returns)
logger.logkv("serial_timesteps", update*nsteps)
logger.logkv("nupdates", update)
@@ -294,7 +419,7 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
print('Saving to', savepath)
model.save(savepath)
return model
# Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error)
def safemean(xs):
return np.nan if len(xs) == 0 else np.mean(xs)