merged master
This commit is contained in:
@@ -1,6 +1,11 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from mpi4py import MPI
|
from baselines.common import tf_util as U
|
||||||
|
from baselines.common.tests.test_with_mpi import with_mpi
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||||
"""Adam optimizer that averages gradients across mpi processes."""
|
"""Adam optimizer that averages gradients across mpi processes."""
|
||||||
@@ -13,34 +18,60 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
|||||||
flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0)
|
flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0)
|
||||||
shapes = [v.shape.as_list() for g, v in grads_and_vars]
|
shapes = [v.shape.as_list() for g, v in grads_and_vars]
|
||||||
sizes = [int(np.prod(s)) for s in shapes]
|
sizes = [int(np.prod(s)) for s in shapes]
|
||||||
|
|
||||||
num_tasks = self.comm.Get_size()
|
num_tasks = self.comm.Get_size()
|
||||||
buf = np.zeros(sum(sizes), np.float32)
|
buf = np.zeros(sum(sizes), np.float32)
|
||||||
|
|
||||||
sess = tf.get_default_session()
|
|
||||||
assert sess is not None
|
|
||||||
countholder = [0] # Counts how many times _collect_grads has been called
|
countholder = [0] # Counts how many times _collect_grads has been called
|
||||||
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
|
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
|
||||||
def _collect_grads(flat_grad):
|
def _collect_grads(flat_grad, np_stat):
|
||||||
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
||||||
np.divide(buf, float(num_tasks), out=buf)
|
np.divide(buf, float(num_tasks), out=buf)
|
||||||
if countholder[0] % 100 == 0:
|
if countholder[0] % 100 == 0:
|
||||||
check_synced(sess, self.comm, stat)
|
check_synced(np_stat, self.comm)
|
||||||
countholder[0] += 1
|
countholder[0] += 1
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
|
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad, stat], tf.float32)
|
||||||
avg_flat_grad.set_shape(flat_grad.shape)
|
avg_flat_grad.set_shape(flat_grad.shape)
|
||||||
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
|
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
|
||||||
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
|
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
|
||||||
for g, (_, v) in zip(avg_grads, grads_and_vars)]
|
for g, (_, v) in zip(avg_grads, grads_and_vars)]
|
||||||
return avg_grads_and_vars
|
return avg_grads_and_vars
|
||||||
|
|
||||||
def check_synced(sess, comm, tfstat):
|
def check_synced(localval, comm=None):
|
||||||
"""
|
"""
|
||||||
Check that 'tfstat' evaluates to the same thing on every MPI worker
|
It's common to forget to initialize your variables to the same values, or
|
||||||
|
(less commonly) if you update them in some other way than adam, to get them out of sync.
|
||||||
|
This function checks that variables on all MPI workers are the same, and raises
|
||||||
|
an AssertionError otherwise
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
comm: MPI communicator
|
||||||
|
localval: list of local variables (list of variables on current worker to be compared with the other workers)
|
||||||
"""
|
"""
|
||||||
localval = sess.run(tfstat)
|
comm = comm or MPI.COMM_WORLD
|
||||||
vals = comm.gather(localval)
|
vals = comm.gather(localval)
|
||||||
if comm.rank == 0:
|
if comm.rank == 0:
|
||||||
assert all(val==vals[0] for val in vals[1:])
|
assert all(val==vals[0] for val in vals[1:])
|
||||||
|
|
||||||
|
|
||||||
|
@with_mpi(timeout=5)
|
||||||
|
def test_nonfreeze():
|
||||||
|
np.random.seed(0)
|
||||||
|
tf.set_random_seed(0)
|
||||||
|
|
||||||
|
a = tf.Variable(np.random.randn(3).astype('float32'))
|
||||||
|
b = tf.Variable(np.random.randn(2,5).astype('float32'))
|
||||||
|
loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
|
||||||
|
|
||||||
|
stepsize = 1e-2
|
||||||
|
# for some reason the session config with inter_op_parallelism_threads was causing
|
||||||
|
# nested sess.run calls to freeze
|
||||||
|
config = tf.ConfigProto(inter_op_parallelism_threads=1)
|
||||||
|
sess = U.get_session(config=config)
|
||||||
|
update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss)
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
losslist_ref = []
|
||||||
|
for i in range(100):
|
||||||
|
l,_ = sess.run([loss, update_op])
|
||||||
|
print(i, l)
|
||||||
|
losslist_ref.append(l)
|
||||||
|
@@ -160,7 +160,6 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
|||||||
envsperbatch = nenvs // nminibatches
|
envsperbatch = nenvs // nminibatches
|
||||||
envinds = np.arange(nenvs)
|
envinds = np.arange(nenvs)
|
||||||
flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
|
flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
|
||||||
envsperbatch = nbatch_train // nsteps
|
|
||||||
for _ in range(noptepochs):
|
for _ in range(noptepochs):
|
||||||
np.random.shuffle(envinds)
|
np.random.shuffle(envinds)
|
||||||
for start in range(0, nenvs, envsperbatch):
|
for start in range(0, nenvs, envsperbatch):
|
||||||
|
Reference in New Issue
Block a user