re-exported rl_algs - fixed problems with serialization test and test_cartpole
This commit is contained in:
@@ -43,7 +43,7 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
return
|
return
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||||
ob = env.reset()
|
ob = env.reset().copy()
|
||||||
learn = get_learn_function(learn_fn)
|
learn = get_learn_function(learn_fn)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
@@ -51,24 +51,25 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
kwargs.update(learn_kwargs[learn_fn])
|
kwargs.update(learn_kwargs[learn_fn])
|
||||||
|
|
||||||
|
|
||||||
learn = partial(learn, env=env, network=network_fn, seed=None, **kwargs)
|
learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
model_path = os.path.join(td, 'serialization_test_model')
|
model_path = os.path.join(td, 'serialization_test_model')
|
||||||
|
|
||||||
with tf.Graph().as_default(), make_session().as_default():
|
with tf.Graph().as_default(), make_session().as_default():
|
||||||
model = learn(total_timesteps=100, seed=0)
|
model = learn(total_timesteps=100)
|
||||||
model.save(model_path)
|
model.save(model_path)
|
||||||
mean1, std1 = _get_action_stats(model, ob)
|
mean1, std1 = _get_action_stats(model, ob)
|
||||||
variables_dict1 = _serialize_variables()
|
variables_dict1 = _serialize_variables()
|
||||||
|
|
||||||
with tf.Graph().as_default(), make_session().as_default():
|
with tf.Graph().as_default(), make_session().as_default():
|
||||||
model = learn(total_timesteps=0, seed=0, load_path=model_path)
|
model = learn(total_timesteps=0, load_path=model_path)
|
||||||
mean2, std2 = _get_action_stats(model, ob)
|
mean2, std2 = _get_action_stats(model, ob)
|
||||||
variables_dict2 = _serialize_variables()
|
variables_dict2 = _serialize_variables()
|
||||||
|
|
||||||
for k, v in variables_dict1.items():
|
for k, v in variables_dict1.items():
|
||||||
np.testing.assert_allclose(v, variables_dict2[k])
|
np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
|
||||||
|
err_msg='saved and loaded variable {} value mismatch'.format(k))
|
||||||
|
|
||||||
np.testing.assert_allclose(mean1, mean2, atol=0.5)
|
np.testing.assert_allclose(mean1, mean2, atol=0.5)
|
||||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from baselines.deepq import models # noqa
|
from baselines.deepq import models # noqa
|
||||||
from baselines.deepq.build_graph import build_act, build_train # noqa
|
from baselines.deepq.build_graph import build_act, build_train # noqa
|
||||||
from baselines.deepq.deepq import learn, load # noqa
|
from baselines.deepq.deepq import learn, load_act # noqa
|
||||||
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa
|
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa
|
||||||
|
|
||||||
def wrap_atari_dqn(env):
|
def wrap_atari_dqn(env):
|
||||||
|
@@ -173,10 +173,6 @@ def learn(*,
|
|||||||
with tf.variable_scope("oldpi"):
|
with tf.variable_scope("oldpi"):
|
||||||
oldpi = policy(observ_placeholder=ob)
|
oldpi = policy(observ_placeholder=ob)
|
||||||
|
|
||||||
if load_path is not None:
|
|
||||||
pi.load(load_path)
|
|
||||||
oldpi.load(load_path)
|
|
||||||
|
|
||||||
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
|
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
|
||||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||||
|
|
||||||
@@ -247,6 +243,9 @@ def learn(*,
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
U.initialize()
|
U.initialize()
|
||||||
|
if load_path is not None:
|
||||||
|
pi.load(load_path)
|
||||||
|
|
||||||
th_init = get_flat()
|
th_init = get_flat()
|
||||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||||
set_from_flat(th_init)
|
set_from_flat(th_init)
|
||||||
|
Reference in New Issue
Block a user