Compare commits

...

4 Commits

Author SHA1 Message Date
Peter Zhokhov
7446e6ea34 fix typo 2020-01-31 05:23:33 -08:00
Peter Zhokhov
3bd068c82c actually close the file with the results on Monitor.close() 2020-01-31 05:09:29 -08:00
Harry Uglow
ea25b9e8b2 Monitor should close what it inherits (#1076) 2020-01-31 05:06:18 -08:00
pzhokhov
9ee399f5b2 Fix build with latest gym (#1034)
* update to use latest version of gym

* fix imports

* narrow down gym version to 0.15.4 <= gym < 0.16.0
2019-11-10 11:10:01 -08:00
3 changed files with 7 additions and 8 deletions

View File

@@ -9,7 +9,6 @@ import json
class Monitor(Wrapper): class Monitor(Wrapper):
EXT = "monitor.csv" EXT = "monitor.csv"
f = None
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()): def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
Wrapper.__init__(self, env=env) Wrapper.__init__(self, env=env)
@@ -77,8 +76,9 @@ class Monitor(Wrapper):
self.total_steps += 1 self.total_steps += 1
def close(self): def close(self):
if self.f is not None: super(Monitor, self).close()
self.f.close() if self.results_writer is not None:
self.results_writer.f.close()
def get_total_steps(self): def get_total_steps(self):
return self.total_steps return self.total_steps

View File

@@ -9,7 +9,7 @@ except ImportError:
MPI = None MPI = None
import gym import gym
from gym.wrappers import FlattenDictWrapper from gym.wrappers import FlattenObservation, FilterObservation
from baselines import logger from baselines import logger
from baselines.bench import Monitor from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
env = gym.make(env_id, **env_kwargs) env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
keys = env.observation_space.spaces.keys() env = FlattenObservation(env)
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
env.seed(seed + subrank if seed is not None else None) env.seed(seed + subrank if seed is not None else None)
env = Monitor(env, env = Monitor(env,
@@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0):
""" """
set_global_seeds(seed) set_global_seeds(seed)
env = gym.make(env_id) env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
env = Monitor( env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',)) info_keywords=('is_success',))

View File

@@ -31,7 +31,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym>=0.10.0, <1.0.0', 'gym>=0.15.4, <0.16.0',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',