diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 586480c..bd6ef9b 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -9,7 +9,7 @@ except ImportError: MPI = None import gym -from gym.wrappers import FlattenDictWrapper +from gym.wrappers import FlattenObservation, FilterObservation from baselines import logger from baselines.bench import Monitor from baselines.common import set_global_seeds @@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1. env = gym.make(env_id, **env_kwargs) if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): - keys = env.observation_space.spaces.keys() - env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys)) + env = FlattenObservation(env) env.seed(seed + subrank if seed is not None else None) env = Monitor(env, @@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0): """ set_global_seeds(seed) env = gym.make(env_id) - env = FlattenDictWrapper(env, ['observation', 'desired_goal']) + env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal'])) env = Monitor( env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), info_keywords=('is_success',)) diff --git a/setup.py b/setup.py index e48f269..e1e4610 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup(name='baselines', packages=[package for package in find_packages() if package.startswith('baselines')], install_requires=[ - 'gym>=0.10.0, <1.0.0', + 'gym>=0.15.4, <0.16.0', 'scipy', 'tqdm', 'joblib',