diff --git a/Dockerfile b/Dockerfile index a7c71bc..49a9c79 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,7 @@ -FROM ubuntu:16.04 +FROM python:3.6 -RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv +# RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv ENV CODE_DIR /root/code -ENV VENV /root/venv - -RUN \ - pip install virtualenv && \ - virtualenv $VENV --python=python3 && \ - . $VENV/bin/activate && \ - pip install --upgrade pip - -ENV PATH=$VENV/bin:$PATH COPY . $CODE_DIR/baselines WORKDIR $CODE_DIR/baselines diff --git a/baselines/common/mpi_adam.py b/baselines/common/mpi_adam.py index 17491d7..10f1195 100644 --- a/baselines/common/mpi_adam.py +++ b/baselines/common/mpi_adam.py @@ -1,7 +1,11 @@ -from mpi4py import MPI import baselines.common.tf_util as U import tensorflow as tf import numpy as np +try: + from mpi4py import MPI +except ImportError: + MPI = None + class MpiAdam(object): def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): @@ -16,16 +20,19 @@ class MpiAdam(object): self.t = 0 self.setfromflat = U.SetFromFlat(var_list) self.getflat = U.GetFlat(var_list) - self.comm = MPI.COMM_WORLD if comm is None else comm + self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm def update(self, localg, stepsize): if self.t % 100 == 0: self.check_synced() localg = localg.astype('float32') - globalg = np.zeros_like(localg) - self.comm.Allreduce(localg, globalg, op=MPI.SUM) - if self.scale_grad_by_procs: - globalg /= self.comm.Get_size() + if self.comm is not None: + globalg = np.zeros_like(localg) + self.comm.Allreduce(localg, globalg, op=MPI.SUM) + if self.scale_grad_by_procs: + globalg /= self.comm.Get_size() + else: + globalg = np.copy(localg) self.t += 1 a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t) @@ -35,11 +42,15 @@ class MpiAdam(object): self.setfromflat(self.getflat() + step) def sync(self): + if self.comm is None: + return theta = self.getflat() self.comm.Bcast(theta, root=0) self.setfromflat(theta) def check_synced(self): + if self.comm is None: + return if self.comm.Get_rank() == 0: # this is root theta = self.getflat() self.comm.Bcast(theta, root=0) @@ -63,17 +74,30 @@ def test_MpiAdam(): do_update = U.function([], loss, updates=[update_op]) tf.get_default_session().run(tf.global_variables_initializer()) + losslist_ref = [] for i in range(10): - print(i,do_update()) + l = do_update() + print(i, l) + losslist_ref.append(l) + + tf.set_random_seed(0) tf.get_default_session().run(tf.global_variables_initializer()) var_list = [a,b] - lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op]) + lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)]) adam = MpiAdam(var_list) + losslist_test = [] for i in range(10): l,g = lossandgrad() adam.update(g, stepsize) print(i,l) + losslist_test.append(l) + + np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4) + + +if __name__ == '__main__': + test_MpiAdam() diff --git a/baselines/common/mpi_running_mean_std.py b/baselines/common/mpi_running_mean_std.py index 408f8a2..488d2a1 100644 --- a/baselines/common/mpi_running_mean_std.py +++ b/baselines/common/mpi_running_mean_std.py @@ -1,4 +1,8 @@ -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None + import tensorflow as tf, baselines.common.tf_util as U, numpy as np class RunningMeanStd(object): @@ -39,7 +43,8 @@ class RunningMeanStd(object): n = int(np.prod(self.shape)) totalvec = np.zeros(n*2+1, 'float64') addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')]) - MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM) + if MPI is not None: + MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM) self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n]) @U.in_session diff --git a/baselines/ddpg/ddpg.py b/baselines/ddpg/ddpg.py index 307205f..37551d4 100755 --- a/baselines/ddpg/ddpg.py +++ b/baselines/ddpg/ddpg.py @@ -12,8 +12,11 @@ import baselines.common.tf_util as U from baselines import logger import numpy as np -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None def learn(network, env, seed=None, @@ -49,7 +52,11 @@ def learn(network, env, else: nb_epochs = 500 - rank = MPI.COMM_WORLD.Get_rank() + if MPI is not None: + rank = MPI.COMM_WORLD.Get_rank() + else: + rank = 0 + nb_actions = env.action_space.shape[-1] assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions. @@ -199,7 +206,11 @@ def learn(network, env, eval_episode_rewards_history.append(eval_episode_reward[d]) eval_episode_reward[d] = 0.0 - mpi_size = MPI.COMM_WORLD.Get_size() + if MPI is not None: + mpi_size = MPI.COMM_WORLD.Get_size() + else: + mpi_size = 1 + # Log stats. # XXX shouldn't call np.mean on variable length lists duration = time.time() - start_time @@ -233,7 +244,10 @@ def learn(network, env, else: raise ValueError('expected scalar, got %s'%x) - combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])) + combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]) + if MPI is not None: + combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums) + combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)} # Total statistics. diff --git a/baselines/ddpg/ddpg_learner.py b/baselines/ddpg/ddpg_learner.py index 5b3b5ea..b8a5d60 100755 --- a/baselines/ddpg/ddpg_learner.py +++ b/baselines/ddpg/ddpg_learner.py @@ -9,7 +9,10 @@ from baselines import logger from baselines.common.mpi_adam import MpiAdam import baselines.common.tf_util as U from baselines.common.mpi_running_mean_std import RunningMeanStd -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None def normalize(x, stats): if stats is None: @@ -358,6 +361,11 @@ class DDPG(object): return stats def adapt_param_noise(self): + try: + from mpi4py import MPI + except ImportError: + MPI = None + if self.param_noise is None: return 0. @@ -371,7 +379,16 @@ class DDPG(object): self.param_noise_stddev: self.param_noise.current_stddev, }) - mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() + if MPI is not None: + mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() + else: + mean_distance = distance + + if MPI is not None: + mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() + else: + mean_distance = distance + self.param_noise.adapt(mean_distance) return mean_distance diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index 9a13003..70df06d 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -10,11 +10,15 @@ from baselines.common import explained_variance, set_global_seeds from baselines.common.policies import build_policy from baselines.common.runners import AbstractEnvRunner from baselines.common.tf_util import get_session, save_variables, load_variables -from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer -from mpi4py import MPI +try: + from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer + from mpi4py import MPI + from baselines.common.mpi_util import sync_from_root +except ImportError: + MPI = None + from baselines.common.tf_util import initialize -from baselines.common.mpi_util import sync_from_root class Model(object): """ @@ -93,7 +97,10 @@ class Model(object): # 1. Get the model parameters params = tf.trainable_variables('ppo2_model') # 2. Build our trainer - trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) + if MPI is not None: + trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) + else: + trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) # 3. Calculate the gradients grads_and_var = trainer.compute_gradients(loss, params) grads, var = zip(*grads_and_var) @@ -136,10 +143,12 @@ class Model(object): self.save = functools.partial(save_variables, sess=sess) self.load = functools.partial(load_variables, sess=sess) - if MPI.COMM_WORLD.Get_rank() == 0: + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: initialize() global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="") - sync_from_root(sess, global_variables) #pylint: disable=E1101 + + if MPI is not None: + sync_from_root(sess, global_variables) #pylint: disable=E1101 class Runner(AbstractEnvRunner): """ @@ -392,9 +401,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2 logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval) - if MPI.COMM_WORLD.Get_rank() == 0: + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: logger.dumpkvs() - if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and MPI.COMM_WORLD.Get_rank() == 0: + if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0): checkdir = osp.join(logger.get_dir(), 'checkpoints') os.makedirs(checkdir, exist_ok=True) savepath = osp.join(checkdir, '%.5i'%update) diff --git a/baselines/trpo_mpi/trpo_mpi.py b/baselines/trpo_mpi/trpo_mpi.py index ec0a991..cd1e7ea 100644 --- a/baselines/trpo_mpi/trpo_mpi.py +++ b/baselines/trpo_mpi/trpo_mpi.py @@ -4,7 +4,6 @@ import baselines.common.tf_util as U import tensorflow as tf, numpy as np import time from baselines.common import colorize -from mpi4py import MPI from collections import deque from baselines.common import set_global_seeds from baselines.common.mpi_adam import MpiAdam @@ -13,6 +12,11 @@ from baselines.common.input import observation_placeholder from baselines.common.policies import build_policy from contextlib import contextmanager +try: + from mpi4py import MPI +except ImportError: + MPI = None + def traj_segment_generator(pi, env, horizon, stochastic): # Initialize state variables t = 0 @@ -146,9 +150,12 @@ def learn(*, ''' - - nworkers = MPI.COMM_WORLD.Get_size() - rank = MPI.COMM_WORLD.Get_rank() + if MPI is not None: + nworkers = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + else: + nworkers = 1 + rank = 0 cpus_per_worker = 1 U.get_session(config=tf.ConfigProto( @@ -237,9 +244,13 @@ def learn(*, def allmean(x): assert isinstance(x, np.ndarray) - out = np.empty_like(x) - MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) - out /= nworkers + if MPI is not None: + out = np.empty_like(x) + MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) + out /= nworkers + else: + out = np.copy(x) + return out U.initialize() @@ -247,7 +258,9 @@ def learn(*, pi.load(load_path) th_init = get_flat() - MPI.COMM_WORLD.Bcast(th_init, root=0) + if MPI is not None: + MPI.COMM_WORLD.Bcast(th_init, root=0) + set_from_flat(th_init) vfadam.sync() print("Init param sum", th_init.sum(), flush=True) @@ -353,7 +366,11 @@ def learn(*, logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values - listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples + if MPI is not None: + listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples + else: + listoflrpairs = [lrlocal] + lens, rews = map(flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) diff --git a/setup.py b/setup.py index 7244c18..f77faf0 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,9 @@ extras = { ], 'bullet': [ 'pybullet', + ], + 'mpi': [ + 'mpi4py' ] } @@ -34,7 +37,6 @@ setup(name='baselines', 'joblib', 'dill', 'progressbar2', - 'mpi4py', 'cloudpickle', 'click', 'opencv-python'