diff --git a/baselines/ddpg/ddpg.py b/baselines/ddpg/ddpg.py index 8b8659b..f0843be 100755 --- a/baselines/ddpg/ddpg.py +++ b/baselines/ddpg/ddpg.py @@ -12,8 +12,11 @@ import baselines.common.tf_util as U from baselines import logger import numpy as np -from mpi4py import MPI +try: + from mpi4py import MPI +except ImportError: + MPI = None def learn(network, env, seed=None, @@ -49,7 +52,11 @@ def learn(network, env, else: nb_epochs = 500 - rank = MPI.COMM_WORLD.Get_rank() + if MPI is not None: + rank = MPI.COMM_WORLD.Get_rank() + else: + rank = 0 + nb_actions = env.action_space.shape[-1] assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions. @@ -200,7 +207,11 @@ def learn(network, env, eval_episode_rewards_history.append(eval_episode_reward[d]) eval_episode_reward[d] = 0.0 - mpi_size = MPI.COMM_WORLD.Get_size() + if MPI is not None: + mpi_size = MPI.COMM_WORLD.Get_size() + else: + mpi_size = 1 + # Log stats. # XXX shouldn't call np.mean on variable length lists duration = time.time() - start_time @@ -234,7 +245,10 @@ def learn(network, env, else: raise ValueError('expected scalar, got %s'%x) - combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])) + combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]) + if MPI is not None: + combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums) + combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)} # Total statistics. diff --git a/baselines/ddpg/ddpg_learner.py b/baselines/ddpg/ddpg_learner.py index b8b9fb3..ba3f92c 100755 --- a/baselines/ddpg/ddpg_learner.py +++ b/baselines/ddpg/ddpg_learner.py @@ -9,6 +9,10 @@ from baselines import logger from baselines.common.mpi_adam import MpiAdam import baselines.common.tf_util as U from baselines.common.mpi_running_mean_std import RunningMeanStd +try: + from mpi4py import MPI +except ImportError: + MPI = None def normalize(x, stats): if stats is None: @@ -375,6 +379,11 @@ class DDPG(object): self.param_noise_stddev: self.param_noise.current_stddev, }) + if MPI is not None: + mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() + else: + mean_distance = distance + if MPI is not None: mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() else: