Files
baselines/baselines/common/mpi_adam_optimizer.py
2019-11-08 15:15:38 -08:00

60 lines
2.5 KiB
Python

import numpy as np
import tensorflow as tf
try:
from mpi4py import MPI
except ImportError:
MPI = None
class MpiAdamOptimizer(tf.Module):
"""Adam optimizer that averages gradients across mpi processes."""
def __init__(self, comm, var_list):
self.var_list = var_list
self.comm = comm
self.beta1 = 0.9
self.beta2 = 0.999
self.epsilon = 1e-08
self.t = tf.Variable(0, name='step', dtype=tf.int32)
var_shapes = [v.shape.as_list() for v in var_list]
self.var_sizes = [int(np.prod(s)) for s in var_shapes]
self.flat_var_size = sum(self.var_sizes)
self.m = tf.Variable(np.zeros(self.flat_var_size, 'float32'))
self.v = tf.Variable(np.zeros(self.flat_var_size, 'float32'))
def apply_gradients(self, flat_grad, lr):
buf = np.zeros(self.flat_var_size, np.float32)
self.comm.Allreduce(flat_grad.numpy(), buf, op=MPI.SUM)
avg_flat_grad = np.divide(buf, float(self.comm.Get_size()))
self._apply_gradients(tf.constant(avg_flat_grad), lr)
if self.t.numpy() % 100 == 0:
check_synced(tf.reduce_sum(self.var_list[0]).numpy())
@tf.function
def _apply_gradients(self, avg_flat_grad, lr):
self.t.assign_add(1)
t = tf.cast(self.t, tf.float32)
a = lr * tf.math.sqrt(1 - tf.math.pow(self.beta2, t)) / (1 - tf.math.pow(self.beta1, t))
self.m.assign(self.beta1 * self.m + (1 - self.beta1) * avg_flat_grad)
self.v.assign(self.beta2 * self.v + (1 - self.beta2) * tf.math.square(avg_flat_grad))
flat_step = (- a) * self.m / (tf.math.sqrt(self.v) + self.epsilon)
var_steps = tf.split(flat_step, self.var_sizes, axis=0)
for var_step, var in zip(var_steps, self.var_list):
var.assign_add(tf.reshape(var_step, var.shape))
def check_synced(localval, comm=None):
"""
It's common to forget to initialize your variables to the same values, or
(less commonly) if you update them in some other way than adam, to get them out of sync.
This function checks that variables on all MPI workers are the same, and raises
an AssertionError otherwise
Arguments:
comm: MPI communicator
localval: list of local variables (list of variables on current worker to be compared with the other workers)
"""
comm = comm or MPI.COMM_WORLD
vals = comm.gather(localval)
if comm.rank == 0:
assert all(val==vals[0] for val in vals[1:]),\
'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals)