From 75b93b890e5a14606ea45ce783d12d04652ac833 Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Thu, 6 Sep 2018 16:17:59 -0700 Subject: [PATCH] implement pdfromlatent in BernoulliPdType (#81) * implement pdfromlatent in BernoulliPdType * remove env.close() at the end of algorithms * test case for environment after learn * closing env in run.py * fixes for acktr and trpo_mpi * add make_session with new graph for every call in test_env_after_learn * remove extra prints from test_env_after_learn --- baselines/a2c/a2c.py | 1 - baselines/acer/acer.py | 1 - baselines/acktr/acktr_disc.py | 1 - baselines/common/distributions.py | 3 ++ .../common/tests/test_env_after_learn.py | 28 +++++++++++++++++++ baselines/ppo2/ppo2.py | 1 - baselines/run.py | 4 ++- 7 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 baselines/common/tests/test_env_after_learn.py diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index 4c3013d..729a58b 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -173,6 +173,5 @@ def learn( logger.record_tabular("value_loss", float(value_loss)) logger.record_tabular("explained_variance", float(ev)) logger.dump_tabular() - env.close() return model diff --git a/baselines/acer/acer.py b/baselines/acer/acer.py index 1bb8129..4a865f1 100644 --- a/baselines/acer/acer.py +++ b/baselines/acer/acer.py @@ -370,5 +370,4 @@ def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6 for _ in range(n): acer.call(on_policy=False) # no simulation steps in this - env.close() return model diff --git a/baselines/acktr/acktr_disc.py b/baselines/acktr/acktr_disc.py index 7e42bc6..f4be408 100644 --- a/baselines/acktr/acktr_disc.py +++ b/baselines/acktr/acktr_disc.py @@ -147,5 +147,4 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva model.save(savepath) coord.request_stop() coord.join(enqueue_threads) - env.close() return model diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 29f3632..4a84035 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -107,6 +107,9 @@ class BernoulliPdType(PdType): return [self.size] def sample_dtype(self): return tf.int32 + def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): + pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) + return self.pdfromflat(pdparam), pdparam # WRONG SECOND DERIVATIVES # class CategoricalPd(Pd): diff --git a/baselines/common/tests/test_env_after_learn.py b/baselines/common/tests/test_env_after_learn.py new file mode 100644 index 0000000..6b0890a --- /dev/null +++ b/baselines/common/tests/test_env_after_learn.py @@ -0,0 +1,28 @@ +import pytest +import gym +import tensorflow as tf + +from baselines.common.models import cnn +from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv +from baselines.run import get_learn_function +from baselines.common.tf_util import make_session + +algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] + +@pytest.mark.parametrize('algo', algos) +def test_env_after_learn(algo): + def make_env(): + env = gym.make('PongNoFrameskip-v4') + return env + + make_session(make_default=True, graph=tf.Graph()) + env = SubprocVecEnv([make_env]) + + learn = get_learn_function(algo) + network = cnn(one_dim_bias=True) + + # Commenting out the following line resolves the issue, though crash happens at env.reset(). + learn(network=network, env=env, total_timesteps=0, load_path=None, seed=None) + + env.reset() + env.close() diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index 72c9289..d118a72 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -293,7 +293,6 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0 savepath = osp.join(checkdir, '%.5i'%update) print('Saving to', savepath) model.save(savepath) - env.close() return model def safemean(xs): diff --git a/baselines/run.py b/baselines/run.py index a4bdde2..0133850 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -208,7 +208,8 @@ def main(): logger.configure(format_strs=[]) rank = MPI.COMM_WORLD.Get_rank() - model, _ = train(args, extra_args) + model, env = train(args, extra_args) + env.close() if args.save_path is not None and rank == 0: save_path = osp.expanduser(args.save_path) @@ -227,6 +228,7 @@ def main(): if done: obs = env.reset() + env.close() if __name__ == '__main__': main()