From d0cc325e1414b56bcc6cbd90bbc1778d49a3e950 Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Fri, 19 Oct 2018 08:54:21 -0700 Subject: [PATCH] store session at policy creation time (#655) * sync internal changes. Make ddpg work with vecenvs * B -> nenvs for consistency with other algos, small cleanups * eval_done[d]==True -> eval_done[d] * flake8 and numpy.random.random_integers deprecation warning * store session at policy creation time * coexistence tests * fix a typo * autopep8 * ... and flake8 * updated todo links in test_serialization --- baselines/common/policies.py | 4 +-- baselines/common/tests/envs/mnist_env.py | 2 +- baselines/common/tests/test_serialization.py | 38 +++++++++++++++++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/baselines/common/policies.py b/baselines/common/policies.py index eeac242..9c9bb8b 100644 --- a/baselines/common/policies.py +++ b/baselines/common/policies.py @@ -53,7 +53,7 @@ class PolicyWithValue(object): # Calculate the neg log of our probability self.neglogp = self.pd.neglogp(self.action) - self.sess = sess + self.sess = sess or tf.get_default_session() if estimate_q: assert isinstance(env.action_space, gym.spaces.Discrete) @@ -64,7 +64,7 @@ class PolicyWithValue(object): self.vf = self.vf[:,0] def _evaluate(self, variables, observation, **extra_feed): - sess = self.sess or tf.get_default_session() + sess = self.sess feed_dict = {self.X: adjust_shape(self.X, observation)} for inpt_name, data in extra_feed.items(): if inpt_name in self.__dict__.keys(): diff --git a/baselines/common/tests/envs/mnist_env.py b/baselines/common/tests/envs/mnist_env.py index 4f73495..473008d 100644 --- a/baselines/common/tests/envs/mnist_env.py +++ b/baselines/common/tests/envs/mnist_env.py @@ -1,7 +1,6 @@ import os.path as osp import numpy as np import tempfile -import filelock from gym import Env from gym.spaces import Discrete, Box @@ -14,6 +13,7 @@ class MnistEnv(Env): episode_len=None, no_images=None ): + import filelock from tensorflow.examples.tutorials.mnist import input_data # we could use temporary directory for this with a context manager and # TemporaryDirecotry, but then each test that uses mnist would re-download the data diff --git a/baselines/common/tests/test_serialization.py b/baselines/common/tests/test_serialization.py index 4086f2b..f46b578 100644 --- a/baselines/common/tests/test_serialization.py +++ b/baselines/common/tests/test_serialization.py @@ -1,4 +1,5 @@ import os +import gym import tempfile import pytest import tensorflow as tf @@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn): if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: # TODO make acktr work with recurrent policies # and test - # github issue: https://github.com/openai/baselines/issues/194 + # github issue: https://github.com/openai/baselines/issues/660 return env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) @@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn): np.testing.assert_allclose(std1, std2, atol=0.5) +@pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) +@pytest.mark.parametrize("network_fn", ['mlp']) +def test_coexistence(learn_fn, network_fn): + ''' + Test if more than one model can exist at a time + ''' + + if learn_fn == 'deepq': + # TODO enable multiple DQN models to be useable at the same time + # github issue https://github.com/openai/baselines/issues/656 + return + + if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: + # TODO make acktr work with recurrent policies + # and test + # github issue: https://github.com/openai/baselines/issues/660 + return + + env = DummyVecEnv([lambda: gym.make('CartPole-v0')]) + learn = get_learn_function(learn_fn) + + kwargs = {} + kwargs.update(network_kwargs[network_fn]) + kwargs.update(learn_kwargs[learn_fn]) + + learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs) + make_session(make_default=True, graph=tf.Graph()); + model1 = learn(seed=1) + make_session(make_default=True, graph=tf.Graph()); + model2 = learn(seed=2) + + model1.step(env.observation_space.sample()) + model2.step(env.observation_space.sample()) + + def _serialize_variables(): sess = get_session()