import os import os.path as osp import time from collections import deque import numpy as np import tensorflow as tf from baselines import logger from baselines.common import explained_variance from baselines.common import set_global_seeds from baselines.common.tf_util import display_var_info from baselines.ppo2.policies import build_ppo_policy from baselines.ppo2.runner import Runner try: from mpi4py import MPI except ImportError: MPI = None def constfn(val): def f(_): return val return f def learn(*, network, env, total_timesteps, eval_env=None, seed=None, nsteps=128, ent_coef=0.0, lr=3e-4, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, save_interval=10, load_path=None, model_fn=None, **network_kwargs): """ Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347) Parameters: ---------- network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list) specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets. See common/models.py/lstm for more details on using recurrent nets in policies env: baselines.common.vec_env.VecEnv environment. Needs to be vectorized for parallel environment simulation. The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class. nsteps: int number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where nenv is number of environment copies simulated in parallel) total_timesteps: int number of timesteps (i.e. number of actions taken in the environment) ent_coef: float policy entropy coefficient in the optimization objective lr: float or function learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the training and 0 is the end of the training. vf_coef: float value function loss coefficient in the optimization objective max_grad_norm: float or None gradient norm clipping coefficient gamma: float discounting factor for rewards lam: float advantage estimation discounting factor (lambda in the paper) log_interval: int number of timesteps between logging events nminibatches: int number of training minibatches per update. For recurrent policies, should be smaller or equal than number of environments run in parallel. noptepochs: int number of training epochs per update cliprange: float or function clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training and 0 is the end of the training save_interval: int number of timesteps between saving events load_path: str path to load the model from **network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network For instance, 'mlp' network architecture has arguments num_hidden and num_layers. """ set_global_seeds(seed) if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) total_timesteps = int(total_timesteps) policy = build_ppo_policy(env, network, **network_kwargs) # Get the nb of env nenvs = env.num_envs # Get state_space and action_space ob_space = env.observation_space ac_space = env.action_space # Calculate the batch_size nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches # Instantiate the model object (that creates act_model and train_model) if model_fn is None: from baselines.ppo2.model import Model model_fn = Model model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm) if load_path is not None: model.load(load_path) allvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name) display_var_info(allvars) # Instantiate the runner object runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam) if eval_env is not None: eval_runner = Runner(env=eval_env, model=model, nsteps=nsteps, gamma=gamma, ob_space=ob_space, lam=lam) epinfobuf = deque(maxlen=100) if eval_env is not None: eval_epinfobuf = deque(maxlen=100) # Start total timer tfirststart = time.time() nupdates = total_timesteps // nbatch for update in range(1, nupdates + 1): assert nbatch % nminibatches == 0 # Start timer tstart = time.time() frac = 1.0 - (update - 1.0) / nupdates # Calculate the learning rate lrnow = lr(frac) # Calculate the cliprange cliprangenow = cliprange(frac) # Get minibatch minibatch = runner.run() if eval_env is not None: eval_minibatch = eval_runner.run() eval_obs = eval_minibatch['obs'] eval_returns = eval_minibatch['returns'] eval_masks = eval_minibatch['masks'] eval_actions = eval_minibatch['actions'] eval_values = eval_minibatch['values'] eval_neglogpacs = eval_minibatch['neglogpacs'] eval_states = eval_minibatch['state'] eval_epinfos = eval_minibatch['epinfos'] epinfobuf.extend(minibatch.pop('epinfos')) if eval_env is not None: eval_epinfobuf.extend(eval_epinfos) # Here what we're going to do is for each minibatch calculate the loss and append it. mblossvals = [] # Index of each element of batch_size # Create the indices array inds = np.arange(nbatch) for _ in range(noptepochs): # Randomize the indexes np.random.shuffle(inds) # 0 to batch_size with batch_train_size step for start in range(0, nbatch, nbatch_train): end = start + nbatch_train mbinds = inds[start:end] slices = {key: minibatch[key][mbinds] for key in minibatch} mblossvals.append(model.train(lrnow, cliprangenow, **slices)) # Feedforward --> get losses --> update lossvals = np.mean(mblossvals, axis=0) # End timer tnow = time.time() # Calculate the fps (frame per second) fps = int(nbatch / (tnow - tstart)) if update % log_interval == 0 or update == 1: # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = explained_variance(minibatch['values'], minibatch['returns']) logger.logkv("serial_timesteps", update * nsteps) logger.logkv("nupdates", update) logger.logkv("total_timesteps", update * nbatch) logger.logkv("fps", fps) logger.logkv("explained_variance", float(ev)) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) logger.logkv('rewards_per_step', safemean(minibatch['rewards'])) logger.logkv('advantages_per_step', safemean(minibatch['advs'])) if eval_env is not None: logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf])) logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf])) logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval) if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: logger.dumpkvs() if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and ( MPI is None or MPI.COMM_WORLD.Get_rank() == 0): checkdir = osp.join(logger.get_dir(), 'checkpoints') os.makedirs(checkdir, exist_ok=True) savepath = osp.join(checkdir, '%.5i' % update) print('Saving to', savepath) model.save(savepath) del minibatch return model # Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error) def safemean(xs): return np.nan if len(xs) == 0 else np.mean(xs)