Compare commits
21 Commits
peterz_lea
...
peterz_mpi
Author | SHA1 | Date | |
---|---|---|---|
|
58801032fc | ||
|
b4a149a75f | ||
|
c248bf9a46 | ||
|
d1f7d12743 | ||
|
f0d49fb67d | ||
|
ef2e7246c9 | ||
|
3e3e2b7998 | ||
|
d00f3bce34 | ||
|
72aa2f1251 | ||
|
ea7a52b652 | ||
|
064c45fa76 | ||
|
6f148fdb0d | ||
|
93c7cc202c | ||
|
de36116e3b | ||
|
e2b41828af | ||
|
8e56ddeac2 | ||
|
c3bd8cea66 | ||
|
84ea7aa1fd | ||
|
88300ed54c | ||
|
583ba082a2 | ||
|
d96e20ff27 |
@@ -5,9 +5,13 @@ python:
|
|||||||
services:
|
services:
|
||||||
- docker
|
- docker
|
||||||
|
|
||||||
|
env:
|
||||||
|
- DOCKER_SUFFIX=py36-nompi
|
||||||
|
- DOCKER_SUFFIX=py36-mpi
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- pip install flake8
|
- pip install flake8
|
||||||
- docker build . -t baselines-test
|
- docker build -f test.dockerfile.${DOCKER_SUFFIX} -t baselines-test .
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- flake8 . --show-source --statistics
|
- flake8 . --show-source --statistics
|
||||||
|
25
Dockerfile
25
Dockerfile
@@ -1,25 +0,0 @@
|
|||||||
FROM ubuntu:16.04
|
|
||||||
|
|
||||||
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
|
|
||||||
ENV CODE_DIR /root/code
|
|
||||||
ENV VENV /root/venv
|
|
||||||
|
|
||||||
RUN \
|
|
||||||
pip install virtualenv && \
|
|
||||||
virtualenv $VENV --python=python3 && \
|
|
||||||
. $VENV/bin/activate && \
|
|
||||||
pip install --upgrade pip
|
|
||||||
|
|
||||||
ENV PATH=$VENV/bin:$PATH
|
|
||||||
|
|
||||||
COPY . $CODE_DIR/baselines
|
|
||||||
WORKDIR $CODE_DIR/baselines
|
|
||||||
|
|
||||||
# Clean up pycache and pyc files
|
|
||||||
RUN rm -rf __pycache__ && \
|
|
||||||
find . -name "*.pyc" -delete && \
|
|
||||||
pip install tensorflow && \
|
|
||||||
pip install -e .[test]
|
|
||||||
|
|
||||||
|
|
||||||
CMD /bin/bash
|
|
@@ -21,16 +21,16 @@ class Model(object):
|
|||||||
|
|
||||||
self.sess = sess = get_session()
|
self.sess = sess = get_session()
|
||||||
nbatch = nenvs * nsteps
|
nbatch = nenvs * nsteps
|
||||||
A = tf.placeholder(ac_space.dtype, [nbatch,] + list(ac_space.shape))
|
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
|
||||||
|
self.model = step_model = policy(nenvs, 1, sess=sess)
|
||||||
|
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
|
||||||
|
|
||||||
|
A = train_model.pdtype.sample_placeholder([None])
|
||||||
ADV = tf.placeholder(tf.float32, [nbatch])
|
ADV = tf.placeholder(tf.float32, [nbatch])
|
||||||
R = tf.placeholder(tf.float32, [nbatch])
|
R = tf.placeholder(tf.float32, [nbatch])
|
||||||
PG_LR = tf.placeholder(tf.float32, [])
|
PG_LR = tf.placeholder(tf.float32, [])
|
||||||
VF_LR = tf.placeholder(tf.float32, [])
|
VF_LR = tf.placeholder(tf.float32, [])
|
||||||
|
|
||||||
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
|
|
||||||
self.model = step_model = policy(nenvs, 1, sess=sess)
|
|
||||||
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
|
|
||||||
|
|
||||||
neglogpac = train_model.pd.neglogp(A)
|
neglogpac = train_model.pd.neglogp(A)
|
||||||
self.logits = train_model.pi
|
self.logits = train_model.pi
|
||||||
|
|
||||||
|
@@ -43,7 +43,7 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
|
|||||||
return DummyVecEnv([make_thunk(start_index)])
|
return DummyVecEnv([make_thunk(start_index)])
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None):
|
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
|
||||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
if env_type == 'atari':
|
if env_type == 'atari':
|
||||||
env = make_atari(env_id)
|
env = make_atari(env_id)
|
||||||
|
@@ -39,7 +39,7 @@ class PdType(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
def pdfromflat(self, flat):
|
def pdfromflat(self, flat):
|
||||||
return self.pdclass()(flat)
|
return self.pdclass()(flat)
|
||||||
def pdfromlatent(self, latent_vector):
|
def pdfromlatent(self, latent_vector, init_scale, init_bias):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
def param_shape(self):
|
def param_shape(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -80,6 +80,11 @@ class MultiCategoricalPdType(PdType):
|
|||||||
return MultiCategoricalPd
|
return MultiCategoricalPd
|
||||||
def pdfromflat(self, flat):
|
def pdfromflat(self, flat):
|
||||||
return MultiCategoricalPd(self.ncats, flat)
|
return MultiCategoricalPd(self.ncats, flat)
|
||||||
|
|
||||||
|
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
|
||||||
|
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
||||||
|
return self.pdfromflat(pdparam), pdparam
|
||||||
|
|
||||||
def param_shape(self):
|
def param_shape(self):
|
||||||
return [sum(self.ncats)]
|
return [sum(self.ncats)]
|
||||||
def sample_shape(self):
|
def sample_shape(self):
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from gym.spaces import Discrete, Box
|
from gym.spaces import Discrete, Box, MultiDiscrete
|
||||||
|
|
||||||
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
|
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
|
||||||
'''
|
'''
|
||||||
@@ -20,10 +21,14 @@ def observation_placeholder(ob_space, batch_size=None, name='Ob'):
|
|||||||
tensorflow placeholder tensor
|
tensorflow placeholder tensor
|
||||||
'''
|
'''
|
||||||
|
|
||||||
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box), \
|
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
|
||||||
'Can only deal with Discrete and Box observation spaces for now'
|
'Can only deal with Discrete and Box observation spaces for now'
|
||||||
|
|
||||||
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name)
|
dtype = ob_space.dtype
|
||||||
|
if dtype == np.int8:
|
||||||
|
dtype = np.uint8
|
||||||
|
|
||||||
|
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
|
||||||
|
|
||||||
|
|
||||||
def observation_input(ob_space, batch_size=None, name='Ob'):
|
def observation_input(ob_space, batch_size=None, name='Ob'):
|
||||||
@@ -48,9 +53,12 @@ def encode_observation(ob_space, placeholder):
|
|||||||
'''
|
'''
|
||||||
if isinstance(ob_space, Discrete):
|
if isinstance(ob_space, Discrete):
|
||||||
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
|
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
|
||||||
|
|
||||||
elif isinstance(ob_space, Box):
|
elif isinstance(ob_space, Box):
|
||||||
return tf.to_float(placeholder)
|
return tf.to_float(placeholder)
|
||||||
|
elif isinstance(ob_space, MultiDiscrete):
|
||||||
|
placeholder = tf.cast(placeholder, tf.int32)
|
||||||
|
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
|
||||||
|
return tf.concat(one_hots, axis=-1)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -1,7 +1,11 @@
|
|||||||
from mpi4py import MPI
|
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
|
|
||||||
class MpiAdam(object):
|
class MpiAdam(object):
|
||||||
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
||||||
@@ -16,16 +20,19 @@ class MpiAdam(object):
|
|||||||
self.t = 0
|
self.t = 0
|
||||||
self.setfromflat = U.SetFromFlat(var_list)
|
self.setfromflat = U.SetFromFlat(var_list)
|
||||||
self.getflat = U.GetFlat(var_list)
|
self.getflat = U.GetFlat(var_list)
|
||||||
self.comm = MPI.COMM_WORLD if comm is None else comm
|
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
|
||||||
|
|
||||||
def update(self, localg, stepsize):
|
def update(self, localg, stepsize):
|
||||||
if self.t % 100 == 0:
|
if self.t % 100 == 0:
|
||||||
self.check_synced()
|
self.check_synced()
|
||||||
localg = localg.astype('float32')
|
localg = localg.astype('float32')
|
||||||
|
if self.comm is not None:
|
||||||
globalg = np.zeros_like(localg)
|
globalg = np.zeros_like(localg)
|
||||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||||
if self.scale_grad_by_procs:
|
if self.scale_grad_by_procs:
|
||||||
globalg /= self.comm.Get_size()
|
globalg /= self.comm.Get_size()
|
||||||
|
else:
|
||||||
|
globalg = np.copy(localg)
|
||||||
|
|
||||||
self.t += 1
|
self.t += 1
|
||||||
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
|
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
|
||||||
@@ -35,11 +42,15 @@ class MpiAdam(object):
|
|||||||
self.setfromflat(self.getflat() + step)
|
self.setfromflat(self.getflat() + step)
|
||||||
|
|
||||||
def sync(self):
|
def sync(self):
|
||||||
|
if self.comm is None:
|
||||||
|
return
|
||||||
theta = self.getflat()
|
theta = self.getflat()
|
||||||
self.comm.Bcast(theta, root=0)
|
self.comm.Bcast(theta, root=0)
|
||||||
self.setfromflat(theta)
|
self.setfromflat(theta)
|
||||||
|
|
||||||
def check_synced(self):
|
def check_synced(self):
|
||||||
|
if self.comm is None:
|
||||||
|
return
|
||||||
if self.comm.Get_rank() == 0: # this is root
|
if self.comm.Get_rank() == 0: # this is root
|
||||||
theta = self.getflat()
|
theta = self.getflat()
|
||||||
self.comm.Bcast(theta, root=0)
|
self.comm.Bcast(theta, root=0)
|
||||||
@@ -63,17 +74,30 @@ def test_MpiAdam():
|
|||||||
do_update = U.function([], loss, updates=[update_op])
|
do_update = U.function([], loss, updates=[update_op])
|
||||||
|
|
||||||
tf.get_default_session().run(tf.global_variables_initializer())
|
tf.get_default_session().run(tf.global_variables_initializer())
|
||||||
|
losslist_ref = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
print(i,do_update())
|
l = do_update()
|
||||||
|
print(i, l)
|
||||||
|
losslist_ref.append(l)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tf.set_random_seed(0)
|
tf.set_random_seed(0)
|
||||||
tf.get_default_session().run(tf.global_variables_initializer())
|
tf.get_default_session().run(tf.global_variables_initializer())
|
||||||
|
|
||||||
var_list = [a,b]
|
var_list = [a,b]
|
||||||
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
|
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
|
||||||
adam = MpiAdam(var_list)
|
adam = MpiAdam(var_list)
|
||||||
|
|
||||||
|
losslist_test = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
l,g = lossandgrad()
|
l,g = lossandgrad()
|
||||||
adam.update(g, stepsize)
|
adam.update(g, stepsize)
|
||||||
print(i,l)
|
print(i,l)
|
||||||
|
losslist_test.append(l)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_MpiAdam()
|
||||||
|
@@ -1,4 +1,8 @@
|
|||||||
|
try:
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
|
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
|
||||||
|
|
||||||
class RunningMeanStd(object):
|
class RunningMeanStd(object):
|
||||||
@@ -39,6 +43,7 @@ class RunningMeanStd(object):
|
|||||||
n = int(np.prod(self.shape))
|
n = int(np.prod(self.shape))
|
||||||
totalvec = np.zeros(n*2+1, 'float64')
|
totalvec = np.zeros(n*2+1, 'float64')
|
||||||
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
|
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
|
||||||
|
if MPI is not None:
|
||||||
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
||||||
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
|
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from gym import Env
|
from gym import Env
|
||||||
from gym.spaces import Discrete, Box
|
from gym.spaces import MultiDiscrete, Discrete, Box
|
||||||
|
|
||||||
|
|
||||||
class IdentityEnv(Env):
|
class IdentityEnv(Env):
|
||||||
@@ -53,6 +53,19 @@ class DiscreteIdentityEnv(IdentityEnv):
|
|||||||
def _get_reward(self, actions):
|
def _get_reward(self, actions):
|
||||||
return 1 if self.state == actions else 0
|
return 1 if self.state == actions else 0
|
||||||
|
|
||||||
|
class MultiDiscreteIdentityEnv(IdentityEnv):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims,
|
||||||
|
episode_len=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.action_space = MultiDiscrete(dims)
|
||||||
|
super().__init__(episode_len=episode_len)
|
||||||
|
|
||||||
|
def _get_reward(self, actions):
|
||||||
|
return 1 if all(self.state == actions) else 0
|
||||||
|
|
||||||
|
|
||||||
class BoxIdentityEnv(IdentityEnv):
|
class BoxIdentityEnv(IdentityEnv):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv
|
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv
|
||||||
from baselines.run import get_learn_function
|
from baselines.run import get_learn_function
|
||||||
from baselines.common.tests.util import simple_test
|
from baselines.common.tests.util import simple_test
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ learn_kwargs = {
|
|||||||
|
|
||||||
|
|
||||||
algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
|
algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
|
||||||
|
algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi']
|
||||||
algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
|
algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@@ -38,6 +39,21 @@ def test_discrete_identity(alg):
|
|||||||
env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
|
env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
|
||||||
simple_test(env_fn, learn_fn, 0.9)
|
simple_test(env_fn, learn_fn, 0.9)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("alg", algos_multidisc)
|
||||||
|
def test_multidiscrete_identity(alg):
|
||||||
|
'''
|
||||||
|
Test if the algorithm (with an mlp policy)
|
||||||
|
can learn an identity transformation (i.e. return observation as an action)
|
||||||
|
'''
|
||||||
|
|
||||||
|
kwargs = learn_kwargs[alg]
|
||||||
|
kwargs.update(common_kwargs)
|
||||||
|
|
||||||
|
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
|
||||||
|
env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100)
|
||||||
|
simple_test(env_fn, learn_fn, 0.9)
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("alg", algos_cont)
|
@pytest.mark.parametrize("alg", algos_cont)
|
||||||
def test_continuous_identity(alg):
|
def test_continuous_identity(alg):
|
||||||
@@ -55,5 +71,5 @@ def test_continuous_identity(alg):
|
|||||||
simple_test(env_fn, learn_fn, -0.1)
|
simple_test(env_fn, learn_fn, -0.1)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_continuous_identity('ddpg')
|
test_multidiscrete_identity('acktr')
|
||||||
|
|
||||||
|
@@ -20,8 +20,11 @@ class DummyVecEnv(VecEnv):
|
|||||||
env = self.envs[0]
|
env = self.envs[0]
|
||||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
|
if isinstance(obs_space, spaces.MultiDiscrete):
|
||||||
|
obs_space.shape = obs_space.shape[0]
|
||||||
|
|
||||||
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
||||||
|
|
||||||
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
||||||
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
||||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||||
|
@@ -7,13 +7,16 @@ from baselines.ddpg.ddpg_learner import DDPG
|
|||||||
from baselines.ddpg.models import Actor, Critic
|
from baselines.ddpg.models import Actor, Critic
|
||||||
from baselines.ddpg.memory import Memory
|
from baselines.ddpg.memory import Memory
|
||||||
from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||||
|
from baselines.common import set_global_seeds
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
|
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mpi4py import MPI
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
def learn(network, env,
|
def learn(network, env,
|
||||||
seed=None,
|
seed=None,
|
||||||
@@ -41,6 +44,7 @@ def learn(network, env,
|
|||||||
param_noise_adaption_interval=50,
|
param_noise_adaption_interval=50,
|
||||||
**network_kwargs):
|
**network_kwargs):
|
||||||
|
|
||||||
|
set_global_seeds(seed)
|
||||||
|
|
||||||
if total_timesteps is not None:
|
if total_timesteps is not None:
|
||||||
assert nb_epochs is None
|
assert nb_epochs is None
|
||||||
@@ -48,7 +52,11 @@ def learn(network, env,
|
|||||||
else:
|
else:
|
||||||
nb_epochs = 500
|
nb_epochs = 500
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
else:
|
||||||
|
rank = 0
|
||||||
|
|
||||||
nb_actions = env.action_space.shape[-1]
|
nb_actions = env.action_space.shape[-1]
|
||||||
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
||||||
|
|
||||||
@@ -199,7 +207,11 @@ def learn(network, env,
|
|||||||
eval_episode_rewards_history.append(eval_episode_reward[d])
|
eval_episode_rewards_history.append(eval_episode_reward[d])
|
||||||
eval_episode_reward[d] = 0.0
|
eval_episode_reward[d] = 0.0
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
mpi_size = MPI.COMM_WORLD.Get_size()
|
mpi_size = MPI.COMM_WORLD.Get_size()
|
||||||
|
else:
|
||||||
|
mpi_size = 1
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
# XXX shouldn't call np.mean on variable length lists
|
# XXX shouldn't call np.mean on variable length lists
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
@@ -233,7 +245,10 @@ def learn(network, env,
|
|||||||
else:
|
else:
|
||||||
raise ValueError('expected scalar, got %s'%x)
|
raise ValueError('expected scalar, got %s'%x)
|
||||||
|
|
||||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
|
combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
|
||||||
|
if MPI is not None:
|
||||||
|
combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
|
||||||
|
|
||||||
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
||||||
|
|
||||||
# Total statistics.
|
# Total statistics.
|
||||||
|
@@ -9,7 +9,10 @@ from baselines import logger
|
|||||||
from baselines.common.mpi_adam import MpiAdam
|
from baselines.common.mpi_adam import MpiAdam
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||||
|
try:
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
def normalize(x, stats):
|
def normalize(x, stats):
|
||||||
if stats is None:
|
if stats is None:
|
||||||
@@ -358,6 +361,11 @@ class DDPG(object):
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
def adapt_param_noise(self):
|
def adapt_param_noise(self):
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
if self.param_noise is None:
|
if self.param_noise is None:
|
||||||
return 0.
|
return 0.
|
||||||
|
|
||||||
@@ -371,7 +379,16 @@ class DDPG(object):
|
|||||||
self.param_noise_stddev: self.param_noise.current_stddev,
|
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||||
|
else:
|
||||||
|
mean_distance = distance
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
|
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||||
|
else:
|
||||||
|
mean_distance = distance
|
||||||
|
|
||||||
self.param_noise.adapt(mean_distance)
|
self.param_noise.adapt(mean_distance)
|
||||||
return mean_distance
|
return mean_distance
|
||||||
|
|
||||||
|
@@ -169,6 +169,8 @@ def learn(env,
|
|||||||
to 1.0. If set to None equals to total_timesteps.
|
to 1.0. If set to None equals to total_timesteps.
|
||||||
prioritized_replay_eps: float
|
prioritized_replay_eps: float
|
||||||
epsilon to add to the TD errors when updating priorities.
|
epsilon to add to the TD errors when updating priorities.
|
||||||
|
param_noise: bool
|
||||||
|
whether or not to use parameter space noise (https://arxiv.org/abs/1706.01905)
|
||||||
callback: (locals, globals) -> None
|
callback: (locals, globals) -> None
|
||||||
function called at every steps with state of the algorithm.
|
function called at every steps with state of the algorithm.
|
||||||
If callback returns true training stops.
|
If callback returns true training stops.
|
||||||
|
@@ -18,11 +18,11 @@ class TfInput(object):
|
|||||||
"""Return the tf variable(s) representing the possibly postprocessed value
|
"""Return the tf variable(s) representing the possibly postprocessed value
|
||||||
of placeholder(s).
|
of placeholder(s).
|
||||||
"""
|
"""
|
||||||
raise NotImplemented()
|
raise NotImplementedError
|
||||||
|
|
||||||
def make_feed_dict(data):
|
def make_feed_dict(data):
|
||||||
"""Given data input it to the placeholder(s)."""
|
"""Given data input it to the placeholder(s)."""
|
||||||
raise NotImplemented()
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderTfInput(TfInput):
|
class PlaceholderTfInput(TfInput):
|
||||||
|
@@ -10,11 +10,15 @@ from baselines.common import explained_variance, set_global_seeds
|
|||||||
from baselines.common.policies import build_policy
|
from baselines.common.policies import build_policy
|
||||||
from baselines.common.runners import AbstractEnvRunner
|
from baselines.common.runners import AbstractEnvRunner
|
||||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
from baselines.common.tf_util import initialize
|
|
||||||
from baselines.common.mpi_util import sync_from_root
|
from baselines.common.mpi_util import sync_from_root
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
|
from baselines.common.tf_util import initialize
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
"""
|
"""
|
||||||
@@ -93,7 +97,10 @@ class Model(object):
|
|||||||
# 1. Get the model parameters
|
# 1. Get the model parameters
|
||||||
params = tf.trainable_variables('ppo2_model')
|
params = tf.trainable_variables('ppo2_model')
|
||||||
# 2. Build our trainer
|
# 2. Build our trainer
|
||||||
|
if MPI is not None:
|
||||||
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||||
|
else:
|
||||||
|
trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||||
# 3. Calculate the gradients
|
# 3. Calculate the gradients
|
||||||
grads_and_var = trainer.compute_gradients(loss, params)
|
grads_and_var = trainer.compute_gradients(loss, params)
|
||||||
grads, var = zip(*grads_and_var)
|
grads, var = zip(*grads_and_var)
|
||||||
@@ -136,9 +143,11 @@ class Model(object):
|
|||||||
self.save = functools.partial(save_variables, sess=sess)
|
self.save = functools.partial(save_variables, sess=sess)
|
||||||
self.load = functools.partial(load_variables, sess=sess)
|
self.load = functools.partial(load_variables, sess=sess)
|
||||||
|
|
||||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
initialize()
|
initialize()
|
||||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||||
|
|
||||||
class Runner(AbstractEnvRunner):
|
class Runner(AbstractEnvRunner):
|
||||||
@@ -392,9 +401,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
|||||||
logger.logkv('time_elapsed', tnow - tfirststart)
|
logger.logkv('time_elapsed', tnow - tfirststart)
|
||||||
for (lossval, lossname) in zip(lossvals, model.loss_names):
|
for (lossval, lossname) in zip(lossvals, model.loss_names):
|
||||||
logger.logkv(lossname, lossval)
|
logger.logkv(lossname, lossval)
|
||||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
logger.dumpkvs()
|
logger.dumpkvs()
|
||||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and MPI.COMM_WORLD.Get_rank() == 0:
|
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
||||||
checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
||||||
os.makedirs(checkdir, exist_ok=True)
|
os.makedirs(checkdir, exist_ok=True)
|
||||||
savepath = osp.join(checkdir, '%.5i'%update)
|
savepath = osp.join(checkdir, '%.5i'%update)
|
||||||
|
@@ -131,7 +131,7 @@ def get_env_type(env_id):
|
|||||||
|
|
||||||
|
|
||||||
def get_default_network(env_type):
|
def get_default_network(env_type):
|
||||||
if env_type == 'atari':
|
if env_type in {'atari', 'retro'}:
|
||||||
return 'cnn'
|
return 'cnn'
|
||||||
else:
|
else:
|
||||||
return 'mlp'
|
return 'mlp'
|
||||||
|
@@ -4,7 +4,6 @@ import baselines.common.tf_util as U
|
|||||||
import tensorflow as tf, numpy as np
|
import tensorflow as tf, numpy as np
|
||||||
import time
|
import time
|
||||||
from baselines.common import colorize
|
from baselines.common import colorize
|
||||||
from mpi4py import MPI
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
from baselines.common.mpi_adam import MpiAdam
|
from baselines.common.mpi_adam import MpiAdam
|
||||||
@@ -13,6 +12,11 @@ from baselines.common.input import observation_placeholder
|
|||||||
from baselines.common.policies import build_policy
|
from baselines.common.policies import build_policy
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
def traj_segment_generator(pi, env, horizon, stochastic):
|
def traj_segment_generator(pi, env, horizon, stochastic):
|
||||||
# Initialize state variables
|
# Initialize state variables
|
||||||
t = 0
|
t = 0
|
||||||
@@ -146,9 +150,12 @@ def learn(*,
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
if MPI is not None:
|
||||||
nworkers = MPI.COMM_WORLD.Get_size()
|
nworkers = MPI.COMM_WORLD.Get_size()
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
else:
|
||||||
|
nworkers = 1
|
||||||
|
rank = 0
|
||||||
|
|
||||||
cpus_per_worker = 1
|
cpus_per_worker = 1
|
||||||
U.get_session(config=tf.ConfigProto(
|
U.get_session(config=tf.ConfigProto(
|
||||||
@@ -237,9 +244,13 @@ def learn(*,
|
|||||||
|
|
||||||
def allmean(x):
|
def allmean(x):
|
||||||
assert isinstance(x, np.ndarray)
|
assert isinstance(x, np.ndarray)
|
||||||
|
if MPI is not None:
|
||||||
out = np.empty_like(x)
|
out = np.empty_like(x)
|
||||||
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
||||||
out /= nworkers
|
out /= nworkers
|
||||||
|
else:
|
||||||
|
out = np.copy(x)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
U.initialize()
|
U.initialize()
|
||||||
@@ -247,7 +258,9 @@ def learn(*,
|
|||||||
pi.load(load_path)
|
pi.load(load_path)
|
||||||
|
|
||||||
th_init = get_flat()
|
th_init = get_flat()
|
||||||
|
if MPI is not None:
|
||||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||||
|
|
||||||
set_from_flat(th_init)
|
set_from_flat(th_init)
|
||||||
vfadam.sync()
|
vfadam.sync()
|
||||||
print("Init param sum", th_init.sum(), flush=True)
|
print("Init param sum", th_init.sum(), flush=True)
|
||||||
@@ -353,7 +366,11 @@ def learn(*,
|
|||||||
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
|
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
|
||||||
|
|
||||||
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
|
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
|
||||||
|
if MPI is not None:
|
||||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||||
|
else:
|
||||||
|
listoflrpairs = [lrlocal]
|
||||||
|
|
||||||
lens, rews = map(flatten_lists, zip(*listoflrpairs))
|
lens, rews = map(flatten_lists, zip(*listoflrpairs))
|
||||||
lenbuffer.extend(lens)
|
lenbuffer.extend(lens)
|
||||||
rewbuffer.extend(rews)
|
rewbuffer.extend(rews)
|
||||||
|
6
setup.py
6
setup.py
@@ -15,6 +15,9 @@ extras = {
|
|||||||
],
|
],
|
||||||
'bullet': [
|
'bullet': [
|
||||||
'pybullet',
|
'pybullet',
|
||||||
|
],
|
||||||
|
'mpi': [
|
||||||
|
'mpi4py'
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,7 +37,6 @@ setup(name='baselines',
|
|||||||
'joblib',
|
'joblib',
|
||||||
'dill',
|
'dill',
|
||||||
'progressbar2',
|
'progressbar2',
|
||||||
'mpi4py',
|
|
||||||
'cloudpickle',
|
'cloudpickle',
|
||||||
'click',
|
'click',
|
||||||
'opencv-python'
|
'opencv-python'
|
||||||
@@ -57,4 +59,4 @@ for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
|
|||||||
pass
|
pass
|
||||||
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
||||||
from distutils.version import StrictVersion
|
from distutils.version import StrictVersion
|
||||||
assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||||
|
16
test.dockerfile.py36-mpi
Normal file
16
test.dockerfile.py36-mpi
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
FROM python:3.6
|
||||||
|
|
||||||
|
RUN apt-get -y update && apt-get -y install ffmpeg libopenmpi-dev
|
||||||
|
ENV CODE_DIR /root/code
|
||||||
|
|
||||||
|
COPY . $CODE_DIR/baselines
|
||||||
|
WORKDIR $CODE_DIR/baselines
|
||||||
|
|
||||||
|
# Clean up pycache and pyc files
|
||||||
|
RUN rm -rf __pycache__ && \
|
||||||
|
find . -name "*.pyc" -delete && \
|
||||||
|
pip install tensorflow && \
|
||||||
|
pip install -e .[test,mpi]
|
||||||
|
|
||||||
|
|
||||||
|
CMD /bin/bash
|
16
test.dockerfile.py36-nompi
Normal file
16
test.dockerfile.py36-nompi
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
FROM python:3.6
|
||||||
|
|
||||||
|
RUN apt-get -y update && apt-get -y install ffmpeg
|
||||||
|
ENV CODE_DIR /root/code
|
||||||
|
|
||||||
|
COPY . $CODE_DIR/baselines
|
||||||
|
WORKDIR $CODE_DIR/baselines
|
||||||
|
|
||||||
|
# Clean up pycache and pyc files
|
||||||
|
RUN rm -rf __pycache__ && \
|
||||||
|
find . -name "*.pyc" -delete && \
|
||||||
|
pip install tensorflow && \
|
||||||
|
pip install -e .[test]
|
||||||
|
|
||||||
|
|
||||||
|
CMD /bin/bash
|
Reference in New Issue
Block a user