Compare commits

...

2 Commits

Author SHA1 Message Date
Harry Uglow
ea25b9e8b2 Monitor should close what it inherits (#1076) 2020-01-31 05:06:18 -08:00
pzhokhov
9ee399f5b2 Fix build with latest gym (#1034)
* update to use latest version of gym

* fix imports

* narrow down gym version to 0.15.4 <= gym < 0.16.0
2019-11-10 11:10:01 -08:00
3 changed files with 5 additions and 5 deletions

View File

@@ -77,6 +77,7 @@ class Monitor(Wrapper):
self.total_steps += 1 self.total_steps += 1
def close(self): def close(self):
super(Monitor, self).close()
if self.f is not None: if self.f is not None:
self.f.close() self.f.close()

View File

@@ -9,7 +9,7 @@ except ImportError:
MPI = None MPI = None
import gym import gym
from gym.wrappers import FlattenDictWrapper from gym.wrappers import FlattenObservation, FilterObservation
from baselines import logger from baselines import logger
from baselines.bench import Monitor from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
env = gym.make(env_id, **env_kwargs) env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
keys = env.observation_space.spaces.keys() env = FlattenObservation(env)
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
env.seed(seed + subrank if seed is not None else None) env.seed(seed + subrank if seed is not None else None)
env = Monitor(env, env = Monitor(env,
@@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0):
""" """
set_global_seeds(seed) set_global_seeds(seed)
env = gym.make(env_id) env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
env = Monitor( env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',)) info_keywords=('is_success',))

View File

@@ -31,7 +31,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym>=0.10.0, <1.0.0', 'gym>=0.15.4, <0.16.0',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',