2022-05-20 14:49:30 +01:00
""" Utilities of visualising an environment. """
from collections import deque
2022-05-25 15:28:19 +01:00
from typing import Callable , Dict , List , Optional , Tuple , Union
2022-03-31 12:50:38 -07:00
2022-05-20 14:49:30 +01:00
import numpy as np
2022-03-31 12:50:38 -07:00
2022-07-04 18:19:25 +01:00
import gym . error
2022-04-18 17:30:56 +02:00
from gym import Env , logger
2022-05-20 14:49:30 +01:00
from gym . core import ActType , ObsType
from gym . error import DependencyNotInstalled
2022-05-02 17:58:23 +02:00
from gym . logger import deprecation
2021-07-29 02:26:34 +02:00
2022-07-04 18:19:25 +01:00
try :
import pygame
from pygame import Surface
from pygame . event import Event
from pygame . locals import VIDEORESIZE
except ImportError :
raise gym . error . DependencyNotInstalled (
" Pygame is not installed, run `pip install gym[classic_control]` "
)
2017-02-01 13:10:59 -08:00
try :
2022-04-18 17:30:56 +02:00
import matplotlib
2021-07-29 02:26:34 +02:00
matplotlib . use ( " TkAgg " )
2017-05-11 10:44:46 -07:00
import matplotlib . pyplot as plt
2022-05-20 14:49:30 +01:00
except ImportError :
logger . warn ( " Matplotlib is not installed, run `pip install gym[other]` " )
2022-07-04 18:19:25 +01:00
plt = None
2022-05-15 15:56:06 +02:00
2021-07-29 02:26:34 +02:00
2022-04-18 17:30:56 +02:00
class MissingKeysToAction ( Exception ) :
2022-05-20 14:49:30 +01:00
""" Raised when the environment does not have a default ``keys_to_action`` mapping. """
2022-04-18 17:30:56 +02:00
class PlayableGame :
2022-05-20 14:49:30 +01:00
""" Wraps an environment allowing keyboard inputs to interact with the environment. """
2022-04-18 17:30:56 +02:00
def __init__ (
self ,
env : Env ,
2022-07-04 18:19:25 +01:00
keys_to_action : Optional [ Dict [ Tuple [ int , . . . ] , int ] ] = None ,
2022-04-18 17:30:56 +02:00
zoom : Optional [ float ] = None ,
) :
2022-05-20 14:49:30 +01:00
""" Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
Args :
env : The environment to play
keys_to_action : The dictionary of keyboard tuples and action value
zoom : If to zoom in on the environment render
"""
2022-08-22 17:21:08 +02:00
if env . render_mode not in { " rgb_array " , " single_rgb_array " } :
logger . error (
" PlayableGame wrapper works only with rgb_array and single_rgb_array render modes, "
f " but your environment render_mode = { env . render_mode } . "
)
2022-04-18 17:30:56 +02:00
self . env = env
self . relevant_keys = self . _get_relevant_keys ( keys_to_action )
self . video_size = self . _get_video_size ( zoom )
self . screen = pygame . display . set_mode ( self . video_size )
self . pressed_keys = [ ]
self . running = True
def _get_relevant_keys (
2022-05-25 15:28:19 +01:00
self , keys_to_action : Optional [ Dict [ Tuple [ int ] , int ] ] = None
2022-04-18 17:30:56 +02:00
) - > set :
if keys_to_action is None :
if hasattr ( self . env , " get_keys_to_action " ) :
keys_to_action = self . env . get_keys_to_action ( )
elif hasattr ( self . env . unwrapped , " get_keys_to_action " ) :
keys_to_action = self . env . unwrapped . get_keys_to_action ( )
else :
raise MissingKeysToAction (
2022-05-25 14:46:41 +01:00
f " { self . env . spec . id } does not have explicit key to action mapping, "
" please specify one manually "
2022-04-18 17:30:56 +02:00
)
2022-07-04 18:19:25 +01:00
assert isinstance ( keys_to_action , dict )
2022-04-18 17:30:56 +02:00
relevant_keys = set ( sum ( ( list ( k ) for k in keys_to_action . keys ( ) ) , [ ] ) )
return relevant_keys
2022-05-25 15:28:19 +01:00
def _get_video_size ( self , zoom : Optional [ float ] = None ) - > Tuple [ int , int ] :
2022-08-22 17:21:08 +02:00
rendered = self . env . render ( )
if isinstance ( rendered , List ) :
rendered = rendered [ - 1 ]
2022-07-04 18:19:25 +01:00
assert rendered is not None and isinstance ( rendered , np . ndarray )
2022-04-18 17:30:56 +02:00
video_size = [ rendered . shape [ 1 ] , rendered . shape [ 0 ] ]
if zoom is not None :
video_size = int ( video_size [ 0 ] * zoom ) , int ( video_size [ 1 ] * zoom )
return video_size
2022-05-20 14:49:30 +01:00
def process_event ( self , event : Event ) :
""" Processes a PyGame event.
2022-05-25 14:46:41 +01:00
In particular , this function is used to keep track of which buttons are currently pressed
and to exit the : func : ` play ` function when the PyGame window is closed .
2022-05-20 14:49:30 +01:00
Args :
event : The event to process
"""
2022-04-18 17:30:56 +02:00
if event . type == pygame . KEYDOWN :
if event . key in self . relevant_keys :
self . pressed_keys . append ( event . key )
elif event . key == pygame . K_ESCAPE :
self . running = False
elif event . type == pygame . KEYUP :
if event . key in self . relevant_keys :
self . pressed_keys . remove ( event . key )
elif event . type == pygame . QUIT :
self . running = False
elif event . type == VIDEORESIZE :
self . video_size = event . size
self . screen = pygame . display . set_mode ( self . video_size )
def display_arr (
2022-05-25 15:28:19 +01:00
screen : Surface , arr : np . ndarray , video_size : Tuple [ int , int ] , transpose : bool
2022-04-18 17:30:56 +02:00
) :
2022-05-20 14:49:30 +01:00
""" Displays a numpy array on screen.
Args :
screen : The screen to show the array on
arr : The array to show
video_size : The video size of the screen
transpose : If to transpose the array on the screen
"""
arr_min , arr_max = np . min ( arr ) , np . max ( arr )
2017-02-01 13:10:59 -08:00
arr = 255.0 * ( arr - arr_min ) / ( arr_max - arr_min )
pyg_img = pygame . surfarray . make_surface ( arr . swapaxes ( 0 , 1 ) if transpose else arr )
pyg_img = pygame . transform . scale ( pyg_img , video_size )
2021-07-29 02:26:34 +02:00
screen . blit ( pyg_img , ( 0 , 0 ) )
2017-02-01 13:10:59 -08:00
2022-04-18 17:30:56 +02:00
def play (
env : Env ,
transpose : Optional [ bool ] = True ,
2022-05-15 15:56:06 +02:00
fps : Optional [ int ] = None ,
2022-04-18 17:30:56 +02:00
zoom : Optional [ float ] = None ,
callback : Optional [ Callable ] = None ,
2022-05-15 15:56:06 +02:00
keys_to_action : Optional [ Dict [ Union [ Tuple [ Union [ str , int ] ] , str ] , ActType ] ] = None ,
2022-04-18 17:30:56 +02:00
seed : Optional [ int ] = None ,
2022-05-15 15:56:06 +02:00
noop : ActType = 0 ,
2022-04-18 17:30:56 +02:00
) :
2017-02-01 13:10:59 -08:00
""" Allows one to play the game using keyboard.
2022-05-20 14:49:30 +01:00
Example : :
>> > import gym
>> > from gym . utils . play import play
2022-08-22 17:21:08 +02:00
>> > play ( gym . make ( " CarRacing-v1 " , render_mode = " single_rgb_array " ) , keys_to_action = {
. . . " w " : np . array ( [ 0 , 0.7 , 0 ] ) ,
2022-05-20 14:49:30 +01:00
. . . " a " : np . array ( [ - 1 , 0 , 0 ] ) ,
. . . " s " : np . array ( [ 0 , 0 , 1 ] ) ,
. . . " d " : np . array ( [ 1 , 0 , 0 ] ) ,
. . . " wa " : np . array ( [ - 1 , 0.7 , 0 ] ) ,
. . . " dw " : np . array ( [ 1 , 0.7 , 0 ] ) ,
. . . " ds " : np . array ( [ 1 , 0 , 1 ] ) ,
. . . " as " : np . array ( [ - 1 , 0 , 1 ] ) ,
. . . } , noop = np . array ( [ 0 , 0 , 0 ] ) )
2017-02-01 13:10:59 -08:00
2022-05-20 14:49:30 +01:00
Above code works also if the environment is wrapped , so it ' s particularly useful in
2017-02-01 13:10:59 -08:00
verifying that the frame - level preprocessing does not render the game
unplayable .
If you wish to plot real time statistics as you play , you can use
2022-05-20 14:49:30 +01:00
: class : ` gym . utils . play . PlayPlot ` . Here ' s a sample code for plotting the reward
for last 150 steps .
2022-08-30 19:41:59 +05:30
>> > def callback ( obs_t , obs_tp1 , action , rew , terminated , truncated , info ) :
2022-05-20 14:49:30 +01:00
. . . return [ rew , ]
>> > plotter = PlayPlot ( callback , 150 , [ " reward " ] )
>> > play ( gym . make ( " ALE/AirRaid-v5 " ) , callback = plotter . callback )
Args :
env : Environment to use for playing .
transpose : If this is ` ` True ` ` , the output of observation is transposed . Defaults to ` ` True ` ` .
fps : Maximum number of steps of the environment executed every second . If ` ` None ` ` ( the default ) ,
` ` env . metadata [ " render_fps " " ]`` (or 30, if the environment does not specify " render_fps " ) is used.
zoom : Zoom the observation in , ` ` zoom ` ` amount , should be positive float
callback : If a callback is provided , it will be executed after every step . It takes the following input :
obs_t : observation before performing action
obs_tp1 : observation after performing action
action : action that was executed
rew : reward that was received
2022-08-30 19:41:59 +05:30
terminated : whether the environment is terminated or not
truncated : whether the environment is truncated or not
2022-05-20 14:49:30 +01:00
info : debug info
keys_to_action : Mapping from keys pressed to action performed .
Different formats are supported : Key combinations can either be expressed as a tuple of unicode code
points of the keys , as a tuple of characters , or as a string where each character of the string represents
one key .
For example if pressing ' w ' and space at the same time is supposed
to trigger action number 2 then ` ` key_to_action ` ` dict could look like this :
>> > {
. . . # ...
. . . ( ord ( ' w ' ) , ord ( ' ' ) ) : 2
. . . # ...
. . . }
or like this :
>> > {
. . . # ...
. . . ( " w " , " " ) : 2
. . . # ...
. . . }
or like this :
>> > {
. . . # ...
. . . " w " : 2
. . . # ...
. . . }
If ` ` None ` ` , default ` ` key_to_action ` ` mapping for that environment is used , if provided .
seed : Random seed used when resetting the environment . If None , no seed is used .
noop : The action used when no key input has been entered , or the entered key combination is unknown .
2017-02-01 13:10:59 -08:00
"""
2022-07-10 02:18:06 +05:30
deprecation (
" `play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools. "
)
2022-04-18 17:30:56 +02:00
env . reset ( seed = seed )
2022-05-15 15:56:06 +02:00
2022-07-02 00:08:01 +02:00
if keys_to_action is None :
if hasattr ( env , " get_keys_to_action " ) :
keys_to_action = env . get_keys_to_action ( )
elif hasattr ( env . unwrapped , " get_keys_to_action " ) :
keys_to_action = env . unwrapped . get_keys_to_action ( )
else :
raise MissingKeysToAction (
f " { env . spec . id } does not have explicit key to action mapping, "
" please specify one manually "
)
2022-07-04 18:19:25 +01:00
assert keys_to_action is not None
2022-07-02 00:08:01 +02:00
2022-05-15 15:56:06 +02:00
key_code_to_action = { }
for key_combination , action in keys_to_action . items ( ) :
key_code = tuple (
sorted ( ord ( key ) if isinstance ( key , str ) else key for key in key_combination )
)
key_code_to_action [ key_code ] = action
game = PlayableGame ( env , key_code_to_action , zoom )
if fps is None :
fps = env . metadata . get ( " render_fps " , 30 )
2017-02-01 13:10:59 -08:00
2022-07-04 18:19:25 +01:00
done , obs = True , None
2017-02-01 13:10:59 -08:00
clock = pygame . time . Clock ( )
2022-04-18 17:30:56 +02:00
while game . running :
if done :
done = False
obs = env . reset ( seed = seed )
2017-02-01 13:10:59 -08:00
else :
2022-05-15 15:56:06 +02:00
action = key_code_to_action . get ( tuple ( sorted ( game . pressed_keys ) ) , noop )
2017-02-01 13:10:59 -08:00
prev_obs = obs
2022-08-30 19:41:59 +05:30
obs , rew , terminated , truncated , info = env . step ( action )
done = terminated or truncated
2017-02-01 13:10:59 -08:00
if callback is not None :
2022-08-30 19:41:59 +05:30
callback ( prev_obs , obs , action , rew , terminated , truncated , info )
2017-02-01 13:10:59 -08:00
if obs is not None :
2022-08-22 17:21:08 +02:00
rendered = env . render ( )
if isinstance ( rendered , List ) :
rendered = rendered [ - 1 ]
assert rendered is not None and isinstance ( rendered , np . ndarray )
2022-04-18 17:30:56 +02:00
display_arr (
game . screen , rendered , transpose = transpose , video_size = game . video_size
)
2017-02-01 13:10:59 -08:00
# process pygame events
for event in pygame . event . get ( ) :
2022-04-18 17:30:56 +02:00
game . process_event ( event )
2017-02-01 13:10:59 -08:00
pygame . display . flip ( )
clock . tick ( fps )
pygame . quit ( )
2021-07-29 02:26:34 +02:00
2021-11-14 14:50:53 +01:00
class PlayPlot :
2022-05-20 14:49:30 +01:00
""" Provides a callback to create live plots of arbitrary metrics when using :func:`play`.
This class is instantiated with a function that accepts information about a single environment transition :
- obs_t : observation before performing action
- obs_tp1 : observation after performing action
- action : action that was executed
- rew : reward that was received
2022-08-30 19:41:59 +05:30
- terminated : whether the environment is terminated or not
- truncated : whether the environment is truncated or not
2022-05-20 14:49:30 +01:00
- info : debug info
It should return a list of metrics that are computed from this data .
For instance , the function may look like this : :
2022-08-30 19:41:59 +05:30
>> > def compute_metrics ( obs_t , obs_tp , action , reward , terminated , truncated , info ) :
2022-05-25 14:46:41 +01:00
. . . return [ reward , info [ " cumulative_reward " ] , np . linalg . norm ( action ) ]
2022-05-20 14:49:30 +01:00
: class : ` PlayPlot ` provides the method : meth : ` callback ` which will pass its arguments along to that function
and uses the returned values to update live plots of the metrics .
Typically , this : meth : ` callback ` will be used in conjunction with : func : ` play ` to see how the metrics evolve as you play : :
2022-05-25 14:46:41 +01:00
>> > plotter = PlayPlot ( compute_metrics , horizon_timesteps = 200 ,
. . . plot_names = [ " Immediate Rew. " , " Cumulative Rew. " , " Action Magnitude " ] )
2022-05-20 14:49:30 +01:00
>> > play ( your_env , callback = plotter . callback )
"""
def __init__ (
2022-05-25 15:28:19 +01:00
self , callback : callable , horizon_timesteps : int , plot_names : List [ str ]
2022-05-20 14:49:30 +01:00
) :
""" Constructor of :class:`PlayPlot`.
The function ` ` callback ` ` that is passed to this constructor should return
a list of metrics that is of length ` ` len ( plot_names ) ` ` .
Args :
callback : Function that computes metrics from environment transitions
horizon_timesteps : The time horizon used for the live plots
plot_names : List of plot titles
2022-05-25 14:46:41 +01:00
Raises :
DependencyNotInstalled : If matplotlib is not installed
2022-05-20 14:49:30 +01:00
"""
2022-05-02 17:58:23 +02:00
deprecation (
" `PlayPlot` is marked as deprecated and will be removed in the near future. "
)
2017-02-01 13:10:59 -08:00
self . data_callback = callback
self . horizon_timesteps = horizon_timesteps
self . plot_names = plot_names
2022-05-20 14:49:30 +01:00
if plt is None :
raise DependencyNotInstalled (
" matplotlib is not installed, run `pip install gym[other]` "
)
2019-03-22 14:27:57 -07:00
2017-02-01 13:10:59 -08:00
num_plots = len ( self . plot_names )
self . fig , self . ax = plt . subplots ( num_plots )
if num_plots == 1 :
self . ax = [ self . ax ]
for axis , name in zip ( self . ax , plot_names ) :
axis . set_title ( name )
self . t = 0
2022-07-04 18:19:25 +01:00
self . cur_plot : List [ Optional [ plt . Axes ] ] = [ None for _ in range ( num_plots ) ]
2021-07-29 02:26:34 +02:00
self . data = [ deque ( maxlen = horizon_timesteps ) for _ in range ( num_plots ) ]
2017-02-01 13:10:59 -08:00
2022-05-20 14:49:30 +01:00
def callback (
self ,
obs_t : ObsType ,
obs_tp1 : ObsType ,
action : ActType ,
rew : float ,
2022-08-30 19:41:59 +05:30
terminated : bool ,
truncated : bool ,
2022-05-20 14:49:30 +01:00
info : dict ,
) :
""" The callback that calls the provided data callback and adds the data to the plots.
Args :
obs_t : The observation at time step t
obs_tp1 : The observation at time step t + 1
action : The action
rew : The reward
2022-08-30 19:41:59 +05:30
terminated : If the environment is terminated
truncated : If the environment is truncated
2022-05-20 14:49:30 +01:00
info : The information from the environment
"""
2022-08-30 19:41:59 +05:30
points = self . data_callback (
obs_t , obs_tp1 , action , rew , terminated , truncated , info
)
2017-02-01 13:10:59 -08:00
for point , data_series in zip ( points , self . data ) :
data_series . append ( point )
self . t + = 1
xmin , xmax = max ( 0 , self . t - self . horizon_timesteps ) , self . t
for i , plot in enumerate ( self . cur_plot ) :
if plot is not None :
plot . remove ( )
2021-07-29 15:39:42 -04:00
self . cur_plot [ i ] = self . ax [ i ] . scatter (
range ( xmin , xmax ) , list ( self . data [ i ] ) , c = " blue "
)
2017-02-01 13:10:59 -08:00
self . ax [ i ] . set_xlim ( xmin , xmax )
2022-07-04 18:19:25 +01:00
if plt is None :
raise DependencyNotInstalled (
" matplotlib is not installed, run `pip install gym[other]` "
)
2017-02-01 13:10:59 -08:00
plt . pause ( 0.000001 )