From 9fa8e1baf1d1f975b87b369a8082122eac812eb1 Mon Sep 17 00:00:00 2001 From: John Schulman Date: Thu, 25 Jan 2018 18:33:48 -0800 Subject: [PATCH] Lots of cleanups Fixes for new gym version Add @olegklimov and @unixpickle to authors list --- README.md | 2 +- baselines/a2c/a2c.py | 45 +- baselines/a2c/policies.py | 138 +++-- baselines/a2c/run_atari.py | 35 +- baselines/a2c/utils.py | 41 +- baselines/acer/policies.py | 21 +- baselines/acer/run_atari.py | 29 +- baselines/acktr/acktr_cont.py | 10 +- baselines/acktr/acktr_disc.py | 83 +-- baselines/acktr/filters.py | 2 +- baselines/acktr/kfac_utils.py | 130 ++--- baselines/acktr/policies.py | 49 +- baselines/acktr/run_atari.py | 29 +- baselines/acktr/run_mujoco.py | 25 +- baselines/acktr/utils.py | 176 +----- baselines/acktr/value_functions.py | 6 +- baselines/bench/benchmarks.py | 24 +- baselines/bench/monitor.py | 51 +- baselines/common/atari_wrappers.py | 46 +- baselines/common/cmd_util.py | 64 +++ baselines/common/distributions.py | 70 +-- baselines/common/mpi_adam.py | 2 +- baselines/common/mpi_moments.py | 53 +- baselines/common/mpi_running_mean_std.py | 6 +- baselines/common/running_mean_std.py | 22 +- baselines/common/tests/test_tf_util.py | 28 - baselines/common/tf_util.py | 501 ++---------------- baselines/common/vec_env/__init__.py | 128 ++++- baselines/common/vec_env/dummy_vec_env.py | 24 +- baselines/common/vec_env/subproc_vec_env.py | 39 +- baselines/common/vec_env/vec_frame_stack.py | 28 +- baselines/common/vec_env/vec_normalize.py | 79 +-- baselines/ddpg/ddpg.py | 24 +- baselines/ddpg/main.py | 1 - baselines/ddpg/training.py | 61 ++- baselines/ddpg/util.py | 44 -- baselines/deepq/build_graph.py | 27 +- baselines/deepq/experiments/atari/enjoy.py | 5 +- baselines/deepq/experiments/atari/model.py | 26 +- baselines/deepq/experiments/atari/train.py | 9 +- .../deepq/experiments/atari/wang2015_eval.py | 5 +- .../deepq/experiments/custom_cartpole.py | 3 +- baselines/deepq/experiments/run_atari.py | 2 - baselines/deepq/experiments/train_cartpole.py | 2 +- baselines/deepq/simple.py | 11 +- baselines/deepq/utils.py | 88 +++ baselines/gail/adversary.py | 8 + baselines/gail/trpo_mpi.py | 8 +- baselines/logger.py | 2 +- baselines/ppo1/cnn_policy.py | 10 +- baselines/ppo1/mlp_policy.py | 16 +- baselines/ppo1/pposgd_simple.py | 20 +- baselines/ppo1/run_atari.py | 10 +- baselines/ppo1/run_mujoco.py | 20 +- baselines/ppo2/policies.py | 63 +-- baselines/ppo2/ppo2.py | 16 +- baselines/ppo2/run_atari.py | 42 +- baselines/ppo2/run_mujoco.py | 10 +- baselines/trpo_mpi/nosharing_cnn_policy.py | 10 +- baselines/trpo_mpi/run_atari.py | 12 +- baselines/trpo_mpi/run_mujoco.py | 26 +- baselines/trpo_mpi/trpo_mpi.py | 26 +- 62 files changed, 989 insertions(+), 1604 deletions(-) create mode 100644 baselines/common/cmd_util.py delete mode 100644 baselines/ddpg/util.py create mode 100644 baselines/deepq/utils.py diff --git a/README.md b/README.md index c51342b..a33059a 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ pip install -e . To cite this repository in publications: @misc{baselines, - author = {Dhariwal, Prafulla and Hesse, Christopher and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai}, + author = {Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai}, title = {OpenAI Baselines}, year = {2017}, publisher = {GitHub}, diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index cfb1d7c..a57d9a5 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -1,3 +1,4 @@ +import os import os.path as osp import gym import time @@ -10,22 +11,19 @@ from baselines import logger from baselines.common import set_global_seeds, explained_variance from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.atari_wrappers import wrap_deepmind +from baselines.common import tf_util from baselines.a2c.utils import discount_with_dones from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables -from baselines.a2c.policies import CnnPolicy from baselines.a2c.utils import cat_entropy, mse class Model(object): - def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, nstack, num_procs, + def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4, alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'): - config = tf.ConfigProto(allow_soft_placement=True, - intra_op_parallelism_threads=num_procs, - inter_op_parallelism_threads=num_procs) - config.gpu_options.allow_growth = True - sess = tf.Session(config=config) + + sess = tf_util.make_session() nact = ac_space.n nbatch = nenvs*nsteps @@ -34,8 +32,8 @@ class Model(object): R = tf.placeholder(tf.float32, [nbatch]) LR = tf.placeholder(tf.float32, []) - step_model = policy(sess, ob_space, ac_space, nenvs, 1, nstack, reuse=False) - train_model = policy(sess, ob_space, ac_space, nenvs, nsteps, nstack, reuse=True) + step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False) + train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True) neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A) pg_loss = tf.reduce_mean(ADV * neglogpac) @@ -58,7 +56,7 @@ class Model(object): for step in range(len(obs)): cur_lr = lr.value() td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr} - if states != []: + if states is not None: td_map[train_model.S] = states td_map[train_model.M] = masks policy_loss, value_loss, policy_entropy, _ = sess.run( @@ -91,32 +89,25 @@ class Model(object): class Runner(object): - def __init__(self, env, model, nsteps=5, nstack=4, gamma=0.99): + def __init__(self, env, model, nsteps=5, gamma=0.99): self.env = env self.model = model nh, nw, nc = env.observation_space.shape nenv = env.num_envs - self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack) - self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8) + self.batch_ob_shape = (nenv*nsteps, nh, nw, nc) + self.obs = np.zeros((nenv, nh, nw, nc), dtype=np.uint8) self.nc = nc obs = env.reset() - self.update_obs(obs) self.gamma = gamma self.nsteps = nsteps self.states = model.initial_state self.dones = [False for _ in range(nenv)] - def update_obs(self, obs): - # Do frame-stacking here instead of the FrameStack wrapper to reduce - # IPC overhead - self.obs = np.roll(self.obs, shift=-self.nc, axis=3) - self.obs[:, :, :, -self.nc:] = obs - def run(self): mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[] mb_states = self.states for n in range(self.nsteps): - actions, values, states = self.model.step(self.obs, self.states, self.dones) + actions, values, states, _ = self.model.step(self.obs, self.states, self.dones) mb_obs.append(np.copy(self.obs)) mb_actions.append(actions) mb_values.append(values) @@ -127,7 +118,7 @@ class Runner(object): for n, done in enumerate(dones): if done: self.obs[n] = self.obs[n]*0 - self.update_obs(obs) + self.obs = obs mb_rewards.append(rewards) mb_dones.append(self.dones) #batch of steps to batch of rollouts @@ -154,17 +145,16 @@ class Runner(object): mb_masks = mb_masks.flatten() return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values -def learn(policy, env, seed, nsteps=5, nstack=4, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100): +def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100): tf.reset_default_graph() set_global_seeds(seed) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space - num_procs = len(env.remotes) # HACK - model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, nstack=nstack, num_procs=num_procs, ent_coef=ent_coef, vf_coef=vf_coef, + model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule) - runner = Runner(env, model, nsteps=nsteps, nstack=nstack, gamma=gamma) + runner = Runner(env, model, nsteps=nsteps, gamma=gamma) nbatch = nenvs*nsteps tstart = time.time() @@ -183,6 +173,3 @@ def learn(policy, env, seed, nsteps=5, nstack=4, total_timesteps=int(80e6), vf_c logger.record_tabular("explained_variance", float(ev)) logger.dump_tabular() env.close() - -if __name__ == '__main__': - main() diff --git a/baselines/a2c/policies.py b/baselines/a2c/policies.py index 4c37df7..9b2a627 100644 --- a/baselines/a2c/policies.py +++ b/baselines/a2c/policies.py @@ -1,36 +1,48 @@ import numpy as np import tensorflow as tf -from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm, sample +from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm +from baselines.common.distributions import make_pdtype + +def nature_cnn(unscaled_images): + """ + CNN from Nature paper. + """ + scaled_images = tf.cast(unscaled_images, tf.float32) / 255. + activ = tf.nn.relu + h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))) + h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))) + h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))) + h3 = conv_to_fc(h3) + return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) class LnLstmPolicy(object): - def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, nlstm=256, reuse=False): - nbatch = nenv*nsteps + def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): + nenv = nbatch // nsteps nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc*nstack) + ob_shape = (nbatch, nh, nw, nc) nact = ac_space.n X = tf.placeholder(tf.uint8, ob_shape) #obs M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states with tf.variable_scope("model", reuse=reuse): - h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) - h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) - h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) - h3 = conv_to_fc(h3) - h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) - xs = batch_to_seq(h4, nenv, nsteps) + h = nature_cnn(X) + xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi = fc(h5, 'pi', nact, act=lambda x:x) - vf = fc(h5, 'v', 1, act=lambda x:x) + pi = fc(h5, 'pi', nact) + vf = fc(h5, 'v', 1) + + self.pdtype = make_pdtype(ac_space) + self.pd = self.pdtype.pdfromflat(pi) v0 = vf[:, 0] - a0 = sample(pi) + a0 = self.pd.sample() + neglogp0 = self.pd.neglogp(a0) self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) def step(ob, state, mask): - a, v, s = sess.run([a0, v0, snew], {X:ob, S:state, M:mask}) - return a, v, s + return sess.run([a0, v0, snew, neglogp0], {X:ob, S:state, M:mask}) def value(ob, state, mask): return sess.run(v0, {X:ob, S:state, M:mask}) @@ -45,34 +57,34 @@ class LnLstmPolicy(object): class LstmPolicy(object): - def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, nlstm=256, reuse=False): - nbatch = nenv*nsteps + def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): + nenv = nbatch // nsteps + nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc*nstack) + ob_shape = (nbatch, nh, nw, nc) nact = ac_space.n X = tf.placeholder(tf.uint8, ob_shape) #obs M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states with tf.variable_scope("model", reuse=reuse): - h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) - h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) - h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) - h3 = conv_to_fc(h3) - h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) - xs = batch_to_seq(h4, nenv, nsteps) + h = nature_cnn(X) + xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi = fc(h5, 'pi', nact, act=lambda x:x) - vf = fc(h5, 'v', 1, act=lambda x:x) + pi = fc(h5, 'pi', nact) + vf = fc(h5, 'v', 1) + + self.pdtype = make_pdtype(ac_space) + self.pd = self.pdtype.pdfromflat(pi) v0 = vf[:, 0] - a0 = sample(pi) + a0 = self.pd.sample() + neglogp0 = self.pd.neglogp(a0) self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) def step(ob, state, mask): - a, v, s = sess.run([a0, v0, snew], {X:ob, S:state, M:mask}) - return a, v, s + return sess.run([a0, v0, snew, neglogp0], {X:ob, S:state, M:mask}) def value(ob, state, mask): return sess.run(v0, {X:ob, S:state, M:mask}) @@ -87,31 +99,67 @@ class LstmPolicy(object): class CnnPolicy(object): - def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False): - nbatch = nenv*nsteps + def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc*nstack) + ob_shape = (nbatch, nh, nw, nc) nact = ac_space.n X = tf.placeholder(tf.uint8, ob_shape) #obs with tf.variable_scope("model", reuse=reuse): - h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) - h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) - h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) - h3 = conv_to_fc(h3) - h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) - pi = fc(h4, 'pi', nact, act=lambda x:x) - vf = fc(h4, 'v', 1, act=lambda x:x) + h = nature_cnn(X) + pi = fc(h, 'pi', nact, init_scale=0.01) + vf = fc(h, 'v', 1)[:,0] - v0 = vf[:, 0] - a0 = sample(pi) - self.initial_state = [] #not stateful + self.pdtype = make_pdtype(ac_space) + self.pd = self.pdtype.pdfromflat(pi) + + a0 = self.pd.sample() + neglogp0 = self.pd.neglogp(a0) + self.initial_state = None def step(ob, *_args, **_kwargs): - a, v = sess.run([a0, v0], {X:ob}) - return a, v, [] #dummy state + a, v, neglogp = sess.run([a0, vf, neglogp0], {X:ob}) + return a, v, self.initial_state, neglogp def value(ob, *_args, **_kwargs): - return sess.run(v0, {X:ob}) + return sess.run(vf, {X:ob}) + + self.X = X + self.pi = pi + self.vf = vf + self.step = step + self.value = value + +class MlpPolicy(object): + def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 + ob_shape = (nbatch,) + ob_space.shape + actdim = ac_space.shape[0] + X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs + with tf.variable_scope("model", reuse=reuse): + activ = tf.tanh + h1 = activ(fc(X, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) + h2 = activ(fc(h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) + pi = fc(h2, 'pi', actdim, init_scale=0.01) + h1 = activ(fc(X, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) + h2 = activ(fc(h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) + vf = fc(h2, 'vf', 1)[:,0] + logstd = tf.get_variable(name="logstd", shape=[1, actdim], + initializer=tf.zeros_initializer()) + + pdparam = tf.concat([pi, pi * 0.0 + logstd], axis=1) + + self.pdtype = make_pdtype(ac_space) + self.pd = self.pdtype.pdfromflat(pdparam) + + a0 = self.pd.sample() + neglogp0 = self.pd.neglogp(a0) + self.initial_state = None + + def step(ob, *_args, **_kwargs): + a, v, neglogp = sess.run([a0, vf, neglogp0], {X:ob}) + return a, v, self.initial_state, neglogp + + def value(ob, *_args, **_kwargs): + return sess.run(vf, {X:ob}) self.X = X self.pi = pi diff --git a/baselines/a2c/run_atari.py b/baselines/a2c/run_atari.py index 8f39f4e..b09d9bb 100644 --- a/baselines/a2c/run_atari.py +++ b/baselines/a2c/run_atari.py @@ -1,45 +1,30 @@ #!/usr/bin/env python3 -import os, logging, gym -from baselines import logger -from baselines.common import set_global_seeds -from baselines import bench -from baselines.a2c.a2c import learn -from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv -from baselines.common.atari_wrappers import make_atari, wrap_deepmind -from baselines.a2c.policies import CnnPolicy, LstmPolicy, LnLstmPolicy -def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu): - def make_env(rank): - def _thunk(): - env = make_atari(env_id) - env.seed(seed + rank) - env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) - gym.logger.setLevel(logging.WARN) - return wrap_deepmind(env) - return _thunk - set_global_seeds(seed) - env = SubprocVecEnv([make_env(i) for i in range(num_cpu)]) +from baselines import logger +from baselines.common.cmd_util import make_atari_env, atari_arg_parser +from baselines.common.vec_env.vec_frame_stack import VecFrameStack +from baselines.a2c.a2c import learn +from baselines.ppo2.policies import CnnPolicy, LstmPolicy, LnLstmPolicy + +def train(env_id, num_timesteps, seed, policy, lrschedule, num_env): if policy == 'cnn': policy_fn = CnnPolicy elif policy == 'lstm': policy_fn = LstmPolicy elif policy == 'lnlstm': policy_fn = LnLstmPolicy + env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4) learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule) env.close() def main(): - import argparse - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') - parser.add_argument('--seed', help='RNG seed', type=int, default=0) + parser = atari_arg_parser() parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn') parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant') - parser.add_argument('--num-timesteps', type=int, default=int(10e6)) args = parser.parse_args() logger.configure() train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, - policy=args.policy, lrschedule=args.lrschedule, num_cpu=16) + policy=args.policy, lrschedule=args.lrschedule, num_env=16) if __name__ == '__main__': main() diff --git a/baselines/a2c/utils.py b/baselines/a2c/utils.py index 0fd514d..3c362ec 100644 --- a/baselines/a2c/utils.py +++ b/baselines/a2c/utils.py @@ -39,23 +39,19 @@ def ortho_init(scale=1.0): return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return _ortho_init -def conv(x, scope, nf, rf, stride, pad='VALID', act=tf.nn.relu, init_scale=1.0): +def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0): with tf.variable_scope(scope): nin = x.get_shape()[3].value w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale)) b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0)) - z = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b - h = act(z) - return h + return tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b -def fc(x, scope, nh, act=tf.nn.relu, init_scale=1.0): +def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): with tf.variable_scope(scope): nin = x.get_shape()[1].value w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) - b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(0.0)) - z = tf.matmul(x, w)+b - h = act(z) - return h + b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias)) + return tf.matmul(x, w)+b def batch_to_seq(h, nbatch, nsteps, flat=False): if flat: @@ -162,9 +158,34 @@ def constant(p): def linear(p): return 1-p +def middle_drop(p): + eps = 0.75 + if 1-p desired_kl * 2: logger.log("kl too high") - U.eval(tf.assign(stepsize, tf.maximum(min_stepsize, stepsize / 1.5))) + tf.assign(stepsize, tf.maximum(min_stepsize, stepsize / 1.5)).eval() elif kl < desired_kl / 2: logger.log("kl too low") - U.eval(tf.assign(stepsize, tf.minimum(max_stepsize, stepsize * 1.5))) + tf.assign(stepsize, tf.minimum(max_stepsize, stepsize * 1.5)).eval() else: logger.log("kl just right!") diff --git a/baselines/acktr/acktr_disc.py b/baselines/acktr/acktr_disc.py index 56e0f03..a8b77b6 100644 --- a/baselines/acktr/acktr_disc.py +++ b/baselines/acktr/acktr_disc.py @@ -7,16 +7,17 @@ from baselines import logger from baselines.common import set_global_seeds, explained_variance -from baselines.acktr.utils import discount_with_dones -from baselines.acktr.utils import Scheduler, find_trainable_variables -from baselines.acktr.utils import cat_entropy, mse +from baselines.a2c.a2c import Runner +from baselines.a2c.utils import discount_with_dones +from baselines.a2c.utils import Scheduler, find_trainable_variables +from baselines.a2c.utils import cat_entropy, mse from baselines.acktr import kfac class Model(object): def __init__(self, policy, ob_space, ac_space, nenvs,total_timesteps, nprocs=32, nsteps=20, - nstack=4, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, + ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, kfac_clip=0.001, lrschedule='linear'): config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=nprocs, @@ -31,8 +32,8 @@ class Model(object): PG_LR = tf.placeholder(tf.float32, []) VF_LR = tf.placeholder(tf.float32, []) - self.model = step_model = policy(sess, ob_space, ac_space, nenvs, 1, nstack, reuse=False) - self.model2 = train_model = policy(sess, ob_space, ac_space, nenvs, nsteps, nstack, reuse=True) + self.model = step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False) + self.model2 = train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True) logpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A) self.logits = logits = train_model.pi @@ -71,7 +72,7 @@ class Model(object): cur_lr = self.lr.value() td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, PG_LR:cur_lr} - if states != []: + if states is not None: td_map[train_model.S] = states td_map[train_model.M] = masks @@ -104,70 +105,8 @@ class Model(object): self.initial_state = step_model.initial_state tf.global_variables_initializer().run(session=sess) -class Runner(object): - - def __init__(self, env, model, nsteps, nstack, gamma): - self.env = env - self.model = model - nh, nw, nc = env.observation_space.shape - nenv = env.num_envs - self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack) - self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8) - obs = env.reset() - self.update_obs(obs) - self.gamma = gamma - self.nsteps = nsteps - self.states = model.initial_state - self.dones = [False for _ in range(nenv)] - - def update_obs(self, obs): - self.obs = np.roll(self.obs, shift=-1, axis=3) - self.obs[:, :, :, -1] = obs[:, :, :, 0] - - def run(self): - mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[] - mb_states = self.states - for n in range(self.nsteps): - actions, values, states = self.model.step(self.obs, self.states, self.dones) - mb_obs.append(np.copy(self.obs)) - mb_actions.append(actions) - mb_values.append(values) - mb_dones.append(self.dones) - obs, rewards, dones, _ = self.env.step(actions) - self.states = states - self.dones = dones - for n, done in enumerate(dones): - if done: - self.obs[n] = self.obs[n]*0 - self.update_obs(obs) - mb_rewards.append(rewards) - mb_dones.append(self.dones) - #batch of steps to batch of rollouts - mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0).reshape(self.batch_ob_shape) - mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0) - mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0) - mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0) - mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0) - mb_masks = mb_dones[:, :-1] - mb_dones = mb_dones[:, 1:] - last_values = self.model.value(self.obs, self.states, self.dones).tolist() - #discount/bootstrap off value fn - for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)): - rewards = rewards.tolist() - dones = dones.tolist() - if dones[-1] == 0: - rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1] - else: - rewards = discount_with_dones(rewards, dones, self.gamma) - mb_rewards[n] = rewards - mb_rewards = mb_rewards.flatten() - mb_actions = mb_actions.flatten() - mb_values = mb_values.flatten() - mb_masks = mb_masks.flatten() - return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values - def learn(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20, - nstack=4, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, + ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, kfac_clip=0.001, save_interval=None, lrschedule='linear'): tf.reset_default_graph() set_global_seeds(seed) @@ -176,7 +115,7 @@ def learn(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval ob_space = env.observation_space ac_space = env.action_space make_model = lambda : Model(policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=nprocs, nsteps - =nsteps, nstack=nstack, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef= + =nsteps, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef= vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip, lrschedule=lrschedule) if save_interval and logger.get_dir(): @@ -185,7 +124,7 @@ def learn(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval fh.write(cloudpickle.dumps(make_model)) model = make_model() - runner = Runner(env, model, nsteps=nsteps, nstack=nstack, gamma=gamma) + runner = Runner(env, model, nsteps=nsteps, gamma=gamma) nbatch = nenvs*nsteps tstart = time.time() coord = tf.train.Coordinator() diff --git a/baselines/acktr/filters.py b/baselines/acktr/filters.py index 9c38c47..5ce019c 100644 --- a/baselines/acktr/filters.py +++ b/baselines/acktr/filters.py @@ -1,4 +1,4 @@ -from baselines.acktr.running_stat import RunningStat +from .running_stat import RunningStat from collections import deque import numpy as np diff --git a/baselines/acktr/kfac_utils.py b/baselines/acktr/kfac_utils.py index 1cc09b9..edc623d 100644 --- a/baselines/acktr/kfac_utils.py +++ b/baselines/acktr/kfac_utils.py @@ -1,93 +1,55 @@ import tensorflow as tf -import numpy as np - def gmatmul(a, b, transpose_a=False, transpose_b=False, reduce_dim=None): - if reduce_dim == None: - # general batch matmul - if len(a.get_shape()) == 3 and len(b.get_shape()) == 3: - return tf.batch_matmul(a, b, adj_x=transpose_a, adj_y=transpose_b) - elif len(a.get_shape()) == 3 and len(b.get_shape()) == 2: - if transpose_b: - N = b.get_shape()[0].value - else: - N = b.get_shape()[1].value - B = a.get_shape()[0].value - if transpose_a: - K = a.get_shape()[1].value - a = tf.reshape(tf.transpose(a, [0, 2, 1]), [-1, K]) - else: - K = a.get_shape()[-1].value - a = tf.reshape(a, [-1, K]) - result = tf.matmul(a, b, transpose_b=transpose_b) - result = tf.reshape(result, [B, -1, N]) - return result - elif len(a.get_shape()) == 2 and len(b.get_shape()) == 3: - if transpose_a: - M = a.get_shape()[1].value - else: - M = a.get_shape()[0].value - B = b.get_shape()[0].value - if transpose_b: - K = b.get_shape()[-1].value - b = tf.transpose(tf.reshape(b, [-1, K]), [1, 0]) - else: - K = b.get_shape()[1].value - b = tf.transpose(tf.reshape( - tf.transpose(b, [0, 2, 1]), [-1, K]), [1, 0]) - result = tf.matmul(a, b, transpose_a=transpose_a) - result = tf.transpose(tf.reshape(result, [M, B, -1]), [1, 0, 2]) - return result - else: - return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) - else: - # weird batch matmul - if len(a.get_shape()) == 2 and len(b.get_shape()) > 2: - # reshape reduce_dim to the left most dim in b - b_shape = b.get_shape() - if reduce_dim != 0: - b_dims = list(range(len(b_shape))) - b_dims.remove(reduce_dim) - b_dims.insert(0, reduce_dim) - b = tf.transpose(b, b_dims) - b_t_shape = b.get_shape() - b = tf.reshape(b, [int(b_shape[reduce_dim]), -1]) - result = tf.matmul(a, b, transpose_a=transpose_a, - transpose_b=transpose_b) - result = tf.reshape(result, b_t_shape) - if reduce_dim != 0: - b_dims = list(range(len(b_shape))) - b_dims.remove(0) - b_dims.insert(reduce_dim, 0) - result = tf.transpose(result, b_dims) - return result + assert reduce_dim is not None - elif len(a.get_shape()) > 2 and len(b.get_shape()) == 2: - # reshape reduce_dim to the right most dim in a - a_shape = a.get_shape() - outter_dim = len(a_shape) - 1 - reduce_dim = len(a_shape) - reduce_dim - 1 - if reduce_dim != outter_dim: - a_dims = list(range(len(a_shape))) - a_dims.remove(reduce_dim) - a_dims.insert(outter_dim, reduce_dim) - a = tf.transpose(a, a_dims) - a_t_shape = a.get_shape() - a = tf.reshape(a, [-1, int(a_shape[reduce_dim])]) - result = tf.matmul(a, b, transpose_a=transpose_a, - transpose_b=transpose_b) - result = tf.reshape(result, a_t_shape) - if reduce_dim != outter_dim: - a_dims = list(range(len(a_shape))) - a_dims.remove(outter_dim) - a_dims.insert(reduce_dim, outter_dim) - result = tf.transpose(result, a_dims) - return result + # weird batch matmul + if len(a.get_shape()) == 2 and len(b.get_shape()) > 2: + # reshape reduce_dim to the left most dim in b + b_shape = b.get_shape() + if reduce_dim != 0: + b_dims = list(range(len(b_shape))) + b_dims.remove(reduce_dim) + b_dims.insert(0, reduce_dim) + b = tf.transpose(b, b_dims) + b_t_shape = b.get_shape() + b = tf.reshape(b, [int(b_shape[reduce_dim]), -1]) + result = tf.matmul(a, b, transpose_a=transpose_a, + transpose_b=transpose_b) + result = tf.reshape(result, b_t_shape) + if reduce_dim != 0: + b_dims = list(range(len(b_shape))) + b_dims.remove(0) + b_dims.insert(reduce_dim, 0) + result = tf.transpose(result, b_dims) + return result - elif len(a.get_shape()) == 2 and len(b.get_shape()) == 2: - return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) + elif len(a.get_shape()) > 2 and len(b.get_shape()) == 2: + # reshape reduce_dim to the right most dim in a + a_shape = a.get_shape() + outter_dim = len(a_shape) - 1 + reduce_dim = len(a_shape) - reduce_dim - 1 + if reduce_dim != outter_dim: + a_dims = list(range(len(a_shape))) + a_dims.remove(reduce_dim) + a_dims.insert(outter_dim, reduce_dim) + a = tf.transpose(a, a_dims) + a_t_shape = a.get_shape() + a = tf.reshape(a, [-1, int(a_shape[reduce_dim])]) + result = tf.matmul(a, b, transpose_a=transpose_a, + transpose_b=transpose_b) + result = tf.reshape(result, a_t_shape) + if reduce_dim != outter_dim: + a_dims = list(range(len(a_shape))) + a_dims.remove(outter_dim) + a_dims.insert(reduce_dim, outter_dim) + result = tf.transpose(result, a_dims) + return result - assert False, 'something went wrong' + elif len(a.get_shape()) == 2 and len(b.get_shape()) == 2: + return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b) + + assert False, 'something went wrong' def clipoutNeg(vec, threshold=1e-6): diff --git a/baselines/acktr/policies.py b/baselines/acktr/policies.py index 47965a5..39bb6cb 100644 --- a/baselines/acktr/policies.py +++ b/baselines/acktr/policies.py @@ -1,43 +1,8 @@ import numpy as np import tensorflow as tf -from baselines.acktr.utils import conv, fc, dense, conv_to_fc, sample, kl_div +from baselines.acktr.utils import dense, kl_div import baselines.common.tf_util as U -class CnnPolicy(object): - - def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False): - nbatch = nenv*nsteps - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc*nstack) - nact = ac_space.n - X = tf.placeholder(tf.uint8, ob_shape) #obs - with tf.variable_scope("model", reuse=reuse): - h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) - h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) - h3 = conv(h2, 'c3', nf=32, rf=3, stride=1, init_scale=np.sqrt(2)) - h3 = conv_to_fc(h3) - h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) - pi = fc(h4, 'pi', nact, act=lambda x:x) - vf = fc(h4, 'v', 1, act=lambda x:x) - - v0 = vf[:, 0] - a0 = sample(pi) - self.initial_state = [] #not stateful - - def step(ob, *_args, **_kwargs): - a, v = sess.run([a0, v0], {X:ob}) - return a, v, [] #dummy state - - def value(ob, *_args, **_kwargs): - return sess.run(v0, {X:ob}) - - self.X = X - self.pi = pi - self.vf = vf - self.step = step - self.value = value - - class GaussianMlpPolicy(object): def __init__(self, ob_dim, ac_dim): # Here we'll construct a bunch of expressions, which will be used in two places: @@ -60,12 +25,12 @@ class GaussianMlpPolicy(object): std_na = tf.tile(std_1a, [tf.shape(mean_na)[0], 1]) ac_dist = tf.concat([tf.reshape(mean_na, [-1, ac_dim]), tf.reshape(std_na, [-1, ac_dim])], 1) sampled_ac_na = tf.random_normal(tf.shape(ac_dist[:,ac_dim:])) * ac_dist[:,ac_dim:] + ac_dist[:,:ac_dim] # This is the sampled action we'll perform. - logprobsampled_n = - U.sum(tf.log(ac_dist[:,ac_dim:]), axis=1) - 0.5 * tf.log(2.0*np.pi)*ac_dim - 0.5 * U.sum(tf.square(ac_dist[:,:ac_dim] - sampled_ac_na) / (tf.square(ac_dist[:,ac_dim:])), axis=1) # Logprob of sampled action - logprob_n = - U.sum(tf.log(ac_dist[:,ac_dim:]), axis=1) - 0.5 * tf.log(2.0*np.pi)*ac_dim - 0.5 * U.sum(tf.square(ac_dist[:,:ac_dim] - oldac_na) / (tf.square(ac_dist[:,ac_dim:])), axis=1) # Logprob of previous actions under CURRENT policy (whereas oldlogprob_n is under OLD policy) - kl = U.mean(kl_div(oldac_dist, ac_dist, ac_dim)) - #kl = .5 * U.mean(tf.square(logprob_n - oldlogprob_n)) # Approximation of KL divergence between old policy used to generate actions, and new policy used to compute logprob_n - surr = - U.mean(adv_n * logprob_n) # Loss function that we'll differentiate to get the policy gradient - surr_sampled = - U.mean(logprob_n) # Sampled loss of the policy + logprobsampled_n = - tf.reduce_sum(tf.log(ac_dist[:,ac_dim:]), axis=1) - 0.5 * tf.log(2.0*np.pi)*ac_dim - 0.5 * tf.reduce_sum(tf.square(ac_dist[:,:ac_dim] - sampled_ac_na) / (tf.square(ac_dist[:,ac_dim:])), axis=1) # Logprob of sampled action + logprob_n = - tf.reduce_sum(tf.log(ac_dist[:,ac_dim:]), axis=1) - 0.5 * tf.log(2.0*np.pi)*ac_dim - 0.5 * tf.reduce_sum(tf.square(ac_dist[:,:ac_dim] - oldac_na) / (tf.square(ac_dist[:,ac_dim:])), axis=1) # Logprob of previous actions under CURRENT policy (whereas oldlogprob_n is under OLD policy) + kl = tf.reduce_mean(kl_div(oldac_dist, ac_dist, ac_dim)) + #kl = .5 * tf.reduce_mean(tf.square(logprob_n - oldlogprob_n)) # Approximation of KL divergence between old policy used to generate actions, and new policy used to compute logprob_n + surr = - tf.reduce_mean(adv_n * logprob_n) # Loss function that we'll differentiate to get the policy gradient + surr_sampled = - tf.reduce_mean(logprob_n) # Sampled loss of the policy self._act = U.function([ob_no], [sampled_ac_na, ac_dist, logprobsampled_n]) # Generate a new action and its logprob #self.compute_kl = U.function([ob_no, oldac_na, oldlogprob_n], kl) # Compute (approximate) KL divergence between old policy and new policy self.compute_kl = U.function([ob_no, oldac_dist], kl) diff --git a/baselines/acktr/run_atari.py b/baselines/acktr/run_atari.py index 0c48758..7569f2e 100644 --- a/baselines/acktr/run_atari.py +++ b/baselines/acktr/run_atari.py @@ -1,38 +1,21 @@ #!/usr/bin/env python3 -import os, logging, gym + from baselines import logger -from baselines.common import set_global_seeds -from baselines import bench from baselines.acktr.acktr_disc import learn -from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv -from baselines.common.atari_wrappers import make_atari, wrap_deepmind -from baselines.acktr.policies import CnnPolicy +from baselines.common.cmd_util import make_atari_env, atari_arg_parser +from baselines.common.vec_env.vec_frame_stack import VecFrameStack +from baselines.ppo2.policies import CnnPolicy def train(env_id, num_timesteps, seed, num_cpu): - def make_env(rank): - def _thunk(): - env = make_atari(env_id) - env.seed(seed + rank) - env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) - gym.logger.setLevel(logging.WARN) - return wrap_deepmind(env) - return _thunk - set_global_seeds(seed) - env = SubprocVecEnv([make_env(i) for i in range(num_cpu)]) + env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) policy_fn = CnnPolicy learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu) env.close() def main(): - import argparse - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') - parser.add_argument('--seed', help='RNG seed', type=int, default=0) - parser.add_argument('--num-timesteps', type=int, default=int(10e6)) - args = parser.parse_args() + args = atari_arg_parser().parse_args() logger.configure() train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32) - if __name__ == '__main__': main() diff --git a/baselines/acktr/run_mujoco.py b/baselines/acktr/run_mujoco.py index ed6d06a..9065d58 100644 --- a/baselines/acktr/run_mujoco.py +++ b/baselines/acktr/run_mujoco.py @@ -1,22 +1,14 @@ #!/usr/bin/env python3 -import argparse -import logging -import os + import tensorflow as tf -import gym from baselines import logger -from baselines.common import set_global_seeds -from baselines import bench +from baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser from baselines.acktr.acktr_cont import learn from baselines.acktr.policies import GaussianMlpPolicy from baselines.acktr.value_functions import NeuralNetValueFunction def train(env_id, num_timesteps, seed): - env=gym.make(env_id) - env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) - set_global_seeds(seed) - env.seed(seed) - gym.logger.setLevel(logging.WARN) + env = make_mujoco_env(env_id, seed) with tf.Session(config=tf.ConfigProto()): ob_dim = env.observation_space.shape[0] @@ -33,11 +25,10 @@ def train(env_id, num_timesteps, seed): env.close() -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Run Mujoco benchmark.') - parser.add_argument('--seed', help='RNG seed', type=int, default=0) - parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1") - parser.add_argument('--num-timesteps', type=int, default=int(1e6)) - args = parser.parse_args() +def main(): + args = mujoco_arg_parser().parse_args() logger.configure() train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) + +if __name__ == "__main__": + main() diff --git a/baselines/acktr/utils.py b/baselines/acktr/utils.py index 089cc78..227350f 100644 --- a/baselines/acktr/utils.py +++ b/baselines/acktr/utils.py @@ -1,69 +1,8 @@ -import os -import numpy as np import tensorflow as tf -import baselines.common.tf_util as U -from collections import deque - -def sample(logits): - noise = tf.random_uniform(tf.shape(logits)) - return tf.argmax(logits - tf.log(-tf.log(noise)), 1) - -def std(x): - mean = tf.reduce_mean(x) - var = tf.reduce_mean(tf.square(x-mean)) - return tf.sqrt(var) - -def cat_entropy(logits): - a0 = logits - tf.reduce_max(logits, 1, keep_dims=True) - ea0 = tf.exp(a0) - z0 = tf.reduce_sum(ea0, 1, keep_dims=True) - p0 = ea0 / z0 - return tf.reduce_sum(p0 * (tf.log(z0) - a0), 1) - -def cat_entropy_softmax(p0): - return - tf.reduce_sum(p0 * tf.log(p0 + 1e-6), axis = 1) - -def mse(pred, target): - return tf.square(pred-target)/2. - -def ortho_init(scale=1.0): - def _ortho_init(shape, dtype, partition_info=None): - #lasagne ortho init for tf - shape = tuple(shape) - if len(shape) == 2: - flat_shape = shape - elif len(shape) == 4: # assumes NHWC - flat_shape = (np.prod(shape[:-1]), shape[-1]) - else: - raise NotImplementedError - a = np.random.normal(0.0, 1.0, flat_shape) - u, _, v = np.linalg.svd(a, full_matrices=False) - q = u if u.shape == flat_shape else v # pick the one with the correct shape - q = q.reshape(shape) - return (scale * q[:shape[0], :shape[1]]).astype(np.float32) - return _ortho_init - -def conv(x, scope, nf, rf, stride, pad='VALID', act=tf.nn.relu, init_scale=1.0): - with tf.variable_scope(scope): - nin = x.get_shape()[3].value - w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale)) - b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0)) - z = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b - h = act(z) - return h - -def fc(x, scope, nh, act=tf.nn.relu, init_scale=1.0): - with tf.variable_scope(scope): - nin = x.get_shape()[1].value - w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) - b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(0.0)) - z = tf.matmul(x, w)+b - h = act(z) - return h def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, reuse=None): with tf.variable_scope(name, reuse=reuse): - assert (len(U.scope_name().split('/')) == 2) + assert (len(tf.get_variable_scope().name.split('/')) == 2) w = tf.get_variable("w", [x.get_shape()[1], size], initializer=weight_init) b = tf.get_variable("b", [size], initializer=tf.constant_initializer(bias_init)) @@ -75,15 +14,10 @@ def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, r weight_loss_dict[w] = weight_decay_fc weight_loss_dict[b] = 0.0 - tf.add_to_collection(U.scope_name().split('/')[0] + '_' + 'losses', weight_decay) + tf.add_to_collection(tf.get_variable_scope().name.split('/')[0] + '_' + 'losses', weight_decay) return tf.nn.bias_add(tf.matmul(x, w), b) -def conv_to_fc(x): - nh = np.prod([v.value for v in x.get_shape()[1:]]) - x = tf.reshape(x, [-1, nh]) - return x - def kl_div(action_dist1, action_dist2, action_size): mean1, std1 = action_dist1[:, :action_size], action_dist1[:, action_size:] mean2, std2 = action_dist2[:, :action_size], action_dist2[:, action_size:] @@ -92,109 +26,3 @@ def kl_div(action_dist1, action_dist2, action_size): denominator = 2 * tf.square(std2) + 1e-8 return tf.reduce_sum( numerator/denominator + tf.log(std2) - tf.log(std1),reduction_indices=-1) - -def discount_with_dones(rewards, dones, gamma): - discounted = [] - r = 0 - for reward, done in zip(rewards[::-1], dones[::-1]): - r = reward + gamma*r*(1.-done) # fixed off by one bug - discounted.append(r) - return discounted[::-1] - -def find_trainable_variables(key): - with tf.variable_scope(key): - return tf.trainable_variables() - -def make_path(f): - return os.makedirs(f, exist_ok=True) - -def constant(p): - return 1 - -def linear(p): - return 1-p - - -def middle_drop(p): - eps = 0.75 - if 1-p= 3 - def _reset(self, **kwargs): + def reset(self, **kwargs): self.env.reset(**kwargs) obs, _, done, _ = self.env.step(1) if done: @@ -47,6 +51,9 @@ class FireResetEnv(gym.Wrapper): self.env.reset(**kwargs) return obs + def step(self, ac): + return self.env.step(ac) + class EpisodicLifeEnv(gym.Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game over. @@ -56,7 +63,7 @@ class EpisodicLifeEnv(gym.Wrapper): self.lives = 0 self.was_real_done = True - def _step(self, action): + def step(self, action): obs, reward, done, info = self.env.step(action) self.was_real_done = done # check current lives, make loss of life terminal, @@ -70,7 +77,7 @@ class EpisodicLifeEnv(gym.Wrapper): self.lives = lives return obs, reward, done, info - def _reset(self, **kwargs): + def reset(self, **kwargs): """Reset only when lives are exhausted. This way all states are still reachable even though lives are episodic, and the learner need not know about any of this behind-the-scenes. @@ -88,10 +95,13 @@ class MaxAndSkipEnv(gym.Wrapper): """Return only every `skip`-th frame""" gym.Wrapper.__init__(self, env) # most recent raw observations (for max pooling across time steps) - self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype='uint8') + self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) self._skip = skip - def _step(self, action): + def reset(self): + return self.env.reset() + + def step(self, action): """Repeat action, sum reward, and max over last observations.""" total_reward = 0.0 done = None @@ -108,8 +118,14 @@ class MaxAndSkipEnv(gym.Wrapper): return max_frame, total_reward, done, info + def reset(self, **kwargs): + return self.env.reset(**kwargs) + class ClipRewardEnv(gym.RewardWrapper): - def _reward(self, reward): + def __init__(self, env): + gym.RewardWrapper.__init__(self, env) + + def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) @@ -119,9 +135,10 @@ class WarpFrame(gym.ObservationWrapper): gym.ObservationWrapper.__init__(self, env) self.width = 84 self.height = 84 - self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width, 1)) + self.observation_space = spaces.Box(low=0, high=255, + shape=(self.height, self.width, 1), dtype=np.uint8) - def _observation(self, frame): + def observation(self, frame): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) return frame[:, :, None] @@ -140,15 +157,15 @@ class FrameStack(gym.Wrapper): self.k = k self.frames = deque([], maxlen=k) shp = env.observation_space.shape - self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k)) + self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) - def _reset(self): + def reset(self): ob = self.env.reset() for _ in range(self.k): self.frames.append(ob) return self._get_ob() - def _step(self, action): + def step(self, action): ob, reward, done, info = self.env.step(action) self.frames.append(ob) return self._get_ob(), reward, done, info @@ -158,7 +175,10 @@ class FrameStack(gym.Wrapper): return LazyFrames(list(self.frames)) class ScaledFloatFrame(gym.ObservationWrapper): - def _observation(self, observation): + def __init__(self, env): + gym.ObservationWrapper.__init__(self, env) + + def observation(self, observation): # careful! This undoes the memory optimization, use # with smaller replay buffers only. return np.array(observation).astype(np.float32) / 255.0 diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py new file mode 100644 index 0000000..d8a48ae --- /dev/null +++ b/baselines/common/cmd_util.py @@ -0,0 +1,64 @@ +""" +Helpers for scripts like run_atari.py. +""" + +import os +import gym +from baselines import logger +from baselines.bench import Monitor +from baselines.common import set_global_seeds +from baselines.common.atari_wrappers import make_atari, wrap_deepmind +from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv +from mpi4py import MPI + +def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): + """ + Create a wrapped, monitored SubprocVecEnv for Atari. + """ + if wrapper_kwargs is None: wrapper_kwargs = {} + def make_env(rank): # pylint: disable=C0111 + def _thunk(): + env = make_atari(env_id) + env.seed(seed + rank) + env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) + return wrap_deepmind(env, **wrapper_kwargs) + return _thunk + set_global_seeds(seed) + return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) + +def make_mujoco_env(env_id, seed): + """ + Create a wrapped, monitored gym.Env for MuJoCo. + """ + set_global_seeds(seed) + env = gym.make(env_id) + env = Monitor(env, logger.get_dir()) + env.seed(seed) + return env + +def arg_parser(): + """ + Create an empty argparse.ArgumentParser. + """ + import argparse + return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +def atari_arg_parser(): + """ + Create an argparse.ArgumentParser for run_atari.py. + """ + parser = arg_parser() + parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') + parser.add_argument('--seed', help='RNG seed', type=int, default=0) + parser.add_argument('--num-timesteps', type=int, default=int(10e6)) + return parser + +def mujoco_arg_parser(): + """ + Create an argparse.ArgumentParser for run_mujoco.py. + """ + parser = arg_parser() + parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1") + parser.add_argument('--seed', help='RNG seed', type=int, default=0) + parser.add_argument('--num-timesteps', type=int, default=int(1e6)) + return parser diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 1dc02c3..6f5b522 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -57,14 +57,12 @@ class CategoricalPdType(PdType): class MultiCategoricalPdType(PdType): - def __init__(self, low, high): - self.low = low - self.high = high - self.ncats = high - low + 1 + def __init__(self, nvec): + self.ncats = nvec def pdclass(self): return MultiCategoricalPd def pdfromflat(self, flat): - return MultiCategoricalPd(self.low, self.high, flat) + return MultiCategoricalPd(self.ncats, flat) def param_shape(self): return [sum(self.ncats)] def sample_shape(self): @@ -125,7 +123,7 @@ class CategoricalPd(Pd): def flatparam(self): return self.logits def mode(self): - return U.argmax(self.logits, axis=-1) + return tf.argmax(self.logits, axis=-1) def neglogp(self, x): # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) # Note: we can't use sparse_softmax_cross_entropy_with_logits because @@ -135,20 +133,20 @@ class CategoricalPd(Pd): logits=self.logits, labels=one_hot_actions) def kl(self, other): - a0 = self.logits - U.max(self.logits, axis=-1, keepdims=True) - a1 = other.logits - U.max(other.logits, axis=-1, keepdims=True) + a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True) + a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keep_dims=True) ea0 = tf.exp(a0) ea1 = tf.exp(a1) - z0 = U.sum(ea0, axis=-1, keepdims=True) - z1 = U.sum(ea1, axis=-1, keepdims=True) + z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True) + z1 = tf.reduce_sum(ea1, axis=-1, keep_dims=True) p0 = ea0 / z0 - return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) + return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) def entropy(self): - a0 = self.logits - U.max(self.logits, axis=-1, keepdims=True) + a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True) ea0 = tf.exp(a0) - z0 = U.sum(ea0, axis=-1, keepdims=True) + z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True) p0 = ea0 / z0 - return U.sum(p0 * (tf.log(z0) - a0), axis=-1) + return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1) def sample(self): u = tf.random_uniform(tf.shape(self.logits)) return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) @@ -157,24 +155,21 @@ class CategoricalPd(Pd): return cls(flat) class MultiCategoricalPd(Pd): - def __init__(self, low, high, flat): + def __init__(self, nvec, flat): self.flat = flat - self.low = tf.constant(low, dtype=tf.int32) - self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1))) + self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1))) def flatparam(self): return self.flat def mode(self): - return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32) + return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32) def neglogp(self, x): - return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))]) + return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))]) def kl(self, other): - return tf.add_n([ - p.kl(q) for p, q in zip(self.categoricals, other.categoricals) - ]) + return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)]) def entropy(self): return tf.add_n([p.entropy() for p in self.categoricals]) def sample(self): - return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32) + return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32) @classmethod def fromflat(cls, flat): raise NotImplementedError @@ -191,14 +186,14 @@ class DiagGaussianPd(Pd): def mode(self): return self.mean def neglogp(self, x): - return 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=-1) \ + return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \ + 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \ - + U.sum(self.logstd, axis=-1) + + tf.reduce_sum(self.logstd, axis=-1) def kl(self, other): assert isinstance(other, DiagGaussianPd) - return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1) + return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1) def entropy(self): - return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1) + return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1) def sample(self): return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) @classmethod @@ -214,11 +209,11 @@ class BernoulliPd(Pd): def mode(self): return tf.round(self.ps) def neglogp(self, x): - return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1) + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1) def kl(self, other): - return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) def entropy(self): - return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) def sample(self): u = tf.random_uniform(tf.shape(self.ps)) return tf.to_float(math_ops.less(u, self.ps)) @@ -234,7 +229,7 @@ def make_pdtype(ac_space): elif isinstance(ac_space, spaces.Discrete): return CategoricalPdType(ac_space.n) elif isinstance(ac_space, spaces.MultiDiscrete): - return MultiCategoricalPdType(ac_space.low, ac_space.high) + return MultiCategoricalPdType(ac_space.nvec) elif isinstance(ac_space, spaces.MultiBinary): return BernoulliPdType(ac_space.n) else: @@ -259,6 +254,11 @@ def test_probtypes(): categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101 validate_probtype(categorical, pdparam_categorical) + nvec = [1,2,3] + pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1]) + multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101 + validate_probtype(multicategorical, pdparam_multicategorical) + pdparam_bernoulli = np.array([-.2, .3, .5]) bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101 validate_probtype(bernoulli, pdparam_bernoulli) @@ -270,10 +270,10 @@ def validate_probtype(probtype, pdparam): Mval = np.repeat(pdparam[None, :], N, axis=0) M = probtype.param_placeholder([N]) X = probtype.sample_placeholder([N]) - pd = probtype.pdclass()(M) + pd = probtype.pdfromflat(M) calcloglik = U.function([X, M], pd.logp(X)) calcent = U.function([M], pd.entropy()) - Xval = U.eval(pd.sample(), feed_dict={M:Mval}) + Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval}) logliks = calcloglik(Xval, Mval) entval_ll = - logliks.mean() #pylint: disable=E1101 entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 @@ -282,7 +282,7 @@ def validate_probtype(probtype, pdparam): # Check to see if kldiv[p,q] = - ent[p] - E_p[log q] M2 = probtype.param_placeholder([N]) - pd2 = probtype.pdclass()(M2) + pd2 = probtype.pdfromflat(M2) q = pdparam + np.random.randn(pdparam.size) * 0.1 Mval2 = np.repeat(q[None, :], N, axis=0) calckl = U.function([M, M2], pd.kl(pd2)) @@ -291,3 +291,5 @@ def validate_probtype(probtype, pdparam): klval_ll = - entval - logliks.mean() #pylint: disable=E1101 klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas + print('ok on', probtype, pdparam) + diff --git a/baselines/common/mpi_adam.py b/baselines/common/mpi_adam.py index 30ebaba..4902caf 100644 --- a/baselines/common/mpi_adam.py +++ b/baselines/common/mpi_adam.py @@ -53,7 +53,7 @@ class MpiAdam(object): def test_MpiAdam(): np.random.seed(0) tf.set_random_seed(0) - + a = tf.Variable(np.random.randn(3).astype('float32')) b = tf.Variable(np.random.randn(2,5).astype('float32')) loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b)) diff --git a/baselines/common/mpi_moments.py b/baselines/common/mpi_moments.py index c4c8fc0..d13cc2f 100644 --- a/baselines/common/mpi_moments.py +++ b/baselines/common/mpi_moments.py @@ -2,29 +2,41 @@ from mpi4py import MPI import numpy as np from baselines.common import zipsame -def mpi_moments(x, axis=0): - x = np.asarray(x, dtype='float64') - newshape = list(x.shape) - newshape.pop(axis) - n = np.prod(newshape,dtype=int) - totalvec = np.zeros(n*2+1, 'float64') - addvec = np.concatenate([x.sum(axis=axis).ravel(), - np.square(x).sum(axis=axis).ravel(), - np.array([x.shape[axis]],dtype='float64')]) - MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM) - sum = totalvec[:n] - sumsq = totalvec[n:2*n] - count = totalvec[2*n] - if count == 0: - mean = np.empty(newshape); mean[:] = np.nan - std = np.empty(newshape); std[:] = np.nan - else: - mean = sum/count - std = np.sqrt(np.maximum(sumsq/count - np.square(mean),0)) +def mpi_mean(x, axis=0, comm=None, keepdims=False): + x = np.asarray(x) + assert x.ndim > 0 + if comm is None: comm = MPI.COMM_WORLD + xsum = x.sum(axis=axis, keepdims=keepdims) + n = xsum.size + localsum = np.zeros(n+1, x.dtype) + localsum[:n] = xsum.ravel() + localsum[n] = x.shape[axis] + globalsum = np.zeros_like(localsum) + comm.Allreduce(localsum, globalsum, op=MPI.SUM) + return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] + +def mpi_moments(x, axis=0, comm=None, keepdims=False): + x = np.asarray(x) + assert x.ndim > 0 + mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) + sqdiffs = np.square(x - mean) + meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) + assert count1 == count + std = np.sqrt(meansqdiff) + if not keepdims: + newshape = mean.shape[:axis] + mean.shape[axis+1:] + mean = mean.reshape(newshape) + std = std.reshape(newshape) return mean, std, count def test_runningmeanstd(): + import subprocess + subprocess.check_call(['mpirun', '-np', '3', + 'python','-c', + 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) + +def _helper_runningmeanstd(): comm = MPI.COMM_WORLD np.random.seed(0) for (triple,axis) in [ @@ -45,6 +57,3 @@ def test_runningmeanstd(): assert np.allclose(a1, a2) print("ok!") -if __name__ == "__main__": - #mpirun -np 3 python