re-exported rl_algs - fixed problems with serialization test and test_cartpole

This commit is contained in:
Peter Zhokhov
2018-07-30 15:49:48 -07:00
parent 628ddecf6a
commit 872181d4c3
3 changed files with 10 additions and 10 deletions

View File

@@ -43,7 +43,7 @@ def test_serialization(learn_fn, network_fn):
return return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
ob = env.reset() ob = env.reset().copy()
learn = get_learn_function(learn_fn) learn = get_learn_function(learn_fn)
kwargs = {} kwargs = {}
@@ -51,24 +51,25 @@ def test_serialization(learn_fn, network_fn):
kwargs.update(learn_kwargs[learn_fn]) kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, seed=None, **kwargs) learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
with tempfile.TemporaryDirectory() as td: with tempfile.TemporaryDirectory() as td:
model_path = os.path.join(td, 'serialization_test_model') model_path = os.path.join(td, 'serialization_test_model')
with tf.Graph().as_default(), make_session().as_default(): with tf.Graph().as_default(), make_session().as_default():
model = learn(total_timesteps=100, seed=0) model = learn(total_timesteps=100)
model.save(model_path) model.save(model_path)
mean1, std1 = _get_action_stats(model, ob) mean1, std1 = _get_action_stats(model, ob)
variables_dict1 = _serialize_variables() variables_dict1 = _serialize_variables()
with tf.Graph().as_default(), make_session().as_default(): with tf.Graph().as_default(), make_session().as_default():
model = learn(total_timesteps=0, seed=0, load_path=model_path) model = learn(total_timesteps=0, load_path=model_path)
mean2, std2 = _get_action_stats(model, ob) mean2, std2 = _get_action_stats(model, ob)
variables_dict2 = _serialize_variables() variables_dict2 = _serialize_variables()
for k, v in variables_dict1.items(): for k, v in variables_dict1.items():
np.testing.assert_allclose(v, variables_dict2[k]) np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
err_msg='saved and loaded variable {} value mismatch'.format(k))
np.testing.assert_allclose(mean1, mean2, atol=0.5) np.testing.assert_allclose(mean1, mean2, atol=0.5)
np.testing.assert_allclose(std1, std2, atol=0.5) np.testing.assert_allclose(std1, std2, atol=0.5)

View File

@@ -1,6 +1,6 @@
from baselines.deepq import models # noqa from baselines.deepq import models # noqa
from baselines.deepq.build_graph import build_act, build_train # noqa from baselines.deepq.build_graph import build_act, build_train # noqa
from baselines.deepq.deepq import learn, load # noqa from baselines.deepq.deepq import learn, load_act # noqa
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa
def wrap_atari_dqn(env): def wrap_atari_dqn(env):

View File

@@ -173,10 +173,6 @@ def learn(*,
with tf.variable_scope("oldpi"): with tf.variable_scope("oldpi"):
oldpi = policy(observ_placeholder=ob) oldpi = policy(observ_placeholder=ob)
if load_path is not None:
pi.load(load_path)
oldpi.load(load_path)
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
@@ -247,6 +243,9 @@ def learn(*,
return out return out
U.initialize() U.initialize()
if load_path is not None:
pi.load(load_path)
th_init = get_flat() th_init = get_flat()
MPI.COMM_WORLD.Bcast(th_init, root=0) MPI.COMM_WORLD.Bcast(th_init, root=0)
set_from_flat(th_init) set_from_flat(th_init)