store session at policy creation time (#655)
* sync internal changes. Make ddpg work with vecenvs * B -> nenvs for consistency with other algos, small cleanups * eval_done[d]==True -> eval_done[d] * flake8 and numpy.random.random_integers deprecation warning * store session at policy creation time * coexistence tests * fix a typo * autopep8 * ... and flake8 * updated todo links in test_serialization
This commit is contained in:
@@ -53,7 +53,7 @@ class PolicyWithValue(object):
|
|||||||
|
|
||||||
# Calculate the neg log of our probability
|
# Calculate the neg log of our probability
|
||||||
self.neglogp = self.pd.neglogp(self.action)
|
self.neglogp = self.pd.neglogp(self.action)
|
||||||
self.sess = sess
|
self.sess = sess or tf.get_default_session()
|
||||||
|
|
||||||
if estimate_q:
|
if estimate_q:
|
||||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||||
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
|
|||||||
self.vf = self.vf[:,0]
|
self.vf = self.vf[:,0]
|
||||||
|
|
||||||
def _evaluate(self, variables, observation, **extra_feed):
|
def _evaluate(self, variables, observation, **extra_feed):
|
||||||
sess = self.sess or tf.get_default_session()
|
sess = self.sess
|
||||||
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
||||||
for inpt_name, data in extra_feed.items():
|
for inpt_name, data in extra_feed.items():
|
||||||
if inpt_name in self.__dict__.keys():
|
if inpt_name in self.__dict__.keys():
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tempfile
|
import tempfile
|
||||||
import filelock
|
|
||||||
from gym import Env
|
from gym import Env
|
||||||
from gym.spaces import Discrete, Box
|
from gym.spaces import Discrete, Box
|
||||||
|
|
||||||
@@ -14,6 +13,7 @@ class MnistEnv(Env):
|
|||||||
episode_len=None,
|
episode_len=None,
|
||||||
no_images=None
|
no_images=None
|
||||||
):
|
):
|
||||||
|
import filelock
|
||||||
from tensorflow.examples.tutorials.mnist import input_data
|
from tensorflow.examples.tutorials.mnist import input_data
|
||||||
# we could use temporary directory for this with a context manager and
|
# we could use temporary directory for this with a context manager and
|
||||||
# TemporaryDirecotry, but then each test that uses mnist would re-download the data
|
# TemporaryDirecotry, but then each test that uses mnist would re-download the data
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import gym
|
||||||
import tempfile
|
import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||||
# TODO make acktr work with recurrent policies
|
# TODO make acktr work with recurrent policies
|
||||||
# and test
|
# and test
|
||||||
# github issue: https://github.com/openai/baselines/issues/194
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
return
|
return
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||||
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
|
||||||
|
@pytest.mark.parametrize("network_fn", ['mlp'])
|
||||||
|
def test_coexistence(learn_fn, network_fn):
|
||||||
|
'''
|
||||||
|
Test if more than one model can exist at a time
|
||||||
|
'''
|
||||||
|
|
||||||
|
if learn_fn == 'deepq':
|
||||||
|
# TODO enable multiple DQN models to be useable at the same time
|
||||||
|
# github issue https://github.com/openai/baselines/issues/656
|
||||||
|
return
|
||||||
|
|
||||||
|
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||||
|
# TODO make acktr work with recurrent policies
|
||||||
|
# and test
|
||||||
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
|
return
|
||||||
|
|
||||||
|
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
|
||||||
|
learn = get_learn_function(learn_fn)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
kwargs.update(network_kwargs[network_fn])
|
||||||
|
kwargs.update(learn_kwargs[learn_fn])
|
||||||
|
|
||||||
|
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||||
|
make_session(make_default=True, graph=tf.Graph());
|
||||||
|
model1 = learn(seed=1)
|
||||||
|
make_session(make_default=True, graph=tf.Graph());
|
||||||
|
model2 = learn(seed=2)
|
||||||
|
|
||||||
|
model1.step(env.observation_space.sample())
|
||||||
|
model2.step(env.observation_space.sample())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_variables():
|
def _serialize_variables():
|
||||||
sess = get_session()
|
sess = get_session()
|
||||||
|
Reference in New Issue
Block a user