make ppo2 rnn test available.
This commit is contained in:
@@ -17,10 +17,10 @@ learn_kwargs = {
|
|||||||
# 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001)
|
# 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
alg_list = learn_kwargs.keys()
|
alg_list = learn_kwargs.keys()
|
||||||
rnn_list = ['lstm']
|
rnn_list = ['lstm']
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("alg", alg_list)
|
@pytest.mark.parametrize("alg", alg_list)
|
||||||
@pytest.mark.parametrize("rnn", rnn_list)
|
@pytest.mark.parametrize("rnn", rnn_list)
|
||||||
@@ -33,6 +33,9 @@ def test_fixed_sequence(alg, rnn):
|
|||||||
kwargs = learn_kwargs[alg]
|
kwargs = learn_kwargs[alg]
|
||||||
kwargs.update(common_kwargs)
|
kwargs.update(common_kwargs)
|
||||||
|
|
||||||
|
if alg == 'ppo2':
|
||||||
|
rnn = 'ppo_' + rnn
|
||||||
|
|
||||||
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
|
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
|
||||||
learn = lambda e: get_learn_function(alg)(
|
learn = lambda e: get_learn_function(alg)(
|
||||||
env=e,
|
env=e,
|
||||||
@@ -45,6 +48,3 @@ def test_fixed_sequence(alg, rnn):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_fixed_sequence('ppo2', 'lstm')
|
test_fixed_sequence('ppo2', 'lstm')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,17 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import gym
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import pytest
|
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from baselines.common.tests.envs.mnist_env import MnistEnv
|
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
|
||||||
from baselines.run import get_learn_function
|
|
||||||
from baselines.common.tf_util import make_session, get_session
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import tensorflow as tf
|
||||||
|
from baselines.common.tests.envs.mnist_env import MnistEnv
|
||||||
|
from baselines.common.tf_util import make_session, get_session
|
||||||
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
|
from baselines.run import get_learn_function
|
||||||
|
|
||||||
learn_kwargs = {
|
learn_kwargs = {
|
||||||
'deepq': {},
|
'deepq': {},
|
||||||
@@ -37,12 +35,15 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
Test if the trained model can be serialized
|
Test if the trained model can be serialized
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
_network_kwargs = network_kwargs[network_fn]
|
||||||
|
|
||||||
if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
|
if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
|
||||||
# TODO make acktr work with recurrent policies
|
# TODO make acktr work with recurrent policies
|
||||||
# and test
|
# and test
|
||||||
# github issue: https://github.com/openai/baselines/issues/660
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
return
|
return
|
||||||
|
elif network_fn.endswith('lstm') and learn_fn == 'ppo2':
|
||||||
|
network_fn = 'ppo_' + network_fn
|
||||||
|
|
||||||
def make_env():
|
def make_env():
|
||||||
env = MnistEnv(episode_len=100)
|
env = MnistEnv(episode_len=100)
|
||||||
@@ -54,10 +55,9 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
learn = get_learn_function(learn_fn)
|
learn = get_learn_function(learn_fn)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
kwargs.update(network_kwargs[network_fn])
|
kwargs.update(_network_kwargs)
|
||||||
kwargs.update(learn_kwargs[learn_fn])
|
kwargs.update(learn_kwargs[learn_fn])
|
||||||
|
|
||||||
|
|
||||||
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
|
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
@@ -76,7 +76,7 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
|
|
||||||
for k, v in variables_dict1.items():
|
for k, v in variables_dict1.items():
|
||||||
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
|
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
|
||||||
err_msg='saved and loaded variable {} value mismatch'.format(k))
|
err_msg='saved and loaded variable {} value mismatch'.format(k))
|
||||||
|
|
||||||
np.testing.assert_allclose(mean1, mean2, atol=0.5)
|
np.testing.assert_allclose(mean1, mean2, atol=0.5)
|
||||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||||
@@ -90,15 +90,15 @@ def test_coexistence(learn_fn, network_fn):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
if learn_fn == 'deepq':
|
if learn_fn == 'deepq':
|
||||||
# TODO enable multiple DQN models to be useable at the same time
|
# TODO enable multiple DQN models to be useable at the same time
|
||||||
# github issue https://github.com/openai/baselines/issues/656
|
# github issue https://github.com/openai/baselines/issues/656
|
||||||
return
|
return
|
||||||
|
|
||||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||||
# TODO make acktr work with recurrent policies
|
# TODO make acktr work with recurrent policies
|
||||||
# and test
|
# and test
|
||||||
# github issue: https://github.com/openai/baselines/issues/660
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
return
|
return
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
|
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
|
||||||
learn = get_learn_function(learn_fn)
|
learn = get_learn_function(learn_fn)
|
||||||
@@ -107,7 +107,7 @@ def test_coexistence(learn_fn, network_fn):
|
|||||||
kwargs.update(network_kwargs[network_fn])
|
kwargs.update(network_kwargs[network_fn])
|
||||||
kwargs.update(learn_kwargs[learn_fn])
|
kwargs.update(learn_kwargs[learn_fn])
|
||||||
|
|
||||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||||
make_session(make_default=True, graph=tf.Graph())
|
make_session(make_default=True, graph=tf.Graph())
|
||||||
model1 = learn(seed=1)
|
model1 = learn(seed=1)
|
||||||
make_session(make_default=True, graph=tf.Graph())
|
make_session(make_default=True, graph=tf.Graph())
|
||||||
@@ -117,7 +117,6 @@ def test_coexistence(learn_fn, network_fn):
|
|||||||
model2.step(env.observation_space.sample())
|
model2.step(env.observation_space.sample())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_variables():
|
def _serialize_variables():
|
||||||
sess = get_session()
|
sess = get_session()
|
||||||
variables = tf.trainable_variables()
|
variables = tf.trainable_variables()
|
||||||
@@ -137,3 +136,24 @@ def _get_action_stats(model, ob):
|
|||||||
|
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
learn_kwargs = {
|
||||||
|
'deepq': {},
|
||||||
|
'a2c': {},
|
||||||
|
'acktr': {},
|
||||||
|
'acer': {},
|
||||||
|
'ppo2': {'nminibatches': 1, 'nsteps': 10},
|
||||||
|
'trpo_mpi': {},
|
||||||
|
}
|
||||||
|
|
||||||
|
network_kwargs = {
|
||||||
|
'mlp': {},
|
||||||
|
'cnn': {'pad': 'SAME'},
|
||||||
|
'lstm': {},
|
||||||
|
'cnn_lnlstm': {'pad': 'SAME'}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
|
||||||
|
# @pytest.mark.parametrize("network_fn", network_kwargs.keys())
|
||||||
|
test_serialization('ppo2', 'cnn')
|
||||||
|
Reference in New Issue
Block a user