MpiAdam becomes regular Adam if Mpi not present

This commit is contained in:
Peter Zhokhov
2018-10-30 14:04:30 -07:00
parent d00f3bce34
commit 3e3e2b7998

View File

@@ -1,7 +1,11 @@
from mpi4py import MPI
import baselines.common.tf_util as U import baselines.common.tf_util as U
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
try:
from mpi4py import MPI
except ImportError:
MPI = None
class MpiAdam(object): class MpiAdam(object):
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
@@ -16,14 +20,15 @@ class MpiAdam(object):
self.t = 0 self.t = 0
self.setfromflat = U.SetFromFlat(var_list) self.setfromflat = U.SetFromFlat(var_list)
self.getflat = U.GetFlat(var_list) self.getflat = U.GetFlat(var_list)
self.comm = MPI.COMM_WORLD if comm is None else comm self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
def update(self, localg, stepsize): def update(self, localg, stepsize):
if self.t % 100 == 0: if self.t % 100 == 0:
self.check_synced() self.check_synced()
localg = localg.astype('float32') localg = localg.astype('float32')
globalg = np.zeros_like(localg) globalg = np.zeros_like(localg)
self.comm.Allreduce(localg, globalg, op=MPI.SUM) if self.comm is not None:
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
if self.scale_grad_by_procs: if self.scale_grad_by_procs:
globalg /= self.comm.Get_size() globalg /= self.comm.Get_size()
@@ -46,7 +51,8 @@ class MpiAdam(object):
else: else:
thetalocal = self.getflat() thetalocal = self.getflat()
thetaroot = np.empty_like(thetalocal) thetaroot = np.empty_like(thetalocal)
self.comm.Bcast(thetaroot, root=0) if self.comm is not None:
self.comm.Bcast(thetaroot, root=0)
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
@U.in_session @U.in_session