implement pdfromlatent in BernoulliPdType (#81)
* implement pdfromlatent in BernoulliPdType * remove env.close() at the end of algorithms * test case for environment after learn * closing env in run.py * fixes for acktr and trpo_mpi * add make_session with new graph for every call in test_env_after_learn * remove extra prints from test_env_after_learn
This commit is contained in:
@@ -173,6 +173,5 @@ def learn(
|
|||||||
logger.record_tabular("value_loss", float(value_loss))
|
logger.record_tabular("value_loss", float(value_loss))
|
||||||
logger.record_tabular("explained_variance", float(ev))
|
logger.record_tabular("explained_variance", float(ev))
|
||||||
logger.dump_tabular()
|
logger.dump_tabular()
|
||||||
env.close()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@@ -370,5 +370,4 @@ def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6
|
|||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
acer.call(on_policy=False) # no simulation steps in this
|
acer.call(on_policy=False) # no simulation steps in this
|
||||||
|
|
||||||
env.close()
|
|
||||||
return model
|
return model
|
||||||
|
@@ -147,5 +147,4 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
|||||||
model.save(savepath)
|
model.save(savepath)
|
||||||
coord.request_stop()
|
coord.request_stop()
|
||||||
coord.join(enqueue_threads)
|
coord.join(enqueue_threads)
|
||||||
env.close()
|
|
||||||
return model
|
return model
|
||||||
|
@@ -107,6 +107,9 @@ class BernoulliPdType(PdType):
|
|||||||
return [self.size]
|
return [self.size]
|
||||||
def sample_dtype(self):
|
def sample_dtype(self):
|
||||||
return tf.int32
|
return tf.int32
|
||||||
|
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
||||||
|
pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
||||||
|
return self.pdfromflat(pdparam), pdparam
|
||||||
|
|
||||||
# WRONG SECOND DERIVATIVES
|
# WRONG SECOND DERIVATIVES
|
||||||
# class CategoricalPd(Pd):
|
# class CategoricalPd(Pd):
|
||||||
|
28
baselines/common/tests/test_env_after_learn.py
Normal file
28
baselines/common/tests/test_env_after_learn.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import pytest
|
||||||
|
import gym
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from baselines.common.models import cnn
|
||||||
|
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||||
|
from baselines.run import get_learn_function
|
||||||
|
from baselines.common.tf_util import make_session
|
||||||
|
|
||||||
|
algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('algo', algos)
|
||||||
|
def test_env_after_learn(algo):
|
||||||
|
def make_env():
|
||||||
|
env = gym.make('PongNoFrameskip-v4')
|
||||||
|
return env
|
||||||
|
|
||||||
|
make_session(make_default=True, graph=tf.Graph())
|
||||||
|
env = SubprocVecEnv([make_env])
|
||||||
|
|
||||||
|
learn = get_learn_function(algo)
|
||||||
|
network = cnn(one_dim_bias=True)
|
||||||
|
|
||||||
|
# Commenting out the following line resolves the issue, though crash happens at env.reset().
|
||||||
|
learn(network=network, env=env, total_timesteps=0, load_path=None, seed=None)
|
||||||
|
|
||||||
|
env.reset()
|
||||||
|
env.close()
|
@@ -293,7 +293,6 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
|
|||||||
savepath = osp.join(checkdir, '%.5i'%update)
|
savepath = osp.join(checkdir, '%.5i'%update)
|
||||||
print('Saving to', savepath)
|
print('Saving to', savepath)
|
||||||
model.save(savepath)
|
model.save(savepath)
|
||||||
env.close()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def safemean(xs):
|
def safemean(xs):
|
||||||
|
@@ -208,7 +208,8 @@ def main():
|
|||||||
logger.configure(format_strs=[])
|
logger.configure(format_strs=[])
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
|
||||||
model, _ = train(args, extra_args)
|
model, env = train(args, extra_args)
|
||||||
|
env.close()
|
||||||
|
|
||||||
if args.save_path is not None and rank == 0:
|
if args.save_path is not None and rank == 0:
|
||||||
save_path = osp.expanduser(args.save_path)
|
save_path = osp.expanduser(args.save_path)
|
||||||
@@ -227,6 +228,7 @@ def main():
|
|||||||
if done:
|
if done:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Reference in New Issue
Block a user