From c0fa11a3a730dcf987040ea131461ccd74e7da7b Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Mon, 22 Oct 2018 09:15:04 -0700 Subject: [PATCH] minor fixes from internal (#665) * sync internal changes. Make ddpg work with vecenvs * B -> nenvs for consistency with other algos, small cleanups * eval_done[d]==True -> eval_done[d] * flake8 and numpy.random.random_integers deprecation warning * Merge branch 'master' of github.com:openai/games into peterz_track_baselines_branch --- baselines/common/atari_wrappers.py | 5 ++++- baselines/common/vec_env/vec_frame_stack.py | 3 --- baselines/logger.py | 3 ++- setup.py | 16 ++++++++++------ 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/baselines/common/atari_wrappers.py b/baselines/common/atari_wrappers.py index 6be3582..731ee7e 100644 --- a/baselines/common/atari_wrappers.py +++ b/baselines/common/atari_wrappers.py @@ -213,8 +213,11 @@ class LazyFrames(object): def __getitem__(self, i): return self._force()[i] -def make_atari(env_id): +def make_atari(env_id, timelimit=True): + # XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping env = gym.make(env_id) + if not timelimit: + env = env.env assert 'NoFrameskip' in env.spec.id env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) diff --git a/baselines/common/vec_env/vec_frame_stack.py b/baselines/common/vec_env/vec_frame_stack.py index 9185873..1b7a695 100644 --- a/baselines/common/vec_env/vec_frame_stack.py +++ b/baselines/common/vec_env/vec_frame_stack.py @@ -28,6 +28,3 @@ class VecFrameStack(VecEnvWrapper): self.stackedobs[...] = 0 self.stackedobs[..., -obs.shape[-1]:] = obs return self.stackedobs - - def close(self): - self.venv.close() diff --git a/baselines/logger.py b/baselines/logger.py index be38f43..95ae75b 100644 --- a/baselines/logger.py +++ b/baselines/logger.py @@ -106,7 +106,8 @@ class CSVOutputFormat(KVWriter): def writekvs(self, kvs): # Add our current row to the history - extra_keys = kvs.keys() - self.keys + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() if extra_keys: self.keys.extend(extra_keys) self.file.seek(0) diff --git a/setup.py b/setup.py index 726c6a3..425a1e8 100644 --- a/setup.py +++ b/setup.py @@ -48,9 +48,13 @@ setup(name='baselines', # ensure there is some tensorflow build with version above 1.4 -try: - from distutils.version import StrictVersion - import tensorflow - assert StrictVersion(re.sub(r'-rc\d+$', '', tensorflow.__version__)) >= StrictVersion('1.4.0') -except ImportError: - assert False, "TensorFlow needed, of version above 1.4" +import pkg_resources +tf_pkg = None +for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']: + try: + tf_pkg = pkg_resources.get_distribution(tf_pkg_name) + except pkg_resources.DistributionNotFound: + pass +assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4' +from distutils.version import StrictVersion +assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')