store session at policy creation time (#655)

* sync internal changes. Make ddpg work with vecenvs

* B -> nenvs for consistency with other algos, small cleanups

* eval_done[d]==True -> eval_done[d]

* flake8 and numpy.random.random_integers deprecation warning

* store session at policy creation time

* coexistence tests

* fix a typo

* autopep8

* ... and flake8

* updated todo links in test_serialization
This commit is contained in:
pzhokhov
2018-10-19 08:54:21 -07:00
committed by GitHub
parent fc7f9cec49
commit d0cc325e14
3 changed files with 40 additions and 4 deletions

View File

@@ -53,7 +53,7 @@ class PolicyWithValue(object):
# Calculate the neg log of our probability # Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action) self.neglogp = self.pd.neglogp(self.action)
self.sess = sess self.sess = sess or tf.get_default_session()
if estimate_q: if estimate_q:
assert isinstance(env.action_space, gym.spaces.Discrete) assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
self.vf = self.vf[:,0] self.vf = self.vf[:,0]
def _evaluate(self, variables, observation, **extra_feed): def _evaluate(self, variables, observation, **extra_feed):
sess = self.sess or tf.get_default_session() sess = self.sess
feed_dict = {self.X: adjust_shape(self.X, observation)} feed_dict = {self.X: adjust_shape(self.X, observation)}
for inpt_name, data in extra_feed.items(): for inpt_name, data in extra_feed.items():
if inpt_name in self.__dict__.keys(): if inpt_name in self.__dict__.keys():

View File

@@ -1,7 +1,6 @@
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import tempfile import tempfile
import filelock
from gym import Env from gym import Env
from gym.spaces import Discrete, Box from gym.spaces import Discrete, Box
@@ -14,6 +13,7 @@ class MnistEnv(Env):
episode_len=None, episode_len=None,
no_images=None no_images=None
): ):
import filelock
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
# we could use temporary directory for this with a context manager and # we could use temporary directory for this with a context manager and
# TemporaryDirecotry, but then each test that uses mnist would re-download the data # TemporaryDirecotry, but then each test that uses mnist would re-download the data

View File

@@ -1,4 +1,5 @@
import os import os
import gym
import tempfile import tempfile
import pytest import pytest
import tensorflow as tf import tensorflow as tf
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies # TODO make acktr work with recurrent policies
# and test # and test
# github issue: https://github.com/openai/baselines/issues/194 # github issue: https://github.com/openai/baselines/issues/660
return return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
np.testing.assert_allclose(std1, std2, atol=0.5) np.testing.assert_allclose(std1, std2, atol=0.5)
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
@pytest.mark.parametrize("network_fn", ['mlp'])
def test_coexistence(learn_fn, network_fn):
'''
Test if more than one model can exist at a time
'''
if learn_fn == 'deepq':
# TODO enable multiple DQN models to be useable at the same time
# github issue https://github.com/openai/baselines/issues/656
return
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
learn = get_learn_function(learn_fn)
kwargs = {}
kwargs.update(network_kwargs[network_fn])
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
make_session(make_default=True, graph=tf.Graph());
model1 = learn(seed=1)
make_session(make_default=True, graph=tf.Graph());
model2 = learn(seed=2)
model1.step(env.observation_space.sample())
model2.step(env.observation_space.sample())
def _serialize_variables(): def _serialize_variables():
sess = get_session() sess = get_session()