mpiless ddpg
This commit is contained in:
@@ -12,8 +12,11 @@ import baselines.common.tf_util as U
|
|||||||
|
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mpi4py import MPI
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
def learn(network, env,
|
def learn(network, env,
|
||||||
seed=None,
|
seed=None,
|
||||||
@@ -49,7 +52,11 @@ def learn(network, env,
|
|||||||
else:
|
else:
|
||||||
nb_epochs = 500
|
nb_epochs = 500
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
|
||||||
nb_actions = env.action_space.shape[-1]
|
nb_actions = env.action_space.shape[-1]
|
||||||
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
||||||
|
|
||||||
@@ -200,7 +207,11 @@ def learn(network, env,
|
|||||||
eval_episode_rewards_history.append(eval_episode_reward[d])
|
eval_episode_rewards_history.append(eval_episode_reward[d])
|
||||||
eval_episode_reward[d] = 0.0
|
eval_episode_reward[d] = 0.0
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
mpi_size = MPI.COMM_WORLD.Get_size()
|
mpi_size = MPI.COMM_WORLD.Get_size()
|
||||||
|
else:
|
||||||
|
mpi_size = 1
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
# XXX shouldn't call np.mean on variable length lists
|
# XXX shouldn't call np.mean on variable length lists
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
@@ -234,7 +245,10 @@ def learn(network, env,
|
|||||||
else:
|
else:
|
||||||
raise ValueError('expected scalar, got %s'%x)
|
raise ValueError('expected scalar, got %s'%x)
|
||||||
|
|
||||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
|
combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
|
||||||
|
if MPI is not None:
|
||||||
|
combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
|
||||||
|
|
||||||
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
||||||
|
|
||||||
# Total statistics.
|
# Total statistics.
|
||||||
|
@@ -9,6 +9,10 @@ from baselines import logger
|
|||||||
from baselines.common.mpi_adam import MpiAdam
|
from baselines.common.mpi_adam import MpiAdam
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
def normalize(x, stats):
|
def normalize(x, stats):
|
||||||
if stats is None:
|
if stats is None:
|
||||||
@@ -375,6 +379,11 @@ class DDPG(object):
|
|||||||
self.param_noise_stddev: self.param_noise.current_stddev,
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||||
|
else:
|
||||||
|
mean_distance = distance
|
||||||
|
|
||||||
if MPI is not None:
|
if MPI is not None:
|
||||||
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user