store session at policy creation time (#655)

* sync internal changes. Make ddpg work with vecenvs

* B -> nenvs for consistency with other algos, small cleanups

* eval_done[d]==True -> eval_done[d]

* flake8 and numpy.random.random_integers deprecation warning

* store session at policy creation time

* coexistence tests

* fix a typo

* autopep8

* ... and flake8

* updated todo links in test_serialization
This commit is contained in:
pzhokhov
2018-10-19 08:54:21 -07:00
committed by GitHub
parent fc7f9cec49
commit d0cc325e14
3 changed files with 40 additions and 4 deletions

View File

@@ -53,7 +53,7 @@ class PolicyWithValue(object):
# Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action)
self.sess = sess
self.sess = sess or tf.get_default_session()
if estimate_q:
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
self.vf = self.vf[:,0]
def _evaluate(self, variables, observation, **extra_feed):
sess = self.sess or tf.get_default_session()
sess = self.sess
feed_dict = {self.X: adjust_shape(self.X, observation)}
for inpt_name, data in extra_feed.items():
if inpt_name in self.__dict__.keys():

View File

@@ -1,7 +1,6 @@
import os.path as osp
import numpy as np
import tempfile
import filelock
from gym import Env
from gym.spaces import Discrete, Box
@@ -14,6 +13,7 @@ class MnistEnv(Env):
episode_len=None,
no_images=None
):
import filelock
from tensorflow.examples.tutorials.mnist import input_data
# we could use temporary directory for this with a context manager and
# TemporaryDirecotry, but then each test that uses mnist would re-download the data

View File

@@ -1,4 +1,5 @@
import os
import gym
import tempfile
import pytest
import tensorflow as tf
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/194
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
np.testing.assert_allclose(std1, std2, atol=0.5)
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
@pytest.mark.parametrize("network_fn", ['mlp'])
def test_coexistence(learn_fn, network_fn):
'''
Test if more than one model can exist at a time
'''
if learn_fn == 'deepq':
# TODO enable multiple DQN models to be useable at the same time
# github issue https://github.com/openai/baselines/issues/656
return
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
learn = get_learn_function(learn_fn)
kwargs = {}
kwargs.update(network_kwargs[network_fn])
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
make_session(make_default=True, graph=tf.Graph());
model1 = learn(seed=1)
make_session(make_default=True, graph=tf.Graph());
model2 = learn(seed=2)
model1.step(env.observation_space.sample())
model2.step(env.observation_space.sample())
def _serialize_variables():
sess = get_session()