mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Add configure method to Env, and support multiple displays in CartPole (#175)
* Add configure method to Env, and support multiple displays in CartPole This allows people to pass runtime specification which doesn't affect the environment semantics to environments created via `make`. Also include an example of setting the display used for CartPole * Provide full configure method * Allow environments to require configuration * Don't take arguments in make
This commit is contained in:
21
gym/core.py
21
gym/core.py
@@ -21,6 +21,7 @@ class Env(object):
|
|||||||
reset
|
reset
|
||||||
render
|
render
|
||||||
close
|
close
|
||||||
|
configure
|
||||||
seed
|
seed
|
||||||
|
|
||||||
When implementing an environment, override the following methods
|
When implementing an environment, override the following methods
|
||||||
@@ -30,6 +31,7 @@ class Env(object):
|
|||||||
_reset
|
_reset
|
||||||
_render
|
_render
|
||||||
_close
|
_close
|
||||||
|
_configure
|
||||||
_seed
|
_seed
|
||||||
|
|
||||||
And set the following attributes:
|
And set the following attributes:
|
||||||
@@ -51,6 +53,7 @@ class Env(object):
|
|||||||
env._closed = False
|
env._closed = False
|
||||||
env._action_warned = False
|
env._action_warned = False
|
||||||
env._observation_warned = False
|
env._observation_warned = False
|
||||||
|
env._configured = False
|
||||||
|
|
||||||
# Will be automatically set when creating an environment via 'make'
|
# Will be automatically set when creating an environment via 'make'
|
||||||
env.spec = None
|
env.spec = None
|
||||||
@@ -64,6 +67,9 @@ class Env(object):
|
|||||||
def _close(self):
|
def _close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _configure(self):
|
||||||
|
pass
|
||||||
|
|
||||||
# Set these in ALL subclasses
|
# Set these in ALL subclasses
|
||||||
action_space = None
|
action_space = None
|
||||||
observation_space = None
|
observation_space = None
|
||||||
@@ -127,6 +133,9 @@ class Env(object):
|
|||||||
Returns:
|
Returns:
|
||||||
observation (object): the initial observation of the space. (Initial reward is assumed to be 0.)
|
observation (object): the initial observation of the space. (Initial reward is assumed to be 0.)
|
||||||
"""
|
"""
|
||||||
|
if self.metadata.get('configure.required') and not self._configured:
|
||||||
|
raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self))
|
||||||
|
|
||||||
self.monitor._before_reset()
|
self.monitor._before_reset()
|
||||||
observation = self._reset()
|
observation = self._reset()
|
||||||
self.monitor._after_reset(observation)
|
self.monitor._after_reset(observation)
|
||||||
@@ -216,6 +225,18 @@ class Env(object):
|
|||||||
"""
|
"""
|
||||||
return self._seed(seed)
|
return self._seed(seed)
|
||||||
|
|
||||||
|
def configure(self, *args, **kwargs):
|
||||||
|
"""Provides runtime configuration to the environment.
|
||||||
|
|
||||||
|
This configuration should consist of data that tells your
|
||||||
|
environment how to run (such as an address of a remote server,
|
||||||
|
or path to your ImageNet data). It should not affect the
|
||||||
|
semantics of the environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._configured = True
|
||||||
|
return self._configure(*args, **kwargs)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
@@ -43,6 +43,12 @@ class CartPoleEnv(gym.Env):
|
|||||||
|
|
||||||
self.steps_beyond_done = None
|
self.steps_beyond_done = None
|
||||||
|
|
||||||
|
# Just need to initialize the relevant attributes
|
||||||
|
self._configure()
|
||||||
|
|
||||||
|
def _configure(self, display=None):
|
||||||
|
self.display = display
|
||||||
|
|
||||||
def _seed(self, seed=None):
|
def _seed(self, seed=None):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
@@ -108,7 +114,7 @@ class CartPoleEnv(gym.Env):
|
|||||||
|
|
||||||
if self.viewer is None:
|
if self.viewer is None:
|
||||||
from gym.envs.classic_control import rendering
|
from gym.envs.classic_control import rendering
|
||||||
self.viewer = rendering.Viewer(screen_width, screen_height)
|
self.viewer = rendering.Viewer(screen_width, screen_height, display=self.display)
|
||||||
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
|
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2
|
||||||
axleoffset =cartheight/4.0
|
axleoffset =cartheight/4.0
|
||||||
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
|
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
|
||||||
|
@@ -2,7 +2,9 @@
|
|||||||
2D rendering framework
|
2D rendering framework
|
||||||
"""
|
"""
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
import os, sys
|
import os
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
|
||||||
if "Apple" in sys.version:
|
if "Apple" in sys.version:
|
||||||
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
|
if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ:
|
||||||
@@ -27,11 +29,26 @@ import numpy as np
|
|||||||
|
|
||||||
RAD2DEG = 57.29577951308232
|
RAD2DEG = 57.29577951308232
|
||||||
|
|
||||||
|
def get_display(spec):
|
||||||
|
"""Convert a display specification (such as :0) into an actual Display
|
||||||
|
object.
|
||||||
|
|
||||||
|
Pyglet only supports multiple Displays on Linux.
|
||||||
|
"""
|
||||||
|
if spec is None:
|
||||||
|
return None
|
||||||
|
elif isinstance(spec, six.string_types):
|
||||||
|
return pyglet.canvas.Display(spec)
|
||||||
|
else:
|
||||||
|
raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec))
|
||||||
|
|
||||||
class Viewer(object):
|
class Viewer(object):
|
||||||
def __init__(self, width, height):
|
def __init__(self, width, height, display=None):
|
||||||
|
display = get_display(display)
|
||||||
|
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.window = pyglet.window.Window(width=width, height=height)
|
self.window = pyglet.window.Window(width=width, height=height, display=display)
|
||||||
self.window.on_close = self.window_closed_by_user
|
self.window.on_close = self.window_closed_by_user
|
||||||
self.geoms = []
|
self.geoms = []
|
||||||
self.onetime_geoms = []
|
self.onetime_geoms = []
|
||||||
@@ -282,13 +299,14 @@ class Image(Geom):
|
|||||||
# ================================================================
|
# ================================================================
|
||||||
|
|
||||||
class SimpleImageViewer(object):
|
class SimpleImageViewer(object):
|
||||||
def __init__(self):
|
def __init__(self, display=None):
|
||||||
self.window = None
|
self.window = None
|
||||||
self.isopen = False
|
self.isopen = False
|
||||||
|
self.display = display
|
||||||
def imshow(self, arr):
|
def imshow(self, arr):
|
||||||
if self.window is None:
|
if self.window is None:
|
||||||
height, width, channels = arr.shape
|
height, width, channels = arr.shape
|
||||||
self.window = pyglet.window.Window(width=width, height=height)
|
self.window = pyglet.window.Window(width=width, height=height, display=self.display)
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.isopen = True
|
self.isopen = True
|
||||||
|
@@ -2,6 +2,7 @@ import logging
|
|||||||
import pkg_resources
|
import pkg_resources
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
numpy>=1.10.4
|
numpy>=1.10.4
|
||||||
requests>=2.0
|
requests>=2.0
|
||||||
six
|
six
|
||||||
pyglet
|
pyglet>=1.2.0
|
||||||
|
4
setup.py
4
setup.py
@@ -16,7 +16,7 @@ setup(name='gym',
|
|||||||
if package.startswith('gym')],
|
if package.startswith('gym')],
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet',
|
'numpy>=1.10.4', 'requests>=2.0', 'six', 'pyglet>=1.2.0',
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
'all': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL',
|
'all': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL',
|
||||||
@@ -28,7 +28,7 @@ setup(name='gym',
|
|||||||
|
|
||||||
# Environment-specific dependencies. Keep these in sync with
|
# Environment-specific dependencies. Keep these in sync with
|
||||||
# 'all'!
|
# 'all'!
|
||||||
'atari': ['atari_py>=0.0.17', 'Pillow', 'pyglet', 'PyOpenGL'],
|
'atari': ['atari_py>=0.0.17', 'Pillow', 'PyOpenGL'],
|
||||||
'board_game' : ['pachi-py>=0.0.19'],
|
'board_game' : ['pachi-py>=0.0.19'],
|
||||||
'box2d': ['box2d-py'],
|
'box2d': ['box2d-py'],
|
||||||
'classic_control': ['PyOpenGL'],
|
'classic_control': ['PyOpenGL'],
|
||||||
|
Reference in New Issue
Block a user