MpiAdam becomes regular Adam if Mpi not present
This commit is contained in:
@@ -1,7 +1,11 @@
|
|||||||
from mpi4py import MPI
|
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
|
|
||||||
class MpiAdam(object):
|
class MpiAdam(object):
|
||||||
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
||||||
@@ -16,14 +20,15 @@ class MpiAdam(object):
|
|||||||
self.t = 0
|
self.t = 0
|
||||||
self.setfromflat = U.SetFromFlat(var_list)
|
self.setfromflat = U.SetFromFlat(var_list)
|
||||||
self.getflat = U.GetFlat(var_list)
|
self.getflat = U.GetFlat(var_list)
|
||||||
self.comm = MPI.COMM_WORLD if comm is None else comm
|
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
|
||||||
|
|
||||||
def update(self, localg, stepsize):
|
def update(self, localg, stepsize):
|
||||||
if self.t % 100 == 0:
|
if self.t % 100 == 0:
|
||||||
self.check_synced()
|
self.check_synced()
|
||||||
localg = localg.astype('float32')
|
localg = localg.astype('float32')
|
||||||
globalg = np.zeros_like(localg)
|
globalg = np.zeros_like(localg)
|
||||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
if self.comm is not None:
|
||||||
|
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||||
if self.scale_grad_by_procs:
|
if self.scale_grad_by_procs:
|
||||||
globalg /= self.comm.Get_size()
|
globalg /= self.comm.Get_size()
|
||||||
|
|
||||||
@@ -46,7 +51,8 @@ class MpiAdam(object):
|
|||||||
else:
|
else:
|
||||||
thetalocal = self.getflat()
|
thetalocal = self.getflat()
|
||||||
thetaroot = np.empty_like(thetalocal)
|
thetaroot = np.empty_like(thetalocal)
|
||||||
self.comm.Bcast(thetaroot, root=0)
|
if self.comm is not None:
|
||||||
|
self.comm.Bcast(thetaroot, root=0)
|
||||||
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
|
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
|
||||||
|
|
||||||
@U.in_session
|
@U.in_session
|
||||||
|
Reference in New Issue
Block a user