Wrote some comments to explain the A2C and PPO2 implementation (#607)
* added comments in A2C and PPO2 * Fixed format errors to respect PEP 8 style guide
This commit is contained in:
committed by
pzhokhov
parent
a7fd8a4477
commit
8158f35611
@@ -4,3 +4,10 @@
|
||||
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
|
||||
- `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options
|
||||
- also refer to the repo-wide [README.md](../../README.md#training-models)
|
||||
|
||||
## Files
|
||||
- `run_atari`: file used to run the algorithm.
|
||||
- `policies.py`: contains the different versions of the A2C architecture (MlpPolicy, CNNPolicy, LstmPolicy...).
|
||||
- `a2c.py`: - Model : class used to initialize the step_model (sampling) and train_model (training)
|
||||
- learn : Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
|
||||
- `runner.py`: class used to generates a batch of experiences
|
||||
|
@@ -16,6 +16,18 @@ from tensorflow import losses
|
||||
|
||||
class Model(object):
|
||||
|
||||
"""
|
||||
We use this class to :
|
||||
__init__:
|
||||
- Creates the step_model
|
||||
- Creates the train_model
|
||||
|
||||
train():
|
||||
- Make the training part (feedforward and retropropagation of gradients)
|
||||
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
def __init__(self, policy, env, nsteps,
|
||||
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
|
||||
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
|
||||
@@ -26,7 +38,10 @@ class Model(object):
|
||||
|
||||
|
||||
with tf.variable_scope('a2c_model', reuse=tf.AUTO_REUSE):
|
||||
# step_model is used for sampling
|
||||
step_model = policy(nenvs, 1, sess)
|
||||
|
||||
# train_model is used to train our network
|
||||
train_model = policy(nbatch, nsteps, sess)
|
||||
|
||||
A = tf.placeholder(train_model.action.dtype, train_model.action.shape)
|
||||
@@ -34,25 +49,48 @@ class Model(object):
|
||||
R = tf.placeholder(tf.float32, [nbatch])
|
||||
LR = tf.placeholder(tf.float32, [])
|
||||
|
||||
# Calculate the loss
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
|
||||
# Policy loss
|
||||
# Output -log(pi)
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
|
||||
# 1/n * sum A(si,ai) * -logpi(ai|si)
|
||||
pg_loss = tf.reduce_mean(ADV * neglogpac)
|
||||
|
||||
# Value loss 1/2 SUM [R - V(s)]^2
|
||||
vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)
|
||||
|
||||
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
|
||||
|
||||
# Update parameters using loss
|
||||
# 1. Get the model parameters
|
||||
params = find_trainable_variables("a2c_model")
|
||||
|
||||
# 2. Calculate the gradients
|
||||
grads = tf.gradients(loss, params)
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads = list(zip(grads, params))
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
|
||||
# 3. Build our trainer
|
||||
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
|
||||
|
||||
# 4. Backpropagation
|
||||
_train = trainer.apply_gradients(grads)
|
||||
|
||||
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
|
||||
|
||||
def train(obs, states, rewards, masks, actions, values):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# rewards = R + yV(s')
|
||||
advs = rewards - values
|
||||
for step in range(len(obs)):
|
||||
cur_lr = lr.value()
|
||||
@@ -148,23 +186,47 @@ def learn(
|
||||
|
||||
set_global_seeds(seed)
|
||||
|
||||
# Get the nb of env
|
||||
nenvs = env.num_envs
|
||||
policy = build_policy(env, network, **network_kwargs)
|
||||
|
||||
# Instantiate the model object (that creates step_model and train_model)
|
||||
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
|
||||
if load_path is not None:
|
||||
model.load(load_path)
|
||||
|
||||
# Instantiate the runner object
|
||||
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
||||
|
||||
# Calculate the batch_size
|
||||
nbatch = nenvs*nsteps
|
||||
|
||||
# Start total timer
|
||||
tstart = time.time()
|
||||
|
||||
for update in range(1, total_timesteps//nbatch+1):
|
||||
# Get mini batch of experiences
|
||||
obs, states, rewards, masks, actions, values = runner.run()
|
||||
|
||||
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
||||
nseconds = time.time()-tstart
|
||||
|
||||
# Calculate the fps (frame per second)
|
||||
fps = int((update*nbatch)/nseconds)
|
||||
if update % log_interval == 0 or update == 1:
|
||||
"""
|
||||
explained_variances calculates if value function is a good
|
||||
predicator of the returns or if it's just worse than predicting
|
||||
nothing.
|
||||
The goal is that ev goes closer and closer to 1.
|
||||
|
||||
|
||||
interpretation:
|
||||
ev=0 => might as well have predicted zero
|
||||
ev=1 => perfect prediction
|
||||
ev<0 => worse than just predicting zero
|
||||
"""
|
||||
ev = explained_variance(values, rewards)
|
||||
logger.record_tabular("nupdates", update)
|
||||
logger.record_tabular("total_timesteps", update*nbatch)
|
||||
|
@@ -3,7 +3,15 @@ from baselines.a2c.utils import discount_with_dones
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this class to generate batches of experiences
|
||||
|
||||
__init__:
|
||||
- Initialize the runner
|
||||
|
||||
run():
|
||||
- Make a mini batch of experiences
|
||||
"""
|
||||
def __init__(self, env, model, nsteps=5, gamma=0.99):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
self.gamma = gamma
|
||||
@@ -11,14 +19,21 @@ class Runner(AbstractEnvRunner):
|
||||
self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
|
||||
|
||||
def run(self):
|
||||
# We initialize the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
||||
mb_states = self.states
|
||||
for n in range(self.nsteps):
|
||||
# Given observations, take action and value (V(s))
|
||||
# We already have self.obs because AbstractEnvRunner run self.obs[:] = env.reset()
|
||||
actions, values, states, _ = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
|
||||
# Append the experiences
|
||||
mb_obs.append(np.copy(self.obs))
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
obs, rewards, dones, _ = self.env.step(actions)
|
||||
self.states = states
|
||||
self.dones = dones
|
||||
@@ -28,8 +43,8 @@ class Runner(AbstractEnvRunner):
|
||||
self.obs = obs
|
||||
mb_rewards.append(rewards)
|
||||
mb_dones.append(self.dones)
|
||||
#batch of steps to batch of rollouts
|
||||
|
||||
# Batch of steps to batch of rollouts
|
||||
mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
|
||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
||||
mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
|
||||
@@ -40,7 +55,7 @@ class Runner(AbstractEnvRunner):
|
||||
|
||||
|
||||
if self.gamma > 0.0:
|
||||
#discount/bootstrap off value fn
|
||||
# Discount/bootstrap off value fn
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones).tolist()
|
||||
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
|
||||
rewards = rewards.tolist()
|
||||
|
@@ -43,11 +43,17 @@ class PolicyWithValue(object):
|
||||
vf_latent = tf.layers.flatten(vf_latent)
|
||||
latent = tf.layers.flatten(latent)
|
||||
|
||||
# Based on the action space, will select what probability distribution type
|
||||
self.pdtype = make_pdtype(env.action_space)
|
||||
|
||||
# This build a fc connected layer that returns a probability distribution
|
||||
# over actions (self.pd) and our pi logits (self.pi).
|
||||
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
|
||||
|
||||
# Take an action
|
||||
self.action = self.pd.sample()
|
||||
|
||||
# Calculate the neg log of our probability
|
||||
self.neglogp = self.pd.neglogp(self.action)
|
||||
self.sess = sess
|
||||
|
||||
|
@@ -17,50 +17,112 @@ from baselines.common.tf_util import initialize
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
We use this object to :
|
||||
__init__:
|
||||
- Creates the step_model
|
||||
- Creates the train_model
|
||||
|
||||
train():
|
||||
- Make the training part (feedforward and retropropagation of gradients)
|
||||
|
||||
save/load():
|
||||
- Save load the model
|
||||
"""
|
||||
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
|
||||
nsteps, ent_coef, vf_coef, max_grad_norm):
|
||||
sess = get_session()
|
||||
|
||||
with tf.variable_scope('ppo2_model', reuse=tf.AUTO_REUSE):
|
||||
# CREATE OUR TWO MODELS
|
||||
# act_model that is used for sampling
|
||||
act_model = policy(nbatch_act, 1, sess)
|
||||
|
||||
# Train model for training
|
||||
train_model = policy(nbatch_train, nsteps, sess)
|
||||
|
||||
# CREATE THE PLACEHOLDERS
|
||||
A = train_model.pdtype.sample_placeholder([None])
|
||||
ADV = tf.placeholder(tf.float32, [None])
|
||||
R = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old actor
|
||||
OLDNEGLOGPAC = tf.placeholder(tf.float32, [None])
|
||||
# Keep track of old critic
|
||||
OLDVPRED = tf.placeholder(tf.float32, [None])
|
||||
LR = tf.placeholder(tf.float32, [])
|
||||
# Cliprange
|
||||
CLIPRANGE = tf.placeholder(tf.float32, [])
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
|
||||
# Calculate the entropy
|
||||
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
||||
entropy = tf.reduce_mean(train_model.pd.entropy())
|
||||
|
||||
# CALCULATE THE LOSS
|
||||
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
||||
|
||||
# Clip the value
|
||||
# Get the value predicted
|
||||
vpred = train_model.vf
|
||||
# Clip the value = Oldvalue + clip(value - oldvalue, min = - cliprange, max = cliprange)
|
||||
vpredclipped = OLDVPRED + tf.clip_by_value(train_model.vf - OLDVPRED, - CLIPRANGE, CLIPRANGE)
|
||||
# Unclipped value
|
||||
vf_losses1 = tf.square(vpred - R)
|
||||
# Clipped value
|
||||
vf_losses2 = tf.square(vpredclipped - R)
|
||||
# Value loss 0.5 * SUM [max(unclipped, clipped)
|
||||
vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
|
||||
|
||||
# Remember we want ratio (pi current policy / pi old policy)
|
||||
# But neglopac returns us -log(policy)
|
||||
# So we want to transform it into ratio
|
||||
# e^(-log old - (-log new)) == e^(log new - log old) == e^(log(new / old))
|
||||
# = new/old (since exponential function cancels log)
|
||||
ratio = tf.exp(OLDNEGLOGPAC - neglogpac)
|
||||
|
||||
# Remember also that we're doing gradient ascent, aka we want to MAXIMIZE the objective function which is equivalent to say
|
||||
# Loss = - J
|
||||
# To make objective function negative we can put a negation on the multiplication (pi new / pi old) * - Advantages
|
||||
pg_losses = -ADV * ratio
|
||||
|
||||
# value, min [1 - e] , max [1 + e]
|
||||
pg_losses2 = -ADV * tf.clip_by_value(ratio, 1.0 - CLIPRANGE, 1.0 + CLIPRANGE)
|
||||
|
||||
# Final PG loss
|
||||
# Why maximum, because pg_loss_unclipped and pg_loss_clipped are negative, getting the min of positive elements = getting
|
||||
# the max of negative elements
|
||||
pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
|
||||
approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - OLDNEGLOGPAC))
|
||||
clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), CLIPRANGE)))
|
||||
|
||||
# Total loss (Remember that L = - J because it's the same thing than max J
|
||||
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
|
||||
|
||||
# UPDATE THE PARAMETERS USING LOSS
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
|
||||
if max_grad_norm is not None:
|
||||
# Clip the gradients (normalize)
|
||||
grads, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
||||
grads_and_var = list(zip(grads, var))
|
||||
|
||||
# zip aggregate each gradient with parameters associated
|
||||
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
||||
# 4. Backpropagation
|
||||
_train = trainer.apply_gradients(grads_and_var)
|
||||
|
||||
def train(lr, cliprange, obs, returns, masks, actions, values, neglogpacs, states=None):
|
||||
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
||||
# Returns = R + yV(s')
|
||||
advs = returns - values
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
||||
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
||||
@@ -90,23 +152,39 @@ class Model(object):
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
We use this object to make a mini batch of experiences
|
||||
__init__:
|
||||
- Initialize the runner
|
||||
|
||||
run():
|
||||
- Make a mini batch
|
||||
"""
|
||||
def __init__(self, *, env, model, nsteps, gamma, lam):
|
||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||
# Lambda used in GAE (General Advantage Estimation)
|
||||
self.lam = lam
|
||||
# Discount rate
|
||||
self.gamma = gamma
|
||||
|
||||
def run(self):
|
||||
# Here, we init the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
|
||||
mb_states = self.states
|
||||
epinfos = []
|
||||
# For n in range number of steps
|
||||
for _ in range(self.nsteps):
|
||||
# Given observations, get action value and neglopacs
|
||||
# We already have self.obs because AbstractEnvRunner run self.obs[:] = env.reset()
|
||||
actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
|
||||
mb_obs.append(self.obs.copy())
|
||||
mb_actions.append(actions)
|
||||
mb_values.append(values)
|
||||
mb_neglogpacs.append(neglogpacs)
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
# Infos contains a ton of useful informations
|
||||
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
@@ -120,19 +198,36 @@ class Runner(AbstractEnvRunner):
|
||||
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
|
||||
mb_dones = np.asarray(mb_dones, dtype=np.bool)
|
||||
last_values = self.model.value(self.obs, S=self.states, M=self.dones)
|
||||
#discount/bootstrap off value fn
|
||||
|
||||
### GENERALIZED ADVANTAGE ESTIMATION
|
||||
# discount/bootstrap off value fn
|
||||
# We create mb_returns and mb_advantages
|
||||
# mb_returns will contain Advantage + value
|
||||
mb_returns = np.zeros_like(mb_rewards)
|
||||
mb_advs = np.zeros_like(mb_rewards)
|
||||
lastgaelam = 0
|
||||
# From last step to first step
|
||||
for t in reversed(range(self.nsteps)):
|
||||
# If t == before last step
|
||||
if t == self.nsteps - 1:
|
||||
# If a state is done, nextnonterminal = 0
|
||||
# In fact nextnonterminal allows us to do that logic
|
||||
|
||||
#if done (so nextnonterminal = 0):
|
||||
# delta = R - V(s) (because self.gamma * nextvalues * nextnonterminal = 0)
|
||||
# else (not done)
|
||||
#delta = R + gamma * V(st+1)
|
||||
nextnonterminal = 1.0 - self.dones
|
||||
# V(t+1)
|
||||
nextvalues = last_values
|
||||
else:
|
||||
nextnonterminal = 1.0 - mb_dones[t+1]
|
||||
nextvalues = mb_values[t+1]
|
||||
# Delta = R(st) + gamma * V(t+1) * nextnonterminal - V(st)
|
||||
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_values[t]
|
||||
# Advantage = delta + gamma * λ (lambda) * nextnonterminal * lastgaelam
|
||||
mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
|
||||
# Returns
|
||||
mb_returns = mb_advs + mb_values
|
||||
return (*map(sf01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)),
|
||||
mb_states, epinfos)
|
||||
@@ -218,37 +313,56 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
|
||||
|
||||
policy = build_policy(env, network, **network_kwargs)
|
||||
|
||||
# Get the nb of env
|
||||
nenvs = env.num_envs
|
||||
|
||||
# Get state_space and action_space
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
|
||||
# Calculate the batch_size
|
||||
nbatch = nenvs * nsteps
|
||||
nbatch_train = nbatch // nminibatches
|
||||
|
||||
# Instantiate the model object (that creates act_model and train_model)
|
||||
make_model = lambda : Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
||||
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm)
|
||||
model = make_model()
|
||||
if load_path is not None:
|
||||
model.load(load_path)
|
||||
# Instantiate the runner object
|
||||
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
|
||||
|
||||
epinfobuf = deque(maxlen=100)
|
||||
|
||||
# Start total timer
|
||||
tfirststart = time.time()
|
||||
|
||||
nupdates = total_timesteps//nbatch
|
||||
for update in range(1, nupdates+1):
|
||||
assert nbatch % nminibatches == 0
|
||||
# Start timer
|
||||
tstart = time.time()
|
||||
frac = 1.0 - (update - 1.0) / nupdates
|
||||
# Calculate the learning rate
|
||||
lrnow = lr(frac)
|
||||
# Calculate the cliprange
|
||||
cliprangenow = cliprange(frac)
|
||||
# Get minibatch
|
||||
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
|
||||
epinfobuf.extend(epinfos)
|
||||
|
||||
# Here what we're going to do is for each minibatch calculate the loss and append it.
|
||||
mblossvals = []
|
||||
if states is None: # nonrecurrent version
|
||||
# Index of each element of batch_size
|
||||
# Create the indices array
|
||||
inds = np.arange(nbatch)
|
||||
for _ in range(noptepochs):
|
||||
# Randomize the indexes
|
||||
np.random.shuffle(inds)
|
||||
# 0 to batch_size with batch_train_size step
|
||||
for start in range(0, nbatch, nbatch_train):
|
||||
end = start + nbatch_train
|
||||
mbinds = inds[start:end]
|
||||
@@ -270,10 +384,21 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
|
||||
mbstates = states[mbenvinds]
|
||||
mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))
|
||||
|
||||
# Feedforward --> get losses --> update
|
||||
lossvals = np.mean(mblossvals, axis=0)
|
||||
# End timer
|
||||
tnow = time.time()
|
||||
# Calculate the fps (frame per second)
|
||||
fps = int(nbatch / (tnow - tstart))
|
||||
if update % log_interval == 0 or update == 1:
|
||||
"""
|
||||
Computes fraction of variance that ypred explains about y.
|
||||
Returns 1 - Var[y-ypred] / Var[y]
|
||||
interpretation:
|
||||
ev=0 => might as well have predicted zero
|
||||
ev=1 => perfect prediction
|
||||
ev<0 => worse than just predicting zero
|
||||
"""
|
||||
ev = explained_variance(values, returns)
|
||||
logger.logkv("serial_timesteps", update*nsteps)
|
||||
logger.logkv("nupdates", update)
|
||||
@@ -294,7 +419,7 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
|
||||
print('Saving to', savepath)
|
||||
model.save(savepath)
|
||||
return model
|
||||
|
||||
# Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error)
|
||||
def safemean(xs):
|
||||
return np.nan if len(xs) == 0 else np.mean(xs)
|
||||
|
||||
|
Reference in New Issue
Block a user