MpiAdam becomes regular Adam if Mpi not present

This commit is contained in:
Peter Zhokhov
2018-10-30 14:04:30 -07:00
parent d00f3bce34
commit 3e3e2b7998

View File

@@ -1,7 +1,11 @@
from mpi4py import MPI
import baselines.common.tf_util as U
import tensorflow as tf
import numpy as np
try:
from mpi4py import MPI
except ImportError:
MPI = None
class MpiAdam(object):
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
@@ -16,14 +20,15 @@ class MpiAdam(object):
self.t = 0
self.setfromflat = U.SetFromFlat(var_list)
self.getflat = U.GetFlat(var_list)
self.comm = MPI.COMM_WORLD if comm is None else comm
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
def update(self, localg, stepsize):
if self.t % 100 == 0:
self.check_synced()
localg = localg.astype('float32')
globalg = np.zeros_like(localg)
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
if self.comm is not None:
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
if self.scale_grad_by_procs:
globalg /= self.comm.Get_size()
@@ -46,7 +51,8 @@ class MpiAdam(object):
else:
thetalocal = self.getflat()
thetaroot = np.empty_like(thetalocal)
self.comm.Bcast(thetaroot, root=0)
if self.comm is not None:
self.comm.Bcast(thetaroot, root=0)
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
@U.in_session