From 57c23cddd66e5007b844b52bd5767a74bba13a30 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Thu, 8 Nov 2018 10:36:36 -0800 Subject: [PATCH] mpi-less ppo2 (resolving merge conflict) --- baselines/ppo2/model.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/baselines/ppo2/model.py b/baselines/ppo2/model.py index 105a980..f4e5bb8 100644 --- a/baselines/ppo2/model.py +++ b/baselines/ppo2/model.py @@ -2,11 +2,15 @@ import tensorflow as tf import functools from baselines.common.tf_util import get_session, save_variables, load_variables -from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer from baselines.common.tf_util import initialize -from baselines.common.mpi_util import sync_from_root from mpi4py import MPI +try: + from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer + from mpi4py import MPI + from baselines.common.mpi_util import sync_from_root +except ImportError: + MPI = None class Model(object): """ @@ -88,7 +92,10 @@ class Model(object): # 1. Get the model parameters params = tf.trainable_variables('ppo2_model') # 2. Build our trainer - self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) + if MPI is not None: + self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5) + else: + self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5) # 3. Calculate the gradients grads_and_var = self.trainer.compute_gradients(loss, params) grads, var = zip(*grads_and_var) @@ -116,7 +123,7 @@ class Model(object): self.save = functools.partial(save_variables, sess=sess) self.load = functools.partial(load_variables, sess=sess) - if MPI.COMM_WORLD.Get_rank() == 0: + if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: initialize() global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="") sync_from_root(sess, global_variables) #pylint: disable=E1101