diff --git a/baselines/common/mpi_adam_optimizer.py b/baselines/common/mpi_adam_optimizer.py index dcbcd74..3d7cee5 100644 --- a/baselines/common/mpi_adam_optimizer.py +++ b/baselines/common/mpi_adam_optimizer.py @@ -2,6 +2,7 @@ import numpy as np import tensorflow as tf from baselines.common import tf_util as U from baselines.common.tests.test_with_mpi import with_mpi +from baselines import logger try: from mpi4py import MPI except ImportError: @@ -9,8 +10,9 @@ except ImportError: class MpiAdamOptimizer(tf.train.AdamOptimizer): """Adam optimizer that averages gradients across mpi processes.""" - def __init__(self, comm, mpi_rank_weight=1, **kwargs): + def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs): self.comm = comm + self.grad_clip = grad_clip self.mpi_rank_weight = mpi_rank_weight tf.train.AdamOptimizer.__init__(self, **kwargs) def compute_gradients(self, loss, var_list, **kwargs): @@ -28,6 +30,12 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer): countholder = [0] # Counts how many times _collect_grads has been called stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable def _collect_grads(flat_grad, np_stat): + if self.grad_clip is not None: + gradnorm = np.linalg.norm(flat_grad) + if gradnorm > 1: + flat_grad /= gradnorm + logger.logkv_mean('gradnorm', gradnorm) + logger.logkv_mean('gradclipfrac', float(gradnorm > 1)) self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) np.divide(buf, float(total_weight), out=buf) if countholder[0] % 100 == 0: @@ -56,8 +64,8 @@ def check_synced(localval, comm=None): comm = comm or MPI.COMM_WORLD vals = comm.gather(localval) if comm.rank == 0: - assert all(val==vals[0] for val in vals[1:]) - + assert all(val==vals[0] for val in vals[1:]),\ + f'MpiAdamOptimizer detected that different workers have different weights: {vals}' @with_mpi(timeout=5) def test_nonfreeze():