diff --git a/baselines/common/policies.py b/baselines/common/policies.py index eeac242..9c9bb8b 100644 --- a/baselines/common/policies.py +++ b/baselines/common/policies.py @@ -53,7 +53,7 @@ class PolicyWithValue(object): # Calculate the neg log of our probability self.neglogp = self.pd.neglogp(self.action) - self.sess = sess + self.sess = sess or tf.get_default_session() if estimate_q: assert isinstance(env.action_space, gym.spaces.Discrete) @@ -64,7 +64,7 @@ class PolicyWithValue(object): self.vf = self.vf[:,0] def _evaluate(self, variables, observation, **extra_feed): - sess = self.sess or tf.get_default_session() + sess = self.sess feed_dict = {self.X: adjust_shape(self.X, observation)} for inpt_name, data in extra_feed.items(): if inpt_name in self.__dict__.keys(): diff --git a/baselines/common/tests/envs/mnist_env.py b/baselines/common/tests/envs/mnist_env.py index 4f73495..473008d 100644 --- a/baselines/common/tests/envs/mnist_env.py +++ b/baselines/common/tests/envs/mnist_env.py @@ -1,7 +1,6 @@ import os.path as osp import numpy as np import tempfile -import filelock from gym import Env from gym.spaces import Discrete, Box @@ -14,6 +13,7 @@ class MnistEnv(Env): episode_len=None, no_images=None ): + import filelock from tensorflow.examples.tutorials.mnist import input_data # we could use temporary directory for this with a context manager and # TemporaryDirecotry, but then each test that uses mnist would re-download the data diff --git a/baselines/common/tests/test_serialization.py b/baselines/common/tests/test_serialization.py index 4086f2b..f46b578 100644 --- a/baselines/common/tests/test_serialization.py +++ b/baselines/common/tests/test_serialization.py @@ -1,4 +1,5 @@ import os +import gym import tempfile import pytest import tensorflow as tf @@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn): if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: # TODO make acktr work with recurrent policies # and test - # github issue: https://github.com/openai/baselines/issues/194 + # github issue: https://github.com/openai/baselines/issues/660 return env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) @@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn): np.testing.assert_allclose(std1, std2, atol=0.5) +@pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) +@pytest.mark.parametrize("network_fn", ['mlp']) +def test_coexistence(learn_fn, network_fn): + ''' + Test if more than one model can exist at a time + ''' + + if learn_fn == 'deepq': + # TODO enable multiple DQN models to be useable at the same time + # github issue https://github.com/openai/baselines/issues/656 + return + + if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: + # TODO make acktr work with recurrent policies + # and test + # github issue: https://github.com/openai/baselines/issues/660 + return + + env = DummyVecEnv([lambda: gym.make('CartPole-v0')]) + learn = get_learn_function(learn_fn) + + kwargs = {} + kwargs.update(network_kwargs[network_fn]) + kwargs.update(learn_kwargs[learn_fn]) + + learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs) + make_session(make_default=True, graph=tf.Graph()); + model1 = learn(seed=1) + make_session(make_default=True, graph=tf.Graph()); + model2 = learn(seed=2) + + model1.step(env.observation_space.sample()) + model2.step(env.observation_space.sample()) + + def _serialize_variables(): sess = get_session() diff --git a/baselines/ddpg/README.md b/baselines/ddpg/README.md index 6e936dd..ed6d23f 100755 --- a/baselines/ddpg/README.md +++ b/baselines/ddpg/README.md @@ -2,4 +2,4 @@ - Original paper: https://arxiv.org/abs/1509.02971 - Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/ -- `python -m baselines.ddpg.main` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options. \ No newline at end of file +- `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options. diff --git a/baselines/deepq/deepq.py b/baselines/deepq/deepq.py index 47fe19a..c6004b2 100644 --- a/baselines/deepq/deepq.py +++ b/baselines/deepq/deepq.py @@ -47,6 +47,9 @@ class ActWrapper(object): return self._act(*args, **kwargs) def step(self, observation, **kwargs): + # DQN doesn't use RNNs so we ignore states and masks + kwargs.pop('S', None) + kwargs.pop('M', None) return self._act([observation], **kwargs), None, None, None def save_act(self, path=None): diff --git a/baselines/run.py b/baselines/run.py index 5dee154..8ab71ac 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -121,9 +121,11 @@ def build_env(args): env = retro_wrappers.wrap_deepmind_retro(env) else: - get_session(tf.ConfigProto(allow_soft_placement=True, - intra_op_parallelism_threads=1, - inter_op_parallelism_threads=1)) + config = tf.ConfigProto(allow_soft_placement=True, + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1) + config.gpu_options.allow_growth = True + get_session(config=config) env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale) diff --git a/setup.py b/setup.py index 5ec1fce..726c6a3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,8 @@ if sys.version_info.major != 3: extras = { 'test': [ 'filelock', - 'pytest' + 'pytest', + 'atari-py' ], 'bullet': [ 'pybullet', @@ -27,7 +28,7 @@ setup(name='baselines', packages=[package for package in find_packages() if package.startswith('baselines')], install_requires=[ - 'gym[mujoco,atari,classic_control,robotics]', + 'gym', 'scipy', 'tqdm', 'joblib',