re-exported rl_algs - fixed problems with serialization test and test_cartpole

This commit is contained in:
Peter Zhokhov
2018-07-30 15:49:48 -07:00
parent 628ddecf6a
commit 872181d4c3
3 changed files with 10 additions and 10 deletions

View File

@@ -43,7 +43,7 @@ def test_serialization(learn_fn, network_fn):
return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
ob = env.reset()
ob = env.reset().copy()
learn = get_learn_function(learn_fn)
kwargs = {}
@@ -51,24 +51,25 @@ def test_serialization(learn_fn, network_fn):
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, seed=None, **kwargs)
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
with tempfile.TemporaryDirectory() as td:
model_path = os.path.join(td, 'serialization_test_model')
with tf.Graph().as_default(), make_session().as_default():
model = learn(total_timesteps=100, seed=0)
model = learn(total_timesteps=100)
model.save(model_path)
mean1, std1 = _get_action_stats(model, ob)
variables_dict1 = _serialize_variables()
with tf.Graph().as_default(), make_session().as_default():
model = learn(total_timesteps=0, seed=0, load_path=model_path)
model = learn(total_timesteps=0, load_path=model_path)
mean2, std2 = _get_action_stats(model, ob)
variables_dict2 = _serialize_variables()
for k, v in variables_dict1.items():
np.testing.assert_allclose(v, variables_dict2[k])
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
err_msg='saved and loaded variable {} value mismatch'.format(k))
np.testing.assert_allclose(mean1, mean2, atol=0.5)
np.testing.assert_allclose(std1, std2, atol=0.5)

View File

@@ -1,6 +1,6 @@
from baselines.deepq import models # noqa
from baselines.deepq.build_graph import build_act, build_train # noqa
from baselines.deepq.deepq import learn, load # noqa
from baselines.deepq.deepq import learn, load_act # noqa
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa
def wrap_atari_dqn(env):

View File

@@ -173,10 +173,6 @@ def learn(*,
with tf.variable_scope("oldpi"):
oldpi = policy(observ_placeholder=ob)
if load_path is not None:
pi.load(load_path)
oldpi.load(load_path)
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
@@ -247,6 +243,9 @@ def learn(*,
return out
U.initialize()
if load_path is not None:
pi.load(load_path)
th_init = get_flat()
MPI.COMM_WORLD.Bcast(th_init, root=0)
set_from_flat(th_init)