mpiless ddpg

This commit is contained in:
Peter Zhokhov
2018-10-31 09:48:41 -07:00
parent f0d49fb67d
commit d1f7d12743
2 changed files with 27 additions and 4 deletions

View File

@@ -12,8 +12,11 @@ import baselines.common.tf_util as U
from baselines import logger
import numpy as np
from mpi4py import MPI
try:
from mpi4py import MPI
except ImportError:
MPI = None
def learn(network, env,
seed=None,
@@ -49,7 +52,11 @@ def learn(network, env,
else:
nb_epochs = 500
rank = MPI.COMM_WORLD.Get_rank()
if MPI is not None:
rank = MPI.COMM_WORLD.Get_rank()
else:
rank = 0
nb_actions = env.action_space.shape[-1]
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
@@ -200,7 +207,11 @@ def learn(network, env,
eval_episode_rewards_history.append(eval_episode_reward[d])
eval_episode_reward[d] = 0.0
mpi_size = MPI.COMM_WORLD.Get_size()
if MPI is not None:
mpi_size = MPI.COMM_WORLD.Get_size()
else:
mpi_size = 1
# Log stats.
# XXX shouldn't call np.mean on variable length lists
duration = time.time() - start_time
@@ -234,7 +245,10 @@ def learn(network, env,
else:
raise ValueError('expected scalar, got %s'%x)
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
if MPI is not None:
combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
# Total statistics.

View File

@@ -9,6 +9,10 @@ from baselines import logger
from baselines.common.mpi_adam import MpiAdam
import baselines.common.tf_util as U
from baselines.common.mpi_running_mean_std import RunningMeanStd
try:
from mpi4py import MPI
except ImportError:
MPI = None
def normalize(x, stats):
if stats is None:
@@ -375,6 +379,11 @@ class DDPG(object):
self.param_noise_stddev: self.param_noise.current_stddev,
})
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else:
mean_distance = distance
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else: