Add configure method to Env, and support multiple displays in CartPole (#175)

* Add configure method to Env, and support multiple displays in CartPole

This allows people to pass runtime specification which doesn't affect
the environment semantics to environments created via `make`.

Also include an example of setting the display used for CartPole

* Provide full configure method

* Allow environments to require configuration

* Don't take arguments in make
This commit is contained in:
Greg Brockman
2016-06-12 20:56:21 -07:00
committed by GitHub
parent f7f064160e
commit 2aa03d6088
6 changed files with 55 additions and 9 deletions

View File

@@ -21,6 +21,7 @@ class Env(object):
reset reset
render render
close close
configure
seed seed
When implementing an environment, override the following methods When implementing an environment, override the following methods
@@ -30,6 +31,7 @@ class Env(object):
_reset _reset
_render _render
_close _close
_configure
_seed _seed
And set the following attributes: And set the following attributes:
@@ -51,6 +53,7 @@ class Env(object):
env._closed = False env._closed = False
env._action_warned = False env._action_warned = False
env._observation_warned = False env._observation_warned = False
env._configured = False
# Will be automatically set when creating an environment via 'make' # Will be automatically set when creating an environment via 'make'
env.spec = None env.spec = None
@@ -64,6 +67,9 @@ class Env(object):
def _close(self): def _close(self):
pass pass
def _configure(self):
pass
# Set these in ALL subclasses # Set these in ALL subclasses
action_space = None action_space = None
observation_space = None observation_space = None
@@ -127,6 +133,9 @@ class Env(object):
Returns: Returns:
observation (object): the initial observation of the space. (Initial reward is assumed to be 0.) observation (object): the initial observation of the space. (Initial reward is assumed to be 0.)
""" """
if self.metadata.get('configure.required') and not self._configured:
raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self))
self.monitor._before_reset() self.monitor._before_reset()
observation = self._reset() observation = self._reset()
self.monitor._after_reset(observation) self.monitor._after_reset(observation)
@@ -216,6 +225,18 @@ class Env(object):
""" """
return self._seed(seed) return self._seed(seed)
def configure(self, *args, **kwargs):
"""Provides runtime configuration to the environment.
This configuration should consist of data that tells your
environment how to run (such as an address of a remote server,
or path to your ImageNet data). It should not affect the
semantics of the environment.
"""
self._configured = True
return self._configure(*args, **kwargs)
def __del__(self): def __del__(self):
self.close() self.close()

View File

@@ -43,6 +43,12 @@ class CartPoleEnv(gym.Env):
self.steps_beyond_done = None self.steps_beyond_done = None
# Just need to initialize the relevant attributes
self._configure()
def _configure(self, display=None):
self.display = display
def _seed(self, seed=None): def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
return [seed] return [seed]
@@ -108,7 +114,7 @@ class CartPoleEnv(gym.Env):
if self.viewer is None: if self.viewer is None:
from gym.envs.classic_control import rendering from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(screen_width, screen_height) self.viewer = rendering.Viewer(screen_width, screen_height, display=self.display)
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2 l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
axleoffset =cartheight/4.0 axleoffset =cartheight/4.0
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])

View File

@@ -2,7 +2,9 @@
2D rendering framework 2D rendering framework
""" """
from __future__ import division from __future__ import division
import os, sys import os
import six
import sys
if "Apple" in sys.version: if "Apple" in sys.version:
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ: if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
@@ -27,11 +29,26 @@ import numpy as np
RAD2DEG = 57.29577951308232 RAD2DEG = 57.29577951308232
def get_display(spec):
"""Convert a display specification (such as :0) into an actual Display
object.
Pyglet only supports multiple Displays on Linux.
"""
if spec is None:
return None
elif isinstance(spec, six.string_types):
return pyglet.canvas.Display(spec)
else:
raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))
class Viewer(object): class Viewer(object):
def __init__(self, width, height): def __init__(self, width, height, display=None):
display = get_display(display)
self.width = width self.width = width
self.height = height self.height = height
self.window = pyglet.window.Window(width=width, height=height) self.window = pyglet.window.Window(width=width, height=height, display=display)
self.window.on_close = self.window_closed_by_user self.window.on_close = self.window_closed_by_user
self.geoms = [] self.geoms = []
self.onetime_geoms = [] self.onetime_geoms = []
@@ -282,13 +299,14 @@ class Image(Geom):
# ================================================================ # ================================================================
class SimpleImageViewer(object): class SimpleImageViewer(object):
def __init__(self): def __init__(self, display=None):
self.window = None self.window = None
self.isopen = False self.isopen = False
self.display = display
def imshow(self, arr): def imshow(self, arr):
if self.window is None: if self.window is None:
height, width, channels = arr.shape height, width, channels = arr.shape
self.window = pyglet.window.Window(width=width, height=height) self.window = pyglet.window.Window(width=width, height=height, display=self.display)
self.width = width self.width = width
self.height = height self.height = height
self.isopen = True self.isopen = True

View File

@@ -2,6 +2,7 @@ import logging
import pkg_resources import pkg_resources
import re import re
import sys import sys
from gym import error from gym import error
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,4 @@
numpy>=1.10.4 numpy>=1.10.4
requests>=2.0 requests>=2.0
six six
pyglet pyglet>=1.2.0

View File

@@ -16,7 +16,7 @@ setup(name='gym',
if package.startswith('gym')], if package.startswith('gym')],
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=[
'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet', 'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet>=1.2.0',
], ],
extras_require={ extras_require={
'all': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL', 'all': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL',
@@ -28,7 +28,7 @@ setup(name='gym',
# Environment-specific dependencies. Keep these in sync with # Environment-specific dependencies. Keep these in sync with
# 'all'! # 'all'!
'atari': ['atari_py>=0.0.17', 'Pillow', 'pyglet', 'PyOpenGL'], 'atari': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL'],
'board_game' : ['pachi-py>=0.0.19'], 'board_game' : ['pachi-py>=0.0.19'],
'box2d': ['box2d-py'], 'box2d': ['box2d-py'],
'classic_control': ['PyOpenGL'], 'classic_control': ['PyOpenGL'],