From f0d49fb67d9374498441fb41bd7ea4048ddb9090 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Tue, 30 Oct 2018 14:45:20 -0700 Subject: [PATCH] add assertion to test in mpi_adam; fix trpo_mpi failure without MPI on cartpole --- baselines/common/mpi_adam.py | 32 +++++++++++++++++++++++++------- baselines/trpo_mpi/trpo_mpi.py | 5 ++++- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/baselines/common/mpi_adam.py b/baselines/common/mpi_adam.py index 9168e5b..10f1195 100644 --- a/baselines/common/mpi_adam.py +++ b/baselines/common/mpi_adam.py @@ -26,11 +26,13 @@ class MpiAdam(object): if self.t % 100 == 0: self.check_synced() localg = localg.astype('float32') - globalg = np.zeros_like(localg) if self.comm is not None: + globalg = np.zeros_like(localg) self.comm.Allreduce(localg, globalg, op=MPI.SUM) - if self.scale_grad_by_procs: - globalg /= self.comm.Get_size() + if self.scale_grad_by_procs: + globalg /= self.comm.Get_size() + else: + globalg = np.copy(localg) self.t += 1 a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t) @@ -40,19 +42,22 @@ class MpiAdam(object): self.setfromflat(self.getflat() + step) def sync(self): + if self.comm is None: + return theta = self.getflat() self.comm.Bcast(theta, root=0) self.setfromflat(theta) def check_synced(self): + if self.comm is None: + return if self.comm.Get_rank() == 0: # this is root theta = self.getflat() self.comm.Bcast(theta, root=0) else: thetalocal = self.getflat() thetaroot = np.empty_like(thetalocal) - if self.comm is not None: - self.comm.Bcast(thetaroot, root=0) + self.comm.Bcast(thetaroot, root=0) assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) @U.in_session @@ -69,17 +74,30 @@ def test_MpiAdam(): do_update = U.function([], loss, updates=[update_op]) tf.get_default_session().run(tf.global_variables_initializer()) + losslist_ref = [] for i in range(10): - print(i,do_update()) + l = do_update() + print(i, l) + losslist_ref.append(l) + + tf.set_random_seed(0) tf.get_default_session().run(tf.global_variables_initializer()) var_list = [a,b] - lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op]) + lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)]) adam = MpiAdam(var_list) + losslist_test = [] for i in range(10): l,g = lossandgrad() adam.update(g, stepsize) print(i,l) + losslist_test.append(l) + + np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4) + + +if __name__ == '__main__': + test_MpiAdam() diff --git a/baselines/trpo_mpi/trpo_mpi.py b/baselines/trpo_mpi/trpo_mpi.py index 17d9e5c..cd1e7ea 100644 --- a/baselines/trpo_mpi/trpo_mpi.py +++ b/baselines/trpo_mpi/trpo_mpi.py @@ -244,10 +244,13 @@ def learn(*, def allmean(x): assert isinstance(x, np.ndarray) - out = np.empty_like(x) if MPI is not None: + out = np.empty_like(x) MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) out /= nworkers + else: + out = np.copy(x) + return out U.initialize()