Grad clipping in MpiAdamOptimizer, transformer changes (#304)
* transformer mnist experiments * version that only builds one model * work on inverted mnist * Add grad clipping to MpiAdamOptimizer * various * transformer changes, loading * get rid of soft labels * transformer baseline * minor * experiments involving all possible training sets * vary training * minor * get ready for fine-tuning expers * lint * minor
This commit is contained in:
committed by
Peter Zhokhov
parent
5082e5d34b
commit
07cbf1e26a
@@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from baselines.common import tf_util as U
|
from baselines.common import tf_util as U
|
||||||
from baselines.common.tests.test_with_mpi import with_mpi
|
from baselines.common.tests.test_with_mpi import with_mpi
|
||||||
|
from baselines import logger
|
||||||
try:
|
try:
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -9,8 +10,9 @@ except ImportError:
|
|||||||
|
|
||||||
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||||
"""Adam optimizer that averages gradients across mpi processes."""
|
"""Adam optimizer that averages gradients across mpi processes."""
|
||||||
def __init__(self, comm, mpi_rank_weight=1, **kwargs):
|
def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs):
|
||||||
self.comm = comm
|
self.comm = comm
|
||||||
|
self.grad_clip = grad_clip
|
||||||
self.mpi_rank_weight = mpi_rank_weight
|
self.mpi_rank_weight = mpi_rank_weight
|
||||||
tf.train.AdamOptimizer.__init__(self, **kwargs)
|
tf.train.AdamOptimizer.__init__(self, **kwargs)
|
||||||
def compute_gradients(self, loss, var_list, **kwargs):
|
def compute_gradients(self, loss, var_list, **kwargs):
|
||||||
@@ -28,6 +30,12 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
|||||||
countholder = [0] # Counts how many times _collect_grads has been called
|
countholder = [0] # Counts how many times _collect_grads has been called
|
||||||
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
|
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
|
||||||
def _collect_grads(flat_grad, np_stat):
|
def _collect_grads(flat_grad, np_stat):
|
||||||
|
if self.grad_clip is not None:
|
||||||
|
gradnorm = np.linalg.norm(flat_grad)
|
||||||
|
if gradnorm > 1:
|
||||||
|
flat_grad /= gradnorm
|
||||||
|
logger.logkv_mean('gradnorm', gradnorm)
|
||||||
|
logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
|
||||||
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
||||||
np.divide(buf, float(total_weight), out=buf)
|
np.divide(buf, float(total_weight), out=buf)
|
||||||
if countholder[0] % 100 == 0:
|
if countholder[0] % 100 == 0:
|
||||||
@@ -56,8 +64,8 @@ def check_synced(localval, comm=None):
|
|||||||
comm = comm or MPI.COMM_WORLD
|
comm = comm or MPI.COMM_WORLD
|
||||||
vals = comm.gather(localval)
|
vals = comm.gather(localval)
|
||||||
if comm.rank == 0:
|
if comm.rank == 0:
|
||||||
assert all(val==vals[0] for val in vals[1:])
|
assert all(val==vals[0] for val in vals[1:]),\
|
||||||
|
f'MpiAdamOptimizer detected that different workers have different weights: {vals}'
|
||||||
|
|
||||||
@with_mpi(timeout=5)
|
@with_mpi(timeout=5)
|
||||||
def test_nonfreeze():
|
def test_nonfreeze():
|
||||||
|
Reference in New Issue
Block a user