Fix build with latest gym (#1034)

* update to use latest version of gym

* fix imports

* narrow down gym version to 0.15.4 <= gym < 0.16.0
This commit is contained in:
pzhokhov
2019-11-10 11:10:01 -08:00
committed by GitHub
parent 391811d98c
commit 9ee399f5b2
2 changed files with 4 additions and 5 deletions

View File

@@ -9,7 +9,7 @@ except ImportError:
MPI = None MPI = None
import gym import gym
from gym.wrappers import FlattenDictWrapper from gym.wrappers import FlattenObservation, FilterObservation
from baselines import logger from baselines import logger
from baselines.bench import Monitor from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
env = gym.make(env_id, **env_kwargs) env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
keys = env.observation_space.spaces.keys() env = FlattenObservation(env)
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
env.seed(seed + subrank if seed is not None else None) env.seed(seed + subrank if seed is not None else None)
env = Monitor(env, env = Monitor(env,
@@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0):
""" """
set_global_seeds(seed) set_global_seeds(seed)
env = gym.make(env_id) env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
env = Monitor( env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',)) info_keywords=('is_success',))

View File

@@ -31,7 +31,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym>=0.10.0, <1.0.0', 'gym>=0.15.4, <0.16.0',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',