Add configure method to Env, and support multiple displays in CartPole (#175)

* Add configure method to Env, and support multiple displays in CartPole

This allows people to pass runtime specification which doesn't affect
the environment semantics to environments created via `make`.

Also include an example of setting the display used for CartPole

* Provide full configure method

* Allow environments to require configuration

* Don't take arguments in make
This commit is contained in:
Greg Brockman
2016-06-12 20:56:21 -07:00
committed by GitHub
parent f7f064160e
commit 2aa03d6088
6 changed files with 55 additions and 9 deletions

View File

@@ -21,6 +21,7 @@ class Env(object):
reset
render
close
configure
seed
When implementing an environment, override the following methods
@@ -30,6 +31,7 @@ class Env(object):
_reset
_render
_close
_configure
_seed
And set the following attributes:
@@ -51,6 +53,7 @@ class Env(object):
env._closed = False
env._action_warned = False
env._observation_warned = False
env._configured = False
# Will be automatically set when creating an environment via 'make'
env.spec = None
@@ -64,6 +67,9 @@ class Env(object):
def _close(self):
pass
def _configure(self):
pass
# Set these in ALL subclasses
action_space = None
observation_space = None
@@ -127,6 +133,9 @@ class Env(object):
Returns:
observation (object): the initial observation of the space. (Initial reward is assumed to be 0.)
"""
if self.metadata.get('configure.required') and not self._configured:
raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self))
self.monitor._before_reset()
observation = self._reset()
self.monitor._after_reset(observation)
@@ -216,6 +225,18 @@ class Env(object):
"""
return self._seed(seed)
def configure(self, *args, **kwargs):
"""Provides runtime configuration to the environment.
This configuration should consist of data that tells your
environment how to run (such as an address of a remote server,
or path to your ImageNet data). It should not affect the
semantics of the environment.
"""
self._configured = True
return self._configure(*args, **kwargs)
def __del__(self):
self.close()

View File

@@ -43,6 +43,12 @@ class CartPoleEnv(gym.Env):
self.steps_beyond_done = None
# Just need to initialize the relevant attributes
self._configure()
def _configure(self, display=None):
self.display = display
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@@ -108,7 +114,7 @@ class CartPoleEnv(gym.Env):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(screen_width, screen_height)
self.viewer = rendering.Viewer(screen_width, screen_height, display=self.display)
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
axleoffset =cartheight/4.0
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])

View File

@@ -2,7 +2,9 @@
2D rendering framework
"""
from __future__ import division
import os, sys
import os
import six
import sys
if "Apple" in sys.version:
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
@@ -27,11 +29,26 @@ import numpy as np
RAD2DEG = 57.29577951308232
def get_display(spec):
"""Convert a display specification (such as :0) into an actual Display
object.
Pyglet only supports multiple Displays on Linux.
"""
if spec is None:
return None
elif isinstance(spec, six.string_types):
return pyglet.canvas.Display(spec)
else:
raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))
class Viewer(object):
def __init__(self, width, height):
def __init__(self, width, height, display=None):
display = get_display(display)
self.width = width
self.height = height
self.window = pyglet.window.Window(width=width, height=height)
self.window = pyglet.window.Window(width=width, height=height, display=display)
self.window.on_close = self.window_closed_by_user
self.geoms = []
self.onetime_geoms = []
@@ -282,13 +299,14 @@ class Image(Geom):
# ================================================================
class SimpleImageViewer(object):
def __init__(self):
def __init__(self, display=None):
self.window = None
self.isopen = False
self.display = display
def imshow(self, arr):
if self.window is None:
height, width, channels = arr.shape
self.window = pyglet.window.Window(width=width, height=height)
self.window = pyglet.window.Window(width=width, height=height, display=self.display)
self.width = width
self.height = height
self.isopen = True

View File

@@ -2,6 +2,7 @@ import logging
import pkg_resources
import re
import sys
from gym import error
logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,4 @@
numpy>=1.10.4
requests>=2.0
six
pyglet
pyglet>=1.2.0

View File

@@ -16,7 +16,7 @@ setup(name='gym',
if package.startswith('gym')],
zip_safe=False,
install_requires=[
'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet',
'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet>=1.2.0',
],
extras_require={
'all': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL',
@@ -28,7 +28,7 @@ setup(name='gym',
# Environment-specific dependencies. Keep these in sync with
# 'all'!
'atari': ['atari_py>=0.0.17', 'Pillow', 'pyglet', 'PyOpenGL'],
'atari': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL'],
'board_game' : ['pachi-py>=0.0.19'],
'box2d': ['box2d-py'],
'classic_control': ['PyOpenGL'],