store session at policy creation time (#655)
* sync internal changes. Make ddpg work with vecenvs * B -> nenvs for consistency with other algos, small cleanups * eval_done[d]==True -> eval_done[d] * flake8 and numpy.random.random_integers deprecation warning * store session at policy creation time * coexistence tests * fix a typo * autopep8 * ... and flake8 * updated todo links in test_serialization
This commit is contained in:
@@ -53,7 +53,7 @@ class PolicyWithValue(object):
|
||||
|
||||
# Calculate the neg log of our probability
|
||||
self.neglogp = self.pd.neglogp(self.action)
|
||||
self.sess = sess
|
||||
self.sess = sess or tf.get_default_session()
|
||||
|
||||
if estimate_q:
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
|
||||
self.vf = self.vf[:,0]
|
||||
|
||||
def _evaluate(self, variables, observation, **extra_feed):
|
||||
sess = self.sess or tf.get_default_session()
|
||||
sess = self.sess
|
||||
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
||||
for inpt_name, data in extra_feed.items():
|
||||
if inpt_name in self.__dict__.keys():
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import filelock
|
||||
from gym import Env
|
||||
from gym.spaces import Discrete, Box
|
||||
|
||||
@@ -14,6 +13,7 @@ class MnistEnv(Env):
|
||||
episode_len=None,
|
||||
no_images=None
|
||||
):
|
||||
import filelock
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
# we could use temporary directory for this with a context manager and
|
||||
# TemporaryDirecotry, but then each test that uses mnist would re-download the data
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import gym
|
||||
import tempfile
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
|
||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||
# TODO make acktr work with recurrent policies
|
||||
# and test
|
||||
# github issue: https://github.com/openai/baselines/issues/194
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
|
||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
|
||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
|
||||
@pytest.mark.parametrize("network_fn", ['mlp'])
|
||||
def test_coexistence(learn_fn, network_fn):
|
||||
'''
|
||||
Test if more than one model can exist at a time
|
||||
'''
|
||||
|
||||
if learn_fn == 'deepq':
|
||||
# TODO enable multiple DQN models to be useable at the same time
|
||||
# github issue https://github.com/openai/baselines/issues/656
|
||||
return
|
||||
|
||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||
# TODO make acktr work with recurrent policies
|
||||
# and test
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
|
||||
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
|
||||
learn = get_learn_function(learn_fn)
|
||||
|
||||
kwargs = {}
|
||||
kwargs.update(network_kwargs[network_fn])
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
model1 = learn(seed=1)
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
model2 = learn(seed=2)
|
||||
|
||||
model1.step(env.observation_space.sample())
|
||||
model2.step(env.observation_space.sample())
|
||||
|
||||
|
||||
|
||||
def _serialize_variables():
|
||||
sess = get_session()
|
||||
|
Reference in New Issue
Block a user