MpiAdam becomes regular Adam if Mpi not present
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
from mpi4py import MPI
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
|
||||
class MpiAdam(object):
|
||||
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
||||
@@ -16,14 +20,15 @@ class MpiAdam(object):
|
||||
self.t = 0
|
||||
self.setfromflat = U.SetFromFlat(var_list)
|
||||
self.getflat = U.GetFlat(var_list)
|
||||
self.comm = MPI.COMM_WORLD if comm is None else comm
|
||||
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
|
||||
|
||||
def update(self, localg, stepsize):
|
||||
if self.t % 100 == 0:
|
||||
self.check_synced()
|
||||
localg = localg.astype('float32')
|
||||
globalg = np.zeros_like(localg)
|
||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||
if self.comm is not None:
|
||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||
if self.scale_grad_by_procs:
|
||||
globalg /= self.comm.Get_size()
|
||||
|
||||
@@ -46,7 +51,8 @@ class MpiAdam(object):
|
||||
else:
|
||||
thetalocal = self.getflat()
|
||||
thetaroot = np.empty_like(thetalocal)
|
||||
self.comm.Bcast(thetaroot, root=0)
|
||||
if self.comm is not None:
|
||||
self.comm.Bcast(thetaroot, root=0)
|
||||
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
|
||||
|
||||
@U.in_session
|
||||
|
Reference in New Issue
Block a user