224 lines
9.4 KiB
Python
224 lines
9.4 KiB
Python
import os
|
|
import os.path as osp
|
|
import time
|
|
from collections import deque
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from baselines import logger
|
|
from baselines.common import explained_variance
|
|
from baselines.common import set_global_seeds
|
|
from baselines.common.tf_util import display_var_info
|
|
from baselines.ppo2.policies import build_ppo_policy
|
|
from baselines.ppo2.runner import Runner
|
|
|
|
try:
|
|
from mpi4py import MPI
|
|
except ImportError:
|
|
MPI = None
|
|
|
|
|
|
def constfn(val):
|
|
def f(_):
|
|
return val
|
|
|
|
return f
|
|
|
|
|
|
def learn(*, network, env, total_timesteps, eval_env=None, seed=None, nsteps=128, ent_coef=0.0, lr=3e-4,
|
|
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
|
|
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
|
|
save_interval=10, load_path=None, model_fn=None, **network_kwargs):
|
|
"""
|
|
Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)
|
|
|
|
Parameters:
|
|
----------
|
|
|
|
network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
|
|
specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
|
|
tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
|
|
neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
|
|
See common/models.py/lstm for more details on using recurrent nets in policies
|
|
|
|
env: baselines.common.vec_env.VecEnv environment. Needs to be vectorized for parallel environment simulation.
|
|
The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.
|
|
|
|
|
|
nsteps: int number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
|
|
nenv is number of environment copies simulated in parallel)
|
|
|
|
total_timesteps: int number of timesteps (i.e. number of actions taken in the environment)
|
|
|
|
ent_coef: float policy entropy coefficient in the optimization objective
|
|
|
|
lr: float or function learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
|
|
training and 0 is the end of the training.
|
|
|
|
vf_coef: float value function loss coefficient in the optimization objective
|
|
|
|
max_grad_norm: float or None gradient norm clipping coefficient
|
|
|
|
gamma: float discounting factor for rewards
|
|
|
|
lam: float advantage estimation discounting factor (lambda in the paper)
|
|
|
|
log_interval: int number of timesteps between logging events
|
|
|
|
nminibatches: int number of training minibatches per update. For recurrent policies,
|
|
should be smaller or equal than number of environments run in parallel.
|
|
|
|
noptepochs: int number of training epochs per update
|
|
|
|
cliprange: float or function clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
|
|
and 0 is the end of the training
|
|
|
|
save_interval: int number of timesteps between saving events
|
|
|
|
load_path: str path to load the model from
|
|
|
|
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
|
|
For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
|
|
"""
|
|
|
|
set_global_seeds(seed)
|
|
|
|
if isinstance(lr, float):
|
|
lr = constfn(lr)
|
|
else:
|
|
assert callable(lr)
|
|
if isinstance(cliprange, float):
|
|
cliprange = constfn(cliprange)
|
|
else:
|
|
assert callable(cliprange)
|
|
total_timesteps = int(total_timesteps)
|
|
|
|
policy = build_ppo_policy(env, network, **network_kwargs)
|
|
|
|
# Get the nb of env
|
|
nenvs = env.num_envs
|
|
|
|
# Get state_space and action_space
|
|
ob_space = env.observation_space
|
|
ac_space = env.action_space
|
|
|
|
# Calculate the batch_size
|
|
nbatch = nenvs * nsteps
|
|
nbatch_train = nbatch // nminibatches
|
|
|
|
# Instantiate the model object (that creates act_model and train_model)
|
|
if model_fn is None:
|
|
from baselines.ppo2.model import Model
|
|
model_fn = Model
|
|
|
|
model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
|
|
nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm)
|
|
|
|
if load_path is not None:
|
|
model.load(load_path)
|
|
|
|
allvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name)
|
|
display_var_info(allvars)
|
|
|
|
# Instantiate the runner object
|
|
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam)
|
|
|
|
if eval_env is not None:
|
|
eval_runner = Runner(env=eval_env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam)
|
|
|
|
epinfobuf = deque(maxlen=100)
|
|
if eval_env is not None:
|
|
eval_epinfobuf = deque(maxlen=100)
|
|
|
|
# Start total timer
|
|
tfirststart = time.time()
|
|
nupdates = total_timesteps // nbatch
|
|
|
|
for update in range(1, nupdates + 1):
|
|
assert nbatch % nminibatches == 0
|
|
# Start timer
|
|
tstart = time.time()
|
|
frac = 1.0 - (update - 1.0) / nupdates
|
|
# Calculate the learning rate
|
|
lrnow = lr(frac)
|
|
# Calculate the cliprange
|
|
cliprangenow = cliprange(frac)
|
|
|
|
# Get minibatch
|
|
minibatch = runner.run()
|
|
|
|
if eval_env is not None:
|
|
eval_minibatch = eval_runner.run()
|
|
eval_obs = eval_minibatch['obs']
|
|
eval_returns = eval_minibatch['returns']
|
|
eval_masks = eval_minibatch['masks']
|
|
eval_actions = eval_minibatch['actions']
|
|
eval_values = eval_minibatch['values']
|
|
eval_neglogpacs = eval_minibatch['neglogpacs']
|
|
eval_states = eval_minibatch['state']
|
|
eval_epinfos = eval_minibatch['epinfos']
|
|
|
|
epinfobuf.extend(minibatch.pop('epinfos'))
|
|
if eval_env is not None:
|
|
eval_epinfobuf.extend(eval_epinfos)
|
|
|
|
# Here what we're going to do is for each minibatch calculate the loss and append it.
|
|
mblossvals = []
|
|
|
|
# Index of each element of batch_size
|
|
# Create the indices array
|
|
inds = np.arange(nbatch)
|
|
for _ in range(noptepochs):
|
|
# Randomize the indexes
|
|
np.random.shuffle(inds)
|
|
# 0 to batch_size with batch_train_size step
|
|
for start in range(0, nbatch, nbatch_train):
|
|
end = start + nbatch_train
|
|
mbinds = inds[start:end]
|
|
slices = {key: minibatch[key][mbinds] for key in minibatch}
|
|
mblossvals.append(model.train(lrnow, cliprangenow, **slices))
|
|
|
|
# Feedforward --> get losses --> update
|
|
lossvals = np.mean(mblossvals, axis=0)
|
|
# End timer
|
|
tnow = time.time()
|
|
# Calculate the fps (frame per second)
|
|
fps = int(nbatch / (tnow - tstart))
|
|
if update % log_interval == 0 or update == 1:
|
|
# Calculates if value function is a good predicator of the returns (ev > 1)
|
|
# or if it's just worse than predicting nothing (ev =< 0)
|
|
ev = explained_variance(minibatch['values'], minibatch['returns'])
|
|
logger.logkv("serial_timesteps", update * nsteps)
|
|
logger.logkv("nupdates", update)
|
|
logger.logkv("total_timesteps", update * nbatch)
|
|
logger.logkv("fps", fps)
|
|
logger.logkv("explained_variance", float(ev))
|
|
logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
|
|
logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
|
|
logger.logkv('rewards_per_step', safemean(minibatch['rewards']))
|
|
logger.logkv('advantages_per_step', safemean(minibatch['advs']))
|
|
|
|
if eval_env is not None:
|
|
logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]))
|
|
logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]))
|
|
logger.logkv('time_elapsed', tnow - tfirststart)
|
|
for (lossval, lossname) in zip(lossvals, model.loss_names):
|
|
logger.logkv(lossname, lossval)
|
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
|
logger.dumpkvs()
|
|
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (
|
|
MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
|
checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
|
os.makedirs(checkdir, exist_ok=True)
|
|
savepath = osp.join(checkdir, '%.5i' % update)
|
|
print('Saving to', savepath)
|
|
model.save(savepath)
|
|
del minibatch
|
|
return model
|
|
|
|
|
|
# Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error)
|
|
def safemean(xs):
|
|
return np.nan if len(xs) == 0 else np.mean(xs)
|