Compare commits
21 Commits
peterz_mpi
...
peterz_lea
Author | SHA1 | Date | |
---|---|---|---|
|
c7a0c2781a | ||
|
06c2fd2a3c | ||
|
a52dcae856 | ||
|
a8c2e643dc | ||
|
5ca31a7c25 | ||
|
35dcb6fd74 | ||
|
c1c7c469a1 | ||
|
b4869bd271 | ||
|
29cfb4a69c | ||
|
bd7c479e04 | ||
|
3ddf69c4b5 | ||
|
bfdc552521 | ||
|
bcb4d4f795 | ||
|
0c9b236475 | ||
|
01884bb0eb | ||
|
ade2d61be7 | ||
|
f6ef52a9df | ||
|
8964d5ad45 | ||
|
8624bc629c | ||
|
7b33af0395 | ||
|
4bca9158a1 |
12
.travis.yml
12
.travis.yml
@@ -5,14 +5,10 @@ python:
|
||||
services:
|
||||
- docker
|
||||
|
||||
env:
|
||||
- DOCKER_SUFFIX=py36-nompi
|
||||
- DOCKER_SUFFIX=py36-mpi
|
||||
|
||||
install:
|
||||
- pip install flake8
|
||||
- docker build -f test.dockerfile.${DOCKER_SUFFIX} -t baselines-test .
|
||||
- pip install flake8
|
||||
- docker build . -t baselines-test
|
||||
|
||||
script:
|
||||
- flake8 . --show-source --statistics
|
||||
- docker run baselines-test pytest -v .
|
||||
- flake8 . --show-source --statistics
|
||||
- docker run baselines-test pytest -v .
|
||||
|
25
Dockerfile
Normal file
25
Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM ubuntu:16.04
|
||||
|
||||
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
|
||||
ENV CODE_DIR /root/code
|
||||
ENV VENV /root/venv
|
||||
|
||||
RUN \
|
||||
pip install virtualenv && \
|
||||
virtualenv $VENV --python=python3 && \
|
||||
. $VENV/bin/activate && \
|
||||
pip install --upgrade pip
|
||||
|
||||
ENV PATH=$VENV/bin:$PATH
|
||||
|
||||
COPY . $CODE_DIR/baselines
|
||||
WORKDIR $CODE_DIR/baselines
|
||||
|
||||
# Clean up pycache and pyc files
|
||||
RUN rm -rf __pycache__ && \
|
||||
find . -name "*.pyc" -delete && \
|
||||
pip install tensorflow && \
|
||||
pip install -e .[test]
|
||||
|
||||
|
||||
CMD /bin/bash
|
@@ -0,0 +1,12 @@
|
||||
# explicitly import sub-packages to register algorithms
|
||||
|
||||
import baselines.a2c.a2c
|
||||
import baselines.acer.acer
|
||||
import baselines.acktr.acktr
|
||||
import baselines.deepq.deepq
|
||||
import baselines.ddpg.ddpg
|
||||
import baselines.ppo2.ppo2
|
||||
|
||||
# not really sure why flake8 complains only about trpo_mpi here...
|
||||
import baselines.trpo_mpi.trpo_mpi # noqa: F401
|
||||
|
||||
|
@@ -2,13 +2,12 @@ import time
|
||||
import functools
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
|
||||
from baselines.common import set_global_seeds, explained_variance
|
||||
from baselines.common import tf_util
|
||||
from baselines.common.policies import build_policy
|
||||
|
||||
|
||||
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
||||
from baselines.a2c.runner import Runner
|
||||
|
||||
@@ -114,6 +113,7 @@ class Model(object):
|
||||
tf.global_variables_initializer().run(session=sess)
|
||||
|
||||
|
||||
@registry.register('a2c')
|
||||
def learn(
|
||||
network,
|
||||
env,
|
||||
|
@@ -2,7 +2,7 @@ import time
|
||||
import functools
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.policies import build_policy
|
||||
@@ -16,6 +16,7 @@ from baselines.a2c.utils import EpisodeStats
|
||||
from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
|
||||
from baselines.acer.buffer import Buffer
|
||||
from baselines.acer.runner import Runner
|
||||
from baselines.acer.defaults import defaults
|
||||
|
||||
# remove last step
|
||||
def strip(var, nenvs, nsteps, flat = False):
|
||||
@@ -270,7 +271,7 @@ class Acer():
|
||||
logger.record_tabular(name, float(val))
|
||||
logger.dump_tabular()
|
||||
|
||||
|
||||
@registry.register('acer', defaults=defaults)
|
||||
def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
|
||||
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
|
||||
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
|
||||
|
@@ -1,4 +1,3 @@
|
||||
def atari():
|
||||
return dict(
|
||||
lrschedule='constant'
|
||||
)
|
||||
defaults = {
|
||||
'atari': dict(lrschedule='constant')
|
||||
}
|
||||
|
@@ -2,7 +2,7 @@ import os.path as osp
|
||||
import time
|
||||
import functools
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
|
||||
from baselines.common import set_global_seeds, explained_variance
|
||||
from baselines.common.policies import build_policy
|
||||
@@ -11,6 +11,7 @@ from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
from baselines.a2c.runner import Runner
|
||||
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
||||
from baselines.acktr import kfac
|
||||
from baselines.acktr.defaults import defaults
|
||||
|
||||
|
||||
class Model(object):
|
||||
@@ -21,16 +22,16 @@ class Model(object):
|
||||
|
||||
self.sess = sess = get_session()
|
||||
nbatch = nenvs * nsteps
|
||||
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
|
||||
self.model = step_model = policy(nenvs, 1, sess=sess)
|
||||
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
|
||||
|
||||
A = train_model.pdtype.sample_placeholder([None])
|
||||
A = tf.placeholder(ac_space.dtype, [nbatch,] + list(ac_space.shape))
|
||||
ADV = tf.placeholder(tf.float32, [nbatch])
|
||||
R = tf.placeholder(tf.float32, [nbatch])
|
||||
PG_LR = tf.placeholder(tf.float32, [])
|
||||
VF_LR = tf.placeholder(tf.float32, [])
|
||||
|
||||
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
|
||||
self.model = step_model = policy(nenvs, 1, sess=sess)
|
||||
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
|
||||
|
||||
neglogpac = train_model.pd.neglogp(A)
|
||||
self.logits = train_model.pi
|
||||
|
||||
@@ -90,6 +91,7 @@ class Model(object):
|
||||
self.initial_state = step_model.initial_state
|
||||
tf.global_variables_initializer().run(session=sess)
|
||||
|
||||
@registry.register('acktr', defaults=defaults)
|
||||
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
|
||||
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
|
||||
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):
|
||||
|
@@ -1,5 +1,6 @@
|
||||
def mujoco():
|
||||
return dict(
|
||||
defaults = {
|
||||
'mujoco' : dict(
|
||||
nsteps=2500,
|
||||
value_network='copy'
|
||||
)
|
||||
}
|
||||
|
@@ -16,11 +16,13 @@ from baselines.common import set_global_seeds
|
||||
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
|
||||
from baselines.common import retro_wrappers
|
||||
|
||||
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None):
|
||||
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None, frame_stack_size=1):
|
||||
"""
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||
Create a wrapped, monitored SubprocVecEnv
|
||||
"""
|
||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
@@ -38,9 +40,15 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
|
||||
|
||||
set_global_seeds(seed)
|
||||
if num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
|
||||
venv = SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
|
||||
else:
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
venv = DummyVecEnv([make_thunk(start_index)])
|
||||
|
||||
if frame_stack_size > 1:
|
||||
venv = VecFrameStack(venv, frame_stack_size)
|
||||
|
||||
return venv
|
||||
|
||||
|
||||
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
|
||||
|
@@ -39,7 +39,7 @@ class PdType(object):
|
||||
raise NotImplementedError
|
||||
def pdfromflat(self, flat):
|
||||
return self.pdclass()(flat)
|
||||
def pdfromlatent(self, latent_vector, init_scale, init_bias):
|
||||
def pdfromlatent(self, latent_vector):
|
||||
raise NotImplementedError
|
||||
def param_shape(self):
|
||||
raise NotImplementedError
|
||||
@@ -80,11 +80,6 @@ class MultiCategoricalPdType(PdType):
|
||||
return MultiCategoricalPd
|
||||
def pdfromflat(self, flat):
|
||||
return MultiCategoricalPd(self.ncats, flat)
|
||||
|
||||
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
|
||||
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
||||
return self.pdfromflat(pdparam), pdparam
|
||||
|
||||
def param_shape(self):
|
||||
return [sum(self.ncats)]
|
||||
def sample_shape(self):
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from gym.spaces import Discrete, Box, MultiDiscrete
|
||||
from gym.spaces import Discrete, Box
|
||||
|
||||
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
|
||||
'''
|
||||
@@ -21,14 +20,10 @@ def observation_placeholder(ob_space, batch_size=None, name='Ob'):
|
||||
tensorflow placeholder tensor
|
||||
'''
|
||||
|
||||
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
|
||||
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box), \
|
||||
'Can only deal with Discrete and Box observation spaces for now'
|
||||
|
||||
dtype = ob_space.dtype
|
||||
if dtype == np.int8:
|
||||
dtype = np.uint8
|
||||
|
||||
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
|
||||
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name)
|
||||
|
||||
|
||||
def observation_input(ob_space, batch_size=None, name='Ob'):
|
||||
@@ -53,12 +48,9 @@ def encode_observation(ob_space, placeholder):
|
||||
'''
|
||||
if isinstance(ob_space, Discrete):
|
||||
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
|
||||
|
||||
elif isinstance(ob_space, Box):
|
||||
return tf.to_float(placeholder)
|
||||
elif isinstance(ob_space, MultiDiscrete):
|
||||
placeholder = tf.cast(placeholder, tf.int32)
|
||||
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
|
||||
return tf.concat(one_hots, axis=-1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -1,11 +1,7 @@
|
||||
from mpi4py import MPI
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
|
||||
class MpiAdam(object):
|
||||
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
||||
@@ -20,19 +16,16 @@ class MpiAdam(object):
|
||||
self.t = 0
|
||||
self.setfromflat = U.SetFromFlat(var_list)
|
||||
self.getflat = U.GetFlat(var_list)
|
||||
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
|
||||
self.comm = MPI.COMM_WORLD if comm is None else comm
|
||||
|
||||
def update(self, localg, stepsize):
|
||||
if self.t % 100 == 0:
|
||||
self.check_synced()
|
||||
localg = localg.astype('float32')
|
||||
if self.comm is not None:
|
||||
globalg = np.zeros_like(localg)
|
||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||
if self.scale_grad_by_procs:
|
||||
globalg /= self.comm.Get_size()
|
||||
else:
|
||||
globalg = np.copy(localg)
|
||||
globalg = np.zeros_like(localg)
|
||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||
if self.scale_grad_by_procs:
|
||||
globalg /= self.comm.Get_size()
|
||||
|
||||
self.t += 1
|
||||
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
|
||||
@@ -42,15 +35,11 @@ class MpiAdam(object):
|
||||
self.setfromflat(self.getflat() + step)
|
||||
|
||||
def sync(self):
|
||||
if self.comm is None:
|
||||
return
|
||||
theta = self.getflat()
|
||||
self.comm.Bcast(theta, root=0)
|
||||
self.setfromflat(theta)
|
||||
|
||||
def check_synced(self):
|
||||
if self.comm is None:
|
||||
return
|
||||
if self.comm.Get_rank() == 0: # this is root
|
||||
theta = self.getflat()
|
||||
self.comm.Bcast(theta, root=0)
|
||||
@@ -74,30 +63,17 @@ def test_MpiAdam():
|
||||
do_update = U.function([], loss, updates=[update_op])
|
||||
|
||||
tf.get_default_session().run(tf.global_variables_initializer())
|
||||
losslist_ref = []
|
||||
for i in range(10):
|
||||
l = do_update()
|
||||
print(i, l)
|
||||
losslist_ref.append(l)
|
||||
|
||||
|
||||
print(i,do_update())
|
||||
|
||||
tf.set_random_seed(0)
|
||||
tf.get_default_session().run(tf.global_variables_initializer())
|
||||
|
||||
var_list = [a,b]
|
||||
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
|
||||
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
|
||||
adam = MpiAdam(var_list)
|
||||
|
||||
losslist_test = []
|
||||
for i in range(10):
|
||||
l,g = lossandgrad()
|
||||
adam.update(g, stepsize)
|
||||
print(i,l)
|
||||
losslist_test.append(l)
|
||||
|
||||
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_MpiAdam()
|
||||
|
@@ -1,8 +1,4 @@
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
from mpi4py import MPI
|
||||
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
|
||||
|
||||
class RunningMeanStd(object):
|
||||
@@ -43,8 +39,7 @@ class RunningMeanStd(object):
|
||||
n = int(np.prod(self.shape))
|
||||
totalvec = np.zeros(n*2+1, 'float64')
|
||||
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
|
||||
if MPI is not None:
|
||||
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
||||
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
||||
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
|
||||
|
||||
@U.in_session
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
from abc import abstractmethod
|
||||
from gym import Env
|
||||
from gym.spaces import MultiDiscrete, Discrete, Box
|
||||
from gym.spaces import Discrete, Box
|
||||
|
||||
|
||||
class IdentityEnv(Env):
|
||||
@@ -53,19 +53,6 @@ class DiscreteIdentityEnv(IdentityEnv):
|
||||
def _get_reward(self, actions):
|
||||
return 1 if self.state == actions else 0
|
||||
|
||||
class MultiDiscreteIdentityEnv(IdentityEnv):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
episode_len=None,
|
||||
):
|
||||
|
||||
self.action_space = MultiDiscrete(dims)
|
||||
super().__init__(episode_len=episode_len)
|
||||
|
||||
def _get_reward(self, actions):
|
||||
return 1 if all(self.state == actions) else 0
|
||||
|
||||
|
||||
class BoxIdentityEnv(IdentityEnv):
|
||||
def __init__(
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv
|
||||
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv
|
||||
from baselines.run import get_learn_function
|
||||
from baselines.common.tests.util import simple_test
|
||||
|
||||
@@ -21,7 +21,6 @@ learn_kwargs = {
|
||||
|
||||
|
||||
algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
|
||||
algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi']
|
||||
algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -39,21 +38,6 @@ def test_discrete_identity(alg):
|
||||
env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
|
||||
simple_test(env_fn, learn_fn, 0.9)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alg", algos_multidisc)
|
||||
def test_multidiscrete_identity(alg):
|
||||
'''
|
||||
Test if the algorithm (with an mlp policy)
|
||||
can learn an identity transformation (i.e. return observation as an action)
|
||||
'''
|
||||
|
||||
kwargs = learn_kwargs[alg]
|
||||
kwargs.update(common_kwargs)
|
||||
|
||||
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
|
||||
env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100)
|
||||
simple_test(env_fn, learn_fn, 0.9)
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alg", algos_cont)
|
||||
def test_continuous_identity(alg):
|
||||
@@ -71,5 +55,5 @@ def test_continuous_identity(alg):
|
||||
simple_test(env_fn, learn_fn, -0.1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_multidiscrete_identity('acktr')
|
||||
test_continuous_identity('ddpg')
|
||||
|
||||
|
@@ -20,11 +20,8 @@ class DummyVecEnv(VecEnv):
|
||||
env = self.envs[0]
|
||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||
obs_space = env.observation_space
|
||||
if isinstance(obs_space, spaces.MultiDiscrete):
|
||||
obs_space.shape = obs_space.shape[0]
|
||||
|
||||
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
||||
|
||||
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
||||
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
|
@@ -7,17 +7,14 @@ from baselines.ddpg.ddpg_learner import DDPG
|
||||
from baselines.ddpg.models import Actor, Critic
|
||||
from baselines.ddpg.memory import Memory
|
||||
from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
from baselines.common import set_global_seeds
|
||||
|
||||
import baselines.common.tf_util as U
|
||||
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
import numpy as np
|
||||
from mpi4py import MPI
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
@registry.register('ddpg')
|
||||
def learn(network, env,
|
||||
seed=None,
|
||||
total_timesteps=None,
|
||||
@@ -44,7 +41,6 @@ def learn(network, env,
|
||||
param_noise_adaption_interval=50,
|
||||
**network_kwargs):
|
||||
|
||||
set_global_seeds(seed)
|
||||
|
||||
if total_timesteps is not None:
|
||||
assert nb_epochs is None
|
||||
@@ -52,11 +48,7 @@ def learn(network, env,
|
||||
else:
|
||||
nb_epochs = 500
|
||||
|
||||
if MPI is not None:
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
nb_actions = env.action_space.shape[-1]
|
||||
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
|
||||
|
||||
@@ -207,11 +199,7 @@ def learn(network, env,
|
||||
eval_episode_rewards_history.append(eval_episode_reward[d])
|
||||
eval_episode_reward[d] = 0.0
|
||||
|
||||
if MPI is not None:
|
||||
mpi_size = MPI.COMM_WORLD.Get_size()
|
||||
else:
|
||||
mpi_size = 1
|
||||
|
||||
mpi_size = MPI.COMM_WORLD.Get_size()
|
||||
# Log stats.
|
||||
# XXX shouldn't call np.mean on variable length lists
|
||||
duration = time.time() - start_time
|
||||
@@ -245,10 +233,7 @@ def learn(network, env,
|
||||
else:
|
||||
raise ValueError('expected scalar, got %s'%x)
|
||||
|
||||
combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
|
||||
if MPI is not None:
|
||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
|
||||
|
||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
|
||||
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
||||
|
||||
# Total statistics.
|
||||
|
@@ -9,10 +9,7 @@ from baselines import logger
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
import baselines.common.tf_util as U
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
from mpi4py import MPI
|
||||
|
||||
def normalize(x, stats):
|
||||
if stats is None:
|
||||
@@ -361,11 +358,6 @@ class DDPG(object):
|
||||
return stats
|
||||
|
||||
def adapt_param_noise(self):
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
if self.param_noise is None:
|
||||
return 0.
|
||||
|
||||
@@ -379,16 +371,7 @@ class DDPG(object):
|
||||
self.param_noise_stddev: self.param_noise.current_stddev,
|
||||
})
|
||||
|
||||
if MPI is not None:
|
||||
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||
else:
|
||||
mean_distance = distance
|
||||
|
||||
if MPI is not None:
|
||||
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||
else:
|
||||
mean_distance = distance
|
||||
|
||||
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
|
||||
self.param_noise.adapt(mean_distance)
|
||||
return mean_distance
|
||||
|
||||
|
@@ -8,7 +8,7 @@ import numpy as np
|
||||
|
||||
import baselines.common.tf_util as U
|
||||
from baselines.common.tf_util import load_variables, save_variables
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
from baselines.common.schedules import LinearSchedule
|
||||
from baselines.common import set_global_seeds
|
||||
|
||||
@@ -18,6 +18,7 @@ from baselines.deepq.utils import ObservationInput
|
||||
|
||||
from baselines.common.tf_util import get_session
|
||||
from baselines.deepq.models import build_q_func
|
||||
from baselines.deepq.defaults import defaults
|
||||
|
||||
|
||||
class ActWrapper(object):
|
||||
@@ -92,6 +93,7 @@ def load_act(path):
|
||||
return ActWrapper.load_act(path)
|
||||
|
||||
|
||||
@registry.register('deepq', supports_vecenv=False, defaults=defaults)
|
||||
def learn(env,
|
||||
network,
|
||||
seed=None,
|
||||
@@ -169,8 +171,6 @@ def learn(env,
|
||||
to 1.0. If set to None equals to total_timesteps.
|
||||
prioritized_replay_eps: float
|
||||
epsilon to add to the TD errors when updating priorities.
|
||||
param_noise: bool
|
||||
whether or not to use parameter space noise (https://arxiv.org/abs/1706.01905)
|
||||
callback: (locals, globals) -> None
|
||||
function called at every steps with state of the algorithm.
|
||||
If callback returns true training stops.
|
||||
|
@@ -16,6 +16,8 @@ def atari():
|
||||
dueling=True
|
||||
)
|
||||
|
||||
def retro():
|
||||
return atari()
|
||||
|
||||
defaults = {
|
||||
'atari': atari(),
|
||||
'retro': atari()
|
||||
}
|
||||
|
@@ -18,11 +18,11 @@ class TfInput(object):
|
||||
"""Return the tf variable(s) representing the possibly postprocessed value
|
||||
of placeholder(s).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplemented()
|
||||
|
||||
def make_feed_dict(data):
|
||||
"""Given data input it to the placeholder(s)."""
|
||||
raise NotImplementedError
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class PlaceholderTfInput(TfInput):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
def mujoco():
|
||||
return dict(
|
||||
defaults = {
|
||||
'mujoco': dict(
|
||||
nsteps=2048,
|
||||
nminibatches=32,
|
||||
lam=0.95,
|
||||
@@ -10,13 +10,13 @@ def mujoco():
|
||||
lr=lambda f: 3e-4 * f,
|
||||
cliprange=0.2,
|
||||
value_network='copy'
|
||||
)
|
||||
),
|
||||
|
||||
def atari():
|
||||
return dict(
|
||||
'atari': dict(
|
||||
nsteps=128, nminibatches=4,
|
||||
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
|
||||
ent_coef=.01,
|
||||
lr=lambda f : f * 2.5e-4,
|
||||
cliprange=lambda f : f * 0.1,
|
||||
)
|
||||
}
|
||||
|
@@ -4,21 +4,18 @@ import functools
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import tensorflow as tf
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
from collections import deque
|
||||
from baselines.common import explained_variance, set_global_seeds
|
||||
from baselines.common.policies import build_policy
|
||||
from baselines.common.runners import AbstractEnvRunner
|
||||
from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
|
||||
try:
|
||||
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
|
||||
from mpi4py import MPI
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
from mpi4py import MPI
|
||||
from baselines.common.tf_util import initialize
|
||||
from baselines.common.mpi_util import sync_from_root
|
||||
from baselines.ppo2.defaults import defaults
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
@@ -97,10 +94,7 @@ class Model(object):
|
||||
# 1. Get the model parameters
|
||||
params = tf.trainable_variables('ppo2_model')
|
||||
# 2. Build our trainer
|
||||
if MPI is not None:
|
||||
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
else:
|
||||
trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
|
||||
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
|
||||
# 3. Calculate the gradients
|
||||
grads_and_var = trainer.compute_gradients(loss, params)
|
||||
grads, var = zip(*grads_and_var)
|
||||
@@ -143,12 +137,10 @@ class Model(object):
|
||||
self.save = functools.partial(save_variables, sess=sess)
|
||||
self.load = functools.partial(load_variables, sess=sess)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
initialize()
|
||||
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
|
||||
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
sync_from_root(sess, global_variables) #pylint: disable=E1101
|
||||
|
||||
class Runner(AbstractEnvRunner):
|
||||
"""
|
||||
@@ -227,6 +219,7 @@ def constfn(val):
|
||||
return val
|
||||
return f
|
||||
|
||||
@registry.register('ppo2', defaults=defaults)
|
||||
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
|
||||
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
|
||||
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
|
||||
@@ -401,9 +394,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
logger.logkv('time_elapsed', tnow - tfirststart)
|
||||
for (lossval, lossname) in zip(lossvals, model.loss_names):
|
||||
logger.logkv(lossname, lossval)
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
logger.dumpkvs()
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and MPI.COMM_WORLD.Get_rank() == 0:
|
||||
checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
||||
os.makedirs(checkdir, exist_ok=True)
|
||||
savepath = osp.join(checkdir, '%.5i'%update)
|
||||
|
39
baselines/registry.py
Normal file
39
baselines/registry.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Registry of algorithms that keeps track of algorithms supported environments and
|
||||
# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc)
|
||||
#
|
||||
# Example usage:
|
||||
#
|
||||
# from baselines import registry
|
||||
#
|
||||
# @registry.register('fancy_algorithm', supports_vecenv=False)
|
||||
# def learn(env, network):
|
||||
# return
|
||||
#
|
||||
# for algo_name, algo_entry in registry.registry.items():
|
||||
# if not algo_entry['supports_vecenv']:
|
||||
# print(f'{algo_name} does not support vecenvs')
|
||||
# # should print "fancy_algorithm does not support vecenvs" (among other ones)"f
|
||||
|
||||
|
||||
from baselines import logger
|
||||
|
||||
registry = {}
|
||||
|
||||
def register(name, supports_vecenv=True, defaults={}):
|
||||
def get_fn_entrypoint(fn):
|
||||
import inspect
|
||||
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
|
||||
|
||||
def _thunk(learn_fn):
|
||||
old_entry = registry.get(name)
|
||||
if old_entry is not None:
|
||||
logger.warn('Re-registering learn function {} (old entrypoint {}, new entrypoint {}) '.format(
|
||||
name, get_fn_entrypoint(old_entry['fn']), get_fn_entrypoint(learn_fn)))
|
||||
|
||||
registry[name] = dict(
|
||||
fn = learn_fn,
|
||||
supports_vecenv=supports_vecenv,
|
||||
defaults=defaults,
|
||||
)
|
||||
return learn_fn
|
||||
return _thunk
|
@@ -3,16 +3,12 @@ import multiprocessing
|
||||
import os.path as osp
|
||||
import gym
|
||||
from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines.common.tf_util import get_session
|
||||
from baselines import logger
|
||||
from importlib import import_module
|
||||
|
||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines import logger
|
||||
from baselines.registry import registry
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
@@ -89,28 +85,20 @@ def build_env(args):
|
||||
seed = args.seed
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
assert alg in registry, 'Unknown algorithm {}'.format(alg)
|
||||
|
||||
if env_type in {'atari', 'retro'}:
|
||||
if alg == 'deepq':
|
||||
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
|
||||
elif alg == 'trpo_mpi':
|
||||
env = make_env(env_id, env_type, seed=seed)
|
||||
else:
|
||||
frame_stack_size = 4
|
||||
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
|
||||
env = VecFrameStack(env, frame_stack_size)
|
||||
|
||||
frame_stack_size = 4
|
||||
else:
|
||||
config = tf.ConfigProto(allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1)
|
||||
config.gpu_options.allow_growth = True
|
||||
get_session(config=config)
|
||||
frame_stack_size = 1
|
||||
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
||||
if registry[alg]['supports_vecenv']:
|
||||
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale, frame_stack_size=frame_stack_size)
|
||||
else:
|
||||
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': frame_stack_size > 1})
|
||||
|
||||
if env_type == 'mujoco':
|
||||
env = VecNormalize(env)
|
||||
if env_type == 'mujoco' and registry[alg]['supports_vecenv']:
|
||||
env = VecNormalize(env)
|
||||
|
||||
return env
|
||||
|
||||
@@ -131,35 +119,32 @@ def get_env_type(env_id):
|
||||
|
||||
|
||||
def get_default_network(env_type):
|
||||
if env_type in {'atari', 'retro'}:
|
||||
if env_type == 'atari':
|
||||
return 'cnn'
|
||||
else:
|
||||
return 'mlp'
|
||||
|
||||
def get_alg_module(alg, submodule=None):
|
||||
submodule = submodule or alg
|
||||
try:
|
||||
# first try to import the alg module from baselines
|
||||
alg_module = import_module('.'.join(['baselines', alg, submodule]))
|
||||
except ImportError:
|
||||
# then from rl_algs
|
||||
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
||||
import inspect
|
||||
entry = registry.get(alg)
|
||||
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
|
||||
module = inspect.getmodule(entry['fn']).__name__
|
||||
if submodule is not None:
|
||||
module = '.'.join([module, submodule])
|
||||
return module
|
||||
|
||||
return alg_module
|
||||
|
||||
|
||||
def get_learn_function(alg):
|
||||
return get_alg_module(alg).learn
|
||||
entry = registry.get(alg)
|
||||
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
|
||||
return entry['fn']
|
||||
|
||||
|
||||
def get_learn_function_defaults(alg, env_type):
|
||||
try:
|
||||
alg_defaults = get_alg_module(alg, 'defaults')
|
||||
kwargs = getattr(alg_defaults, env_type)()
|
||||
except (ImportError, AttributeError):
|
||||
kwargs = {}
|
||||
return kwargs
|
||||
|
||||
entry = registry.get(alg)
|
||||
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
|
||||
return entry['defaults'].get(env_type, {})
|
||||
|
||||
|
||||
def parse_cmdline_kwargs(args):
|
||||
@@ -193,6 +178,7 @@ def main():
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
model, env = train(args, extra_args)
|
||||
|
||||
env.close()
|
||||
|
||||
if args.save_path is not None and rank == 0:
|
||||
|
@@ -28,3 +28,8 @@ def mujoco():
|
||||
vf_stepsize=1e-3,
|
||||
normalize_observations=True,
|
||||
)
|
||||
|
||||
defaults = {
|
||||
'atari': atari(),
|
||||
'mujoco': mujoco(),
|
||||
}
|
||||
|
@@ -1,9 +1,10 @@
|
||||
from baselines.common import explained_variance, zipsame, dataset
|
||||
from baselines import logger
|
||||
from baselines import logger, registry
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf, numpy as np
|
||||
import time
|
||||
from baselines.common import colorize
|
||||
from mpi4py import MPI
|
||||
from collections import deque
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
@@ -12,10 +13,7 @@ from baselines.common.input import observation_placeholder
|
||||
from baselines.common.policies import build_policy
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
from baselines.trpo_mpi.defaults import defaults
|
||||
|
||||
def traj_segment_generator(pi, env, horizon, stochastic):
|
||||
# Initialize state variables
|
||||
@@ -86,6 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
|
||||
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
||||
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
||||
|
||||
@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults)
|
||||
def learn(*,
|
||||
network,
|
||||
env,
|
||||
@@ -150,12 +149,9 @@ def learn(*,
|
||||
|
||||
'''
|
||||
|
||||
if MPI is not None:
|
||||
nworkers = MPI.COMM_WORLD.Get_size()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
else:
|
||||
nworkers = 1
|
||||
rank = 0
|
||||
|
||||
nworkers = MPI.COMM_WORLD.Get_size()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
cpus_per_worker = 1
|
||||
U.get_session(config=tf.ConfigProto(
|
||||
@@ -244,13 +240,9 @@ def learn(*,
|
||||
|
||||
def allmean(x):
|
||||
assert isinstance(x, np.ndarray)
|
||||
if MPI is not None:
|
||||
out = np.empty_like(x)
|
||||
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
||||
out /= nworkers
|
||||
else:
|
||||
out = np.copy(x)
|
||||
|
||||
out = np.empty_like(x)
|
||||
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
||||
out /= nworkers
|
||||
return out
|
||||
|
||||
U.initialize()
|
||||
@@ -258,9 +250,7 @@ def learn(*,
|
||||
pi.load(load_path)
|
||||
|
||||
th_init = get_flat()
|
||||
if MPI is not None:
|
||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||
|
||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||
set_from_flat(th_init)
|
||||
vfadam.sync()
|
||||
print("Init param sum", th_init.sum(), flush=True)
|
||||
@@ -366,11 +356,7 @@ def learn(*,
|
||||
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
|
||||
|
||||
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
|
||||
if MPI is not None:
|
||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||
else:
|
||||
listoflrpairs = [lrlocal]
|
||||
|
||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||
lens, rews = map(flatten_lists, zip(*listoflrpairs))
|
||||
lenbuffer.extend(lens)
|
||||
rewbuffer.extend(rews)
|
||||
|
6
setup.py
6
setup.py
@@ -15,9 +15,6 @@ extras = {
|
||||
],
|
||||
'bullet': [
|
||||
'pybullet',
|
||||
],
|
||||
'mpi': [
|
||||
'mpi4py'
|
||||
]
|
||||
}
|
||||
|
||||
@@ -37,6 +34,7 @@ setup(name='baselines',
|
||||
'joblib',
|
||||
'dill',
|
||||
'progressbar2',
|
||||
'mpi4py',
|
||||
'cloudpickle',
|
||||
'click',
|
||||
'opencv-python'
|
||||
@@ -59,4 +57,4 @@ for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
|
||||
pass
|
||||
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
||||
from distutils.version import StrictVersion
|
||||
assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||
assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||
|
@@ -1,16 +0,0 @@
|
||||
FROM python:3.6
|
||||
|
||||
RUN apt-get -y update && apt-get -y install ffmpeg libopenmpi-dev
|
||||
ENV CODE_DIR /root/code
|
||||
|
||||
COPY . $CODE_DIR/baselines
|
||||
WORKDIR $CODE_DIR/baselines
|
||||
|
||||
# Clean up pycache and pyc files
|
||||
RUN rm -rf __pycache__ && \
|
||||
find . -name "*.pyc" -delete && \
|
||||
pip install tensorflow && \
|
||||
pip install -e .[test,mpi]
|
||||
|
||||
|
||||
CMD /bin/bash
|
@@ -1,16 +0,0 @@
|
||||
FROM python:3.6
|
||||
|
||||
RUN apt-get -y update && apt-get -y install ffmpeg
|
||||
ENV CODE_DIR /root/code
|
||||
|
||||
COPY . $CODE_DIR/baselines
|
||||
WORKDIR $CODE_DIR/baselines
|
||||
|
||||
# Clean up pycache and pyc files
|
||||
RUN rm -rf __pycache__ && \
|
||||
find . -name "*.pyc" -delete && \
|
||||
pip install tensorflow && \
|
||||
pip install -e .[test]
|
||||
|
||||
|
||||
CMD /bin/bash
|
Reference in New Issue
Block a user