From 2aa03d6088e6a7d95c68da6d7d229b17243bc121 Mon Sep 17 00:00:00 2001 From: Greg Brockman Date: Sun, 12 Jun 2016 20:56:21 -0700 Subject: [PATCH] Add configure method to Env, and support multiple displays in CartPole (#175) * Add configure method to Env, and support multiple displays in CartPole This allows people to pass runtime specification which doesn't affect the environment semantics to environments created via `make`. Also include an example of setting the display used for CartPole * Provide full configure method * Allow environments to require configuration * Don't take arguments in make --- gym/core.py | 21 ++++++++++++++++++++ gym/envs/classic_control/cartpole.py | 8 +++++++- gym/envs/classic_control/rendering.py | 28 ++++++++++++++++++++++----- gym/envs/registration.py | 1 + requirements.txt | 2 +- setup.py | 4 ++-- 6 files changed, 55 insertions(+), 9 deletions(-) diff --git a/gym/core.py b/gym/core.py index cdb1e0ab9..9df8b9301 100644 --- a/gym/core.py +++ b/gym/core.py @@ -21,6 +21,7 @@ class Env(object): reset render close + configure seed When implementing an environment, override the following methods @@ -30,6 +31,7 @@ class Env(object): _reset _render _close + _configure _seed And set the following attributes: @@ -51,6 +53,7 @@ class Env(object): env._closed = False env._action_warned = False env._observation_warned = False + env._configured = False # Will be automatically set when creating an environment via 'make' env.spec = None @@ -64,6 +67,9 @@ class Env(object): def _close(self): pass + def _configure(self): + pass + # Set these in ALL subclasses action_space = None observation_space = None @@ -127,6 +133,9 @@ class Env(object): Returns: observation (object): the initial observation of the space. (Initial reward is assumed to be 0.) """ + if self.metadata.get('configure.required') and not self._configured: + raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self)) + self.monitor._before_reset() observation = self._reset() self.monitor._after_reset(observation) @@ -216,6 +225,18 @@ class Env(object): """ return self._seed(seed) + def configure(self, *args, **kwargs): + """Provides runtime configuration to the environment. + + This configuration should consist of data that tells your + environment how to run (such as an address of a remote server, + or path to your ImageNet data). It should not affect the + semantics of the environment. + """ + + self._configured = True + return self._configure(*args, **kwargs) + def __del__(self): self.close() diff --git a/gym/envs/classic_control/cartpole.py b/gym/envs/classic_control/cartpole.py index febedc25e..8d6393869 100644 --- a/gym/envs/classic_control/cartpole.py +++ b/gym/envs/classic_control/cartpole.py @@ -43,6 +43,12 @@ class CartPoleEnv(gym.Env): self.steps_beyond_done = None + # Just need to initialize the relevant attributes + self._configure() + + def _configure(self, display=None): + self.display = display + def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -108,7 +114,7 @@ class CartPoleEnv(gym.Env): if self.viewer is None: from gym.envs.classic_control import rendering - self.viewer = rendering.Viewer(screen_width, screen_height) + self.viewer = rendering.Viewer(screen_width, screen_height, display=self.display) l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2 axleoffset =cartheight/4.0 cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) diff --git a/gym/envs/classic_control/rendering.py b/gym/envs/classic_control/rendering.py index c8a703a95..c0620203f 100644 --- a/gym/envs/classic_control/rendering.py +++ b/gym/envs/classic_control/rendering.py @@ -2,7 +2,9 @@ 2D rendering framework """ from __future__ import division -import os, sys +import os +import six +import sys if "Apple" in sys.version: if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ: @@ -27,11 +29,26 @@ import numpy as np RAD2DEG = 57.29577951308232 +def get_display(spec): + """Convert a display specification (such as :0) into an actual Display + object. + + Pyglet only supports multiple Displays on Linux. + """ + if spec is None: + return None + elif isinstance(spec, six.string_types): + return pyglet.canvas.Display(spec) + else: + raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec)) + class Viewer(object): - def __init__(self, width, height): + def __init__(self, width, height, display=None): + display = get_display(display) + self.width = width self.height = height - self.window = pyglet.window.Window(width=width, height=height) + self.window = pyglet.window.Window(width=width, height=height, display=display) self.window.on_close = self.window_closed_by_user self.geoms = [] self.onetime_geoms = [] @@ -282,13 +299,14 @@ class Image(Geom): # ================================================================ class SimpleImageViewer(object): - def __init__(self): + def __init__(self, display=None): self.window = None self.isopen = False + self.display = display def imshow(self, arr): if self.window is None: height, width, channels = arr.shape - self.window = pyglet.window.Window(width=width, height=height) + self.window = pyglet.window.Window(width=width, height=height, display=self.display) self.width = width self.height = height self.isopen = True diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 1487b7b24..5fa02b696 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -2,6 +2,7 @@ import logging import pkg_resources import re import sys + from gym import error logger = logging.getLogger(__name__) diff --git a/requirements.txt b/requirements.txt index 301f85b01..9c5175ff1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.10.4 requests>=2.0 six -pyglet +pyglet>=1.2.0 diff --git a/setup.py b/setup.py index a9c85fb47..13bf0ec6d 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup(name='gym', if package.startswith('gym')], zip_safe=False, install_requires=[ - 'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet', + 'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet>=1.2.0', ], extras_require={ 'all': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL', @@ -28,7 +28,7 @@ setup(name='gym', # Environment-specific dependencies. Keep these in sync with # 'all'! - 'atari': ['atari_py>=0.0.17', 'Pillow', 'pyglet', 'PyOpenGL'], + 'atari': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL'], 'board_game' : ['pachi-py>=0.0.19'], 'box2d': ['box2d-py'], 'classic_control': ['PyOpenGL'],