mpi-less ppo2 (resolving merge conflict)

This commit is contained in:
Peter Zhokhov
2018-11-08 10:36:36 -08:00
parent 310fbadba3
commit 57c23cddd6

View File

@@ -2,11 +2,15 @@ import tensorflow as tf
import functools
from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
from baselines.common.tf_util import initialize
from baselines.common.mpi_util import sync_from_root
from mpi4py import MPI
try:
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
from mpi4py import MPI
from baselines.common.mpi_util import sync_from_root
except ImportError:
MPI = None
class Model(object):
"""
@@ -88,7 +92,10 @@ class Model(object):
# 1. Get the model parameters
params = tf.trainable_variables('ppo2_model')
# 2. Build our trainer
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
if MPI is not None:
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
else:
self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
# 3. Calculate the gradients
grads_and_var = self.trainer.compute_gradients(loss, params)
grads, var = zip(*grads_and_var)
@@ -116,7 +123,7 @@ class Model(object):
self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
if MPI.COMM_WORLD.Get_rank() == 0:
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
initialize()
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
sync_from_root(sess, global_variables) #pylint: disable=E1101