add assertion to test in mpi_adam; fix trpo_mpi failure without MPI on cartpole

This commit is contained in:
Peter Zhokhov
2018-10-30 14:45:20 -07:00
parent ef2e7246c9
commit f0d49fb67d
2 changed files with 29 additions and 8 deletions

View File

@@ -26,11 +26,13 @@ class MpiAdam(object):
if self.t % 100 == 0:
self.check_synced()
localg = localg.astype('float32')
globalg = np.zeros_like(localg)
if self.comm is not None:
globalg = np.zeros_like(localg)
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
if self.scale_grad_by_procs:
globalg /= self.comm.Get_size()
if self.scale_grad_by_procs:
globalg /= self.comm.Get_size()
else:
globalg = np.copy(localg)
self.t += 1
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
@@ -40,19 +42,22 @@ class MpiAdam(object):
self.setfromflat(self.getflat() + step)
def sync(self):
if self.comm is None:
return
theta = self.getflat()
self.comm.Bcast(theta, root=0)
self.setfromflat(theta)
def check_synced(self):
if self.comm is None:
return
if self.comm.Get_rank() == 0: # this is root
theta = self.getflat()
self.comm.Bcast(theta, root=0)
else:
thetalocal = self.getflat()
thetaroot = np.empty_like(thetalocal)
if self.comm is not None:
self.comm.Bcast(thetaroot, root=0)
self.comm.Bcast(thetaroot, root=0)
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
@U.in_session
@@ -69,17 +74,30 @@ def test_MpiAdam():
do_update = U.function([], loss, updates=[update_op])
tf.get_default_session().run(tf.global_variables_initializer())
losslist_ref = []
for i in range(10):
print(i,do_update())
l = do_update()
print(i, l)
losslist_ref.append(l)
tf.set_random_seed(0)
tf.get_default_session().run(tf.global_variables_initializer())
var_list = [a,b]
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
adam = MpiAdam(var_list)
losslist_test = []
for i in range(10):
l,g = lossandgrad()
adam.update(g, stepsize)
print(i,l)
losslist_test.append(l)
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
if __name__ == '__main__':
test_MpiAdam()

View File

@@ -244,10 +244,13 @@ def learn(*,
def allmean(x):
assert isinstance(x, np.ndarray)
out = np.empty_like(x)
if MPI is not None:
out = np.empty_like(x)
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
out /= nworkers
else:
out = np.copy(x)
return out
U.initialize()