re-exported rl_algs - fixed problems with serialization test and test_cartpole
This commit is contained in:
@@ -43,7 +43,7 @@ def test_serialization(learn_fn, network_fn):
|
||||
return
|
||||
|
||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||
ob = env.reset()
|
||||
ob = env.reset().copy()
|
||||
learn = get_learn_function(learn_fn)
|
||||
|
||||
kwargs = {}
|
||||
@@ -51,24 +51,25 @@ def test_serialization(learn_fn, network_fn):
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, seed=None, **kwargs)
|
||||
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
model_path = os.path.join(td, 'serialization_test_model')
|
||||
|
||||
with tf.Graph().as_default(), make_session().as_default():
|
||||
model = learn(total_timesteps=100, seed=0)
|
||||
model = learn(total_timesteps=100)
|
||||
model.save(model_path)
|
||||
mean1, std1 = _get_action_stats(model, ob)
|
||||
variables_dict1 = _serialize_variables()
|
||||
|
||||
with tf.Graph().as_default(), make_session().as_default():
|
||||
model = learn(total_timesteps=0, seed=0, load_path=model_path)
|
||||
model = learn(total_timesteps=0, load_path=model_path)
|
||||
mean2, std2 = _get_action_stats(model, ob)
|
||||
variables_dict2 = _serialize_variables()
|
||||
|
||||
for k, v in variables_dict1.items():
|
||||
np.testing.assert_allclose(v, variables_dict2[k])
|
||||
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
|
||||
err_msg='saved and loaded variable {} value mismatch'.format(k))
|
||||
|
||||
np.testing.assert_allclose(mean1, mean2, atol=0.5)
|
||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from baselines.deepq import models # noqa
|
||||
from baselines.deepq.build_graph import build_act, build_train # noqa
|
||||
from baselines.deepq.deepq import learn, load # noqa
|
||||
from baselines.deepq.deepq import learn, load_act # noqa
|
||||
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa
|
||||
|
||||
def wrap_atari_dqn(env):
|
||||
|
@@ -173,10 +173,6 @@ def learn(*,
|
||||
with tf.variable_scope("oldpi"):
|
||||
oldpi = policy(observ_placeholder=ob)
|
||||
|
||||
if load_path is not None:
|
||||
pi.load(load_path)
|
||||
oldpi.load(load_path)
|
||||
|
||||
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
@@ -247,6 +243,9 @@ def learn(*,
|
||||
return out
|
||||
|
||||
U.initialize()
|
||||
if load_path is not None:
|
||||
pi.load(load_path)
|
||||
|
||||
th_init = get_flat()
|
||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||
set_from_flat(th_init)
|
||||
|
Reference in New Issue
Block a user