mpi-less ppo2 (resolving merge conflict)
This commit is contained in:
@@ -2,11 +2,15 @@ import tensorflow as tf
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
|
||||||
from baselines.common.tf_util import initialize
|
from baselines.common.tf_util import initialize
|
||||||
from baselines.common.mpi_util import sync_from_root
|
|
||||||
|
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
try:
|
||||||
|
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||||
|
from mpi4py import MPI
|
||||||
|
from baselines.common.mpi_util import sync_from_root
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
"""
|
"""
|
||||||
@@ -88,7 +92,10 @@ class Model(object):
|
|||||||
# 1. Get the model parameters
|
# 1. Get the model parameters
|
||||||
params = tf.trainable_variables('ppo2_model')
|
params = tf.trainable_variables('ppo2_model')
|
||||||
# 2. Build our trainer
|
# 2. Build our trainer
|
||||||
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
if MPI is not None:
|
||||||
|
self.trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||||
|
else:
|
||||||
|
self.trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||||
# 3. Calculate the gradients
|
# 3. Calculate the gradients
|
||||||
grads_and_var = self.trainer.compute_gradients(loss, params)
|
grads_and_var = self.trainer.compute_gradients(loss, params)
|
||||||
grads, var = zip(*grads_and_var)
|
grads, var = zip(*grads_and_var)
|
||||||
@@ -116,7 +123,7 @@ class Model(object):
|
|||||||
self.save = functools.partial(save_variables, sess=sess)
|
self.save = functools.partial(save_variables, sess=sess)
|
||||||
self.load = functools.partial(load_variables, sess=sess)
|
self.load = functools.partial(load_variables, sess=sess)
|
||||||
|
|
||||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
initialize()
|
initialize()
|
||||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||||
|
Reference in New Issue
Block a user