Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
ea25b9e8b2 | ||
|
9ee399f5b2 |
@@ -77,6 +77,7 @@ class Monitor(Wrapper):
|
|||||||
self.total_steps += 1
|
self.total_steps += 1
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
super(Monitor, self).close()
|
||||||
if self.f is not None:
|
if self.f is not None:
|
||||||
self.f.close()
|
self.f.close()
|
||||||
|
|
||||||
|
@@ -9,7 +9,7 @@ except ImportError:
|
|||||||
MPI = None
|
MPI = None
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.wrappers import FlattenDictWrapper
|
from gym.wrappers import FlattenObservation, FilterObservation
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from baselines.bench import Monitor
|
from baselines.bench import Monitor
|
||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
@@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
|
|||||||
env = gym.make(env_id, **env_kwargs)
|
env = gym.make(env_id, **env_kwargs)
|
||||||
|
|
||||||
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
||||||
keys = env.observation_space.spaces.keys()
|
env = FlattenObservation(env)
|
||||||
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
|
|
||||||
|
|
||||||
env.seed(seed + subrank if seed is not None else None)
|
env.seed(seed + subrank if seed is not None else None)
|
||||||
env = Monitor(env,
|
env = Monitor(env,
|
||||||
@@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0):
|
|||||||
"""
|
"""
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
|
env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
|
||||||
env = Monitor(
|
env = Monitor(
|
||||||
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
|
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
|
||||||
info_keywords=('is_success',))
|
info_keywords=('is_success',))
|
||||||
|
2
setup.py
2
setup.py
@@ -31,7 +31,7 @@ setup(name='baselines',
|
|||||||
packages=[package for package in find_packages()
|
packages=[package for package in find_packages()
|
||||||
if package.startswith('baselines')],
|
if package.startswith('baselines')],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'gym>=0.10.0, <1.0.0',
|
'gym>=0.15.4, <0.16.0',
|
||||||
'scipy',
|
'scipy',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'joblib',
|
'joblib',
|
||||||
|
Reference in New Issue
Block a user