diff --git a/baselines/common/mpi_adam.py b/baselines/common/mpi_adam.py index 17491d7..fd10bd5 100644 --- a/baselines/common/mpi_adam.py +++ b/baselines/common/mpi_adam.py @@ -1,7 +1,11 @@ -from mpi4py import MPI import baselines.common.tf_util as U import tensorflow as tf import numpy as np +try: + from mpi4py import MPI +except ImportError: + MPI = None + class MpiAdam(object): def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): @@ -16,14 +20,15 @@ class MpiAdam(object): self.t = 0 self.setfromflat = U.SetFromFlat(var_list) self.getflat = U.GetFlat(var_list) - self.comm = MPI.COMM_WORLD if comm is None else comm + self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm def update(self, localg, stepsize): if self.t % 100 == 0: self.check_synced() localg = localg.astype('float32') globalg = np.zeros_like(localg) - self.comm.Allreduce(localg, globalg, op=MPI.SUM) + if self.comm is not None: + self.comm.Allreduce(localg, globalg, op=MPI.SUM) if self.scale_grad_by_procs: globalg /= self.comm.Get_size() @@ -46,7 +51,8 @@ class MpiAdam(object): else: thetalocal = self.getflat() thetaroot = np.empty_like(thetalocal) - self.comm.Bcast(thetaroot, root=0) + if self.comm is not None: + self.comm.Bcast(thetaroot, root=0) assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) @U.in_session