Grad clipping in MpiAdamOptimizer, transformer changes (#304)
* transformer mnist experiments * version that only builds one model * work on inverted mnist * Add grad clipping to MpiAdamOptimizer * various * transformer changes, loading * get rid of soft labels * transformer baseline * minor * experiments involving all possible training sets * vary training * minor * get ready for fine-tuning expers * lint * minor
This commit is contained in:
committed by
Peter Zhokhov
parent
5082e5d34b
commit
07cbf1e26a
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines.common import tf_util as U
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
from baselines import logger
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
@@ -9,8 +10,9 @@ except ImportError:
|
||||
|
||||
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||
"""Adam optimizer that averages gradients across mpi processes."""
|
||||
def __init__(self, comm, mpi_rank_weight=1, **kwargs):
|
||||
def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs):
|
||||
self.comm = comm
|
||||
self.grad_clip = grad_clip
|
||||
self.mpi_rank_weight = mpi_rank_weight
|
||||
tf.train.AdamOptimizer.__init__(self, **kwargs)
|
||||
def compute_gradients(self, loss, var_list, **kwargs):
|
||||
@@ -28,6 +30,12 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||
countholder = [0] # Counts how many times _collect_grads has been called
|
||||
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
|
||||
def _collect_grads(flat_grad, np_stat):
|
||||
if self.grad_clip is not None:
|
||||
gradnorm = np.linalg.norm(flat_grad)
|
||||
if gradnorm > 1:
|
||||
flat_grad /= gradnorm
|
||||
logger.logkv_mean('gradnorm', gradnorm)
|
||||
logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
|
||||
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
||||
np.divide(buf, float(total_weight), out=buf)
|
||||
if countholder[0] % 100 == 0:
|
||||
@@ -56,8 +64,8 @@ def check_synced(localval, comm=None):
|
||||
comm = comm or MPI.COMM_WORLD
|
||||
vals = comm.gather(localval)
|
||||
if comm.rank == 0:
|
||||
assert all(val==vals[0] for val in vals[1:])
|
||||
|
||||
assert all(val==vals[0] for val in vals[1:]),\
|
||||
f'MpiAdamOptimizer detected that different workers have different weights: {vals}'
|
||||
|
||||
@with_mpi(timeout=5)
|
||||
def test_nonfreeze():
|
||||
|
Reference in New Issue
Block a user