Go from decorator to constructor syntax for Monitor; add note to What's New

This commit is contained in:
Jie Tang
2016-12-27 16:10:23 -08:00
committed by Tom Brown
parent cf77e19c84
commit 10f7e4ffb0
7 changed files with 56 additions and 52 deletions

View File

@@ -260,6 +260,10 @@ You can also run tests in a specific directory by using the ``-s`` option, or by
What's new What's new
---------- ----------
- 2016-12-27: We've made a backwards-incompatible change to make the gym monitor
a wrapper. Instead of `env.monitor.start(directory)`, wrap your envs with an
`env = wrappers.Monitor(env, directory).` This will also give advanced users
more flexibility with what exactly is run.
- 2016-11-1: Several experimental changes to how a running monitor interacts - 2016-11-1: Several experimental changes to how a running monitor interacts
with environments. The monitor will now raise an error if reset() is called with environments. The monitor will now raise an error if reset() is called
when the env has not returned done=True. The monitor will only record complete when the env has not returned done=True. The monitor will only record complete

View File

@@ -40,7 +40,7 @@ if __name__ == '__main__':
# will be namespaced). You can also dump to a tempdir if you'd # will be namespaced). You can also dump to a tempdir if you'd
# like: tempfile.mkdtemp(). # like: tempfile.mkdtemp().
outdir = '/tmp/random-agent-results' outdir = '/tmp/random-agent-results'
env = wrappers.Monitor(directory=outdir, force=True)(env) env = wrappers.Monitor(env, directory=outdir, force=True)
env.seed(0) env.seed(0)
agent = RandomAgent(env.action_space) agent = RandomAgent(env.action_space)

View File

@@ -46,7 +46,7 @@ def main():
for trial in range(task.trials): for trial in range(task.trials):
env = gym.make(task.env_id) env = gym.make(task.env_id)
training_dir_name = "{}/{}-{}".format(args.training_dir, task.env_id, trial) training_dir_name = "{}/{}-{}".format(args.training_dir, task.env_id, trial)
env = wrappers.Monitor(training_dir_name)(env) env = wrappers.Monitor(env, training_dir_name)
env.reset() env.reset()
for _ in range(task.max_timesteps): for _ in range(task.max_timesteps):
o, r, done, _ = env.step(env.action_space.sample()) o, r, done, _ = env.step(env.action_space.sample())

View File

@@ -22,7 +22,7 @@ def test():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = wrappers.Monitor(directory=temp, video_callable=False)(env) env = wrappers.Monitor(env, directory=temp, video_callable=False)
env.seed(0) env.seed(0)
env.set_monitor_mode('evaluation') env.set_monitor_mode('evaluation')

View File

@@ -12,7 +12,7 @@ from gym.envs.registration import register
def test_monitor_filename(): def test_monitor_filename():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = Monitor(directory=temp)(env) env = Monitor(env, directory=temp)
env.close() env.close()
manifests = glob.glob(os.path.join(temp, '*.manifest.*')) manifests = glob.glob(os.path.join(temp, '*.manifest.*'))
@@ -21,7 +21,7 @@ def test_monitor_filename():
def test_write_upon_reset_false(): def test_write_upon_reset_false():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = Monitor(directory=temp, video_callable=False, write_upon_reset=False)(env) env = Monitor(env, directory=temp, video_callable=False, write_upon_reset=False)
env.reset() env.reset()
files = glob.glob(os.path.join(temp, '*')) files = glob.glob(os.path.join(temp, '*'))
@@ -34,7 +34,7 @@ def test_write_upon_reset_false():
def test_write_upon_reset_true(): def test_write_upon_reset_true():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = Monitor(directory=temp, video_callable=False, write_upon_reset=True)(env) env = Monitor(env, directory=temp, video_callable=False, write_upon_reset=True)
env.reset() env.reset()
files = glob.glob(os.path.join(temp, '*')) files = glob.glob(os.path.join(temp, '*'))
@@ -48,7 +48,7 @@ def test_video_callable_true_not_allowed():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
try: try:
env = Monitor(temp, video_callable=True)(env) env = Monitor(env, temp, video_callable=True)
except error.Error: except error.Error:
pass pass
else: else:
@@ -57,7 +57,7 @@ def test_video_callable_true_not_allowed():
def test_video_callable_false_does_not_record(): def test_video_callable_false_does_not_record():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = Monitor(temp, video_callable=False)(env) env = Monitor(env, temp, video_callable=False)
env.reset() env.reset()
env.close() env.close()
results = monitoring.load_results(temp) results = monitoring.load_results(temp)
@@ -66,7 +66,7 @@ def test_video_callable_false_does_not_record():
def test_video_callable_records_videos(): def test_video_callable_records_videos():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = Monitor(temp)(env) env = Monitor(env, temp)
env.reset() env.reset()
env.close() env.close()
results = monitoring.load_results(temp) results = monitoring.load_results(temp)
@@ -76,7 +76,7 @@ def test_semisuper_succeeds():
"""Regression test. Ensure that this can write""" """Regression test. Ensure that this can write"""
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('SemisuperPendulumDecay-v0') env = gym.make('SemisuperPendulumDecay-v0')
env = Monitor(temp)(env) env = Monitor(env, temp)
env.reset() env.reset()
env.step(env.action_space.sample()) env.step(env.action_space.sample())
env.close() env.close()
@@ -106,7 +106,7 @@ gym.envs.register(
def test_env_reuse(): def test_env_reuse():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('Autoreset-v0') env = gym.make('Autoreset-v0')
env = Monitor(temp)(env) env = Monitor(env, temp)
env.reset() env.reset()
@@ -140,7 +140,7 @@ def test_no_monitor_reset_unless_done():
env.reset() env.reset()
# can reset once as soon as we start # can reset once as soon as we start
env = Monitor(temp, video_callable=False)(env) env = Monitor(env, temp, video_callable=False)
env.reset() env.reset()
# can reset multiple times in a row # can reset multiple times in a row
@@ -167,7 +167,7 @@ def test_no_monitor_reset_unless_done():
def test_only_complete_episodes_written(): def test_only_complete_episodes_written():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')
env = Monitor(temp, video_callable=False)(env) env = Monitor(env, temp, video_callable=False)
env.reset() env.reset()
d = False d = False
while not d: while not d:
@@ -193,7 +193,7 @@ register(
def test_steps_limit_restart(): def test_steps_limit_restart():
with helpers.tempdir() as temp: with helpers.tempdir() as temp:
env = gym.make('test.StepsLimitCartpole-v0') env = gym.make('test.StepsLimitCartpole-v0')
env = Monitor(temp, video_callable=False)(env) env = Monitor(env, temp, video_callable=False)
env.reset() env.reset()
# Episode has started # Episode has started

View File

@@ -6,41 +6,41 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def Monitor(directory, video_callable=None, force=False, resume=False, class _Monitor(Wrapper):
def __init__(self, env, directory, video_callable=None, force=False, resume=False,
write_upon_reset=False, uid=None, mode=None):
super(_Monitor, self).__init__(env)
self._monitor = monitoring.MonitorManager(env)
self._monitor.start(directory, video_callable, force, resume,
write_upon_reset, uid, mode)
def _step(self, action):
self._monitor._before_step(action)
observation, reward, done, info = self.env.step(action)
done = self._monitor._after_step(observation, reward, done, info)
return observation, reward, done, info
def _reset(self):
self._monitor._before_reset()
observation = self.env.reset()
self._monitor._after_reset(observation)
return observation
def _close(self):
super(_Monitor, self)._close()
# _monitor will not be set if super(Monitor, self).__init__ raises, this check prevents a confusing error message
if getattr(self, '_monitor', None):
self._monitor.close()
def set_monitor_mode(self, mode):
logger.info("Setting the monitor mode is deprecated and will be removed soon")
self._monitor._set_mode(mode)
def Monitor(env, directory, video_callable=None, force=False, resume=False,
write_upon_reset=False, uid=None, mode=None): write_upon_reset=False, uid=None, mode=None):
class Monitor(Wrapper):
def __init__(self, env):
super(Monitor, self).__init__(env)
self._monitor = monitoring.MonitorManager(env)
self._monitor.start(directory, video_callable, force, resume,
write_upon_reset, uid, mode)
def _step(self, action): return _Monitor(TimeLimit(env), directory, video_callable, force, resume,
self._monitor._before_step(action) write_upon_reset, uid, mode)
observation, reward, done, info = self.env.step(action)
done = self._monitor._after_step(observation, reward, done, info)
return observation, reward, done, info
def _reset(self):
self._monitor._before_reset()
observation = self.env.reset()
self._monitor._after_reset(observation)
return observation
def _close(self):
super(Monitor, self)._close()
# _monitor will not be set if super(Monitor, self).__init__ raises, this check prevents a confusing error message
if getattr(self, '_monitor', None):
self._monitor.close()
def set_monitor_mode(self, mode):
logger.info("Setting the monitor mode is deprecated and will be removed soon")
self._monitor._set_mode(mode)
def monitor_wrap(env):
return Monitor(TimeLimit(env))
return monitor_wrap

View File

@@ -19,9 +19,9 @@ def test_no_double_wrapping():
temp = tempfile.mkdtemp() temp = tempfile.mkdtemp()
try: try:
env = gym.make("FrozenLake-v0") env = gym.make("FrozenLake-v0")
env = wrappers.Monitor(temp)(env) env = wrappers.Monitor(env, temp)
try: try:
env = wrappers.Monitor(temp)(env) env = wrappers.Monitor(env, temp)
except error.DoubleWrapperError: except error.DoubleWrapperError:
pass pass
else: else: