add assertion to test in mpi_adam; fix trpo_mpi failure without MPI on cartpole
This commit is contained in:
@@ -26,11 +26,13 @@ class MpiAdam(object):
|
||||
if self.t % 100 == 0:
|
||||
self.check_synced()
|
||||
localg = localg.astype('float32')
|
||||
globalg = np.zeros_like(localg)
|
||||
if self.comm is not None:
|
||||
globalg = np.zeros_like(localg)
|
||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||
if self.scale_grad_by_procs:
|
||||
globalg /= self.comm.Get_size()
|
||||
if self.scale_grad_by_procs:
|
||||
globalg /= self.comm.Get_size()
|
||||
else:
|
||||
globalg = np.copy(localg)
|
||||
|
||||
self.t += 1
|
||||
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
|
||||
@@ -40,19 +42,22 @@ class MpiAdam(object):
|
||||
self.setfromflat(self.getflat() + step)
|
||||
|
||||
def sync(self):
|
||||
if self.comm is None:
|
||||
return
|
||||
theta = self.getflat()
|
||||
self.comm.Bcast(theta, root=0)
|
||||
self.setfromflat(theta)
|
||||
|
||||
def check_synced(self):
|
||||
if self.comm is None:
|
||||
return
|
||||
if self.comm.Get_rank() == 0: # this is root
|
||||
theta = self.getflat()
|
||||
self.comm.Bcast(theta, root=0)
|
||||
else:
|
||||
thetalocal = self.getflat()
|
||||
thetaroot = np.empty_like(thetalocal)
|
||||
if self.comm is not None:
|
||||
self.comm.Bcast(thetaroot, root=0)
|
||||
self.comm.Bcast(thetaroot, root=0)
|
||||
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
|
||||
|
||||
@U.in_session
|
||||
@@ -69,17 +74,30 @@ def test_MpiAdam():
|
||||
do_update = U.function([], loss, updates=[update_op])
|
||||
|
||||
tf.get_default_session().run(tf.global_variables_initializer())
|
||||
losslist_ref = []
|
||||
for i in range(10):
|
||||
print(i,do_update())
|
||||
l = do_update()
|
||||
print(i, l)
|
||||
losslist_ref.append(l)
|
||||
|
||||
|
||||
|
||||
tf.set_random_seed(0)
|
||||
tf.get_default_session().run(tf.global_variables_initializer())
|
||||
|
||||
var_list = [a,b]
|
||||
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
|
||||
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
|
||||
adam = MpiAdam(var_list)
|
||||
|
||||
losslist_test = []
|
||||
for i in range(10):
|
||||
l,g = lossandgrad()
|
||||
adam.update(g, stepsize)
|
||||
print(i,l)
|
||||
losslist_test.append(l)
|
||||
|
||||
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_MpiAdam()
|
||||
|
@@ -244,10 +244,13 @@ def learn(*,
|
||||
|
||||
def allmean(x):
|
||||
assert isinstance(x, np.ndarray)
|
||||
out = np.empty_like(x)
|
||||
if MPI is not None:
|
||||
out = np.empty_like(x)
|
||||
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
||||
out /= nworkers
|
||||
else:
|
||||
out = np.copy(x)
|
||||
|
||||
return out
|
||||
|
||||
U.initialize()
|
||||
|
Reference in New Issue
Block a user