From 872181d4c3842028de676e34a47bf36232b259b7 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Mon, 30 Jul 2018 15:49:48 -0700 Subject: [PATCH] re-exported rl_algs - fixed problems with serialization test and test_cartpole --- baselines/common/tests/test_serialization.py | 11 ++++++----- baselines/deepq/__init__.py | 2 +- baselines/trpo_mpi/trpo_mpi.py | 7 +++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/baselines/common/tests/test_serialization.py b/baselines/common/tests/test_serialization.py index 9c2b0e9..ca3d222 100644 --- a/baselines/common/tests/test_serialization.py +++ b/baselines/common/tests/test_serialization.py @@ -43,7 +43,7 @@ def test_serialization(learn_fn, network_fn): return env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) - ob = env.reset() + ob = env.reset().copy() learn = get_learn_function(learn_fn) kwargs = {} @@ -51,24 +51,25 @@ def test_serialization(learn_fn, network_fn): kwargs.update(learn_kwargs[learn_fn]) - learn = partial(learn, env=env, network=network_fn, seed=None, **kwargs) + learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs) with tempfile.TemporaryDirectory() as td: model_path = os.path.join(td, 'serialization_test_model') with tf.Graph().as_default(), make_session().as_default(): - model = learn(total_timesteps=100, seed=0) + model = learn(total_timesteps=100) model.save(model_path) mean1, std1 = _get_action_stats(model, ob) variables_dict1 = _serialize_variables() with tf.Graph().as_default(), make_session().as_default(): - model = learn(total_timesteps=0, seed=0, load_path=model_path) + model = learn(total_timesteps=0, load_path=model_path) mean2, std2 = _get_action_stats(model, ob) variables_dict2 = _serialize_variables() for k, v in variables_dict1.items(): - np.testing.assert_allclose(v, variables_dict2[k]) + np.testing.assert_allclose(v, variables_dict2[k], atol=0.01, + err_msg='saved and loaded variable {} value mismatch'.format(k)) np.testing.assert_allclose(mean1, mean2, atol=0.5) np.testing.assert_allclose(std1, std2, atol=0.5) diff --git a/baselines/deepq/__init__.py b/baselines/deepq/__init__.py index 6d2e168..6859c05 100644 --- a/baselines/deepq/__init__.py +++ b/baselines/deepq/__init__.py @@ -1,6 +1,6 @@ from baselines.deepq import models # noqa from baselines.deepq.build_graph import build_act, build_train # noqa -from baselines.deepq.deepq import learn, load # noqa +from baselines.deepq.deepq import learn, load_act # noqa from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa def wrap_atari_dqn(env): diff --git a/baselines/trpo_mpi/trpo_mpi.py b/baselines/trpo_mpi/trpo_mpi.py index 507edff..d84b0fc 100644 --- a/baselines/trpo_mpi/trpo_mpi.py +++ b/baselines/trpo_mpi/trpo_mpi.py @@ -173,10 +173,6 @@ def learn(*, with tf.variable_scope("oldpi"): oldpi = policy(observ_placeholder=ob) - if load_path is not None: - pi.load(load_path) - oldpi.load(load_path) - atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return @@ -247,6 +243,9 @@ def learn(*, return out U.initialize() + if load_path is not None: + pi.load(load_path) + th_init = get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) set_from_flat(th_init)