2016-04-27 08:00:58 -07:00
"""
Classic cart - pole system implemented by Rich Sutton et al .
Copied from https : / / webdocs . cs . ualberta . ca / ~ sutton / book / code / pole . c
"""
2016-04-28 22:31:46 -07:00
import logging
2016-04-27 08:00:58 -07:00
import math
import gym
from gym import spaces
2016-05-29 09:07:09 -07:00
from gym . utils import seeding
2016-04-27 08:00:58 -07:00
import numpy as np
2016-04-28 22:31:46 -07:00
logger = logging . getLogger ( __name__ )
2016-04-27 08:00:58 -07:00
class CartPoleEnv ( gym . Env ) :
metadata = {
' render.modes ' : [ ' human ' , ' rgb_array ' ] ,
' video.frames_per_second ' : 50
}
def __init__ ( self ) :
self . gravity = 9.8
self . masscart = 1.0
self . masspole = 0.1
self . total_mass = ( self . masspole + self . masscart )
self . length = 0.5 # actually half the pole's length
self . polemass_length = ( self . masspole * self . length )
self . force_mag = 10.0
self . tau = 0.02 # seconds between state updates
# Angle at which to fail the episode
self . theta_threshold_radians = 12 * 2 * math . pi / 360
self . x_threshold = 2.4
2016-05-29 09:07:09 -07:00
2016-05-30 18:07:59 -07:00
# Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds
2016-06-09 16:22:15 -07:00
high = np . array ( [ self . x_threshold * 2 , np . inf , self . theta_threshold_radians * 2 , np . inf ] )
2016-05-30 18:07:59 -07:00
self . action_space = spaces . Discrete ( 2 )
self . observation_space = spaces . Box ( - high , high )
2016-05-29 09:07:09 -07:00
self . _seed ( )
2016-04-27 08:00:58 -07:00
self . reset ( )
self . viewer = None
2016-05-29 09:07:09 -07:00
self . steps_beyond_done = None
2016-06-12 20:56:21 -07:00
# Just need to initialize the relevant attributes
self . _configure ( )
def _configure ( self , display = None ) :
self . display = display
2016-05-29 09:07:09 -07:00
def _seed ( self , seed = None ) :
self . np_random , seed = seeding . np_random ( seed )
return [ seed ]
2016-04-28 22:31:46 -07:00
2016-04-27 08:00:58 -07:00
def _step ( self , action ) :
action = action
assert action == 0 or action == 1 , " %r ( %s ) invalid " % ( action , type ( action ) )
state = self . state
x , x_dot , theta , theta_dot = state
force = self . force_mag if action == 1 else - self . force_mag
costheta = math . cos ( theta )
sintheta = math . sin ( theta )
temp = ( force + self . polemass_length * theta_dot * theta_dot * sintheta ) / self . total_mass
thetaacc = ( self . gravity * sintheta - costheta * temp ) / ( self . length * ( 4.0 / 3.0 - self . masspole * costheta * costheta / self . total_mass ) )
xacc = temp - self . polemass_length * thetaacc * costheta / self . total_mass
x = x + self . tau * x_dot
x_dot = x_dot + self . tau * xacc
theta = theta + self . tau * theta_dot
theta_dot = theta_dot + self . tau * thetaacc
self . state = ( x , x_dot , theta , theta_dot )
done = x < - self . x_threshold \
or x > self . x_threshold \
or theta < - self . theta_threshold_radians \
or theta > self . theta_threshold_radians
done = bool ( done )
2016-04-28 22:31:46 -07:00
if not done :
reward = 1.0
elif self . steps_beyond_done is None :
# Pole just fell!
self . steps_beyond_done = 0
reward = 1.0
else :
if self . steps_beyond_done == 0 :
logger . warn ( " You are calling ' step() ' even though this environment has already returned done = True. You should always call ' reset() ' once you receive ' done = True ' -- any further steps are undefined behavior. " )
self . steps_beyond_done + = 1
reward = 0.0
2016-04-27 08:00:58 -07:00
return np . array ( self . state ) , reward , done , { }
def _reset ( self ) :
2016-05-29 09:07:09 -07:00
self . state = self . np_random . uniform ( low = - 0.05 , high = 0.05 , size = ( 4 , ) )
2016-04-29 02:12:46 -07:00
self . steps_beyond_done = None
2016-04-27 08:00:58 -07:00
return np . array ( self . state )
def _render ( self , mode = ' human ' , close = False ) :
if close :
if self . viewer is not None :
self . viewer . close ( )
2016-05-15 17:22:38 -07:00
self . viewer = None
2016-04-27 08:00:58 -07:00
return
screen_width = 600
screen_height = 400
world_width = self . x_threshold * 2
scale = screen_width / world_width
carty = 100 # TOP OF CART
polewidth = 10.0
polelen = scale * 1.0
cartwidth = 50.0
cartheight = 30.0
if self . viewer is None :
from gym . envs . classic_control import rendering
2016-06-12 20:56:21 -07:00
self . viewer = rendering . Viewer ( screen_width , screen_height , display = self . display )
2016-04-27 08:00:58 -07:00
l , r , t , b = - cartwidth / 2 , cartwidth / 2 , cartheight / 2 , - cartheight / 2
axleoffset = cartheight / 4.0
cart = rendering . FilledPolygon ( [ ( l , b ) , ( l , t ) , ( r , t ) , ( r , b ) ] )
self . carttrans = rendering . Transform ( )
cart . add_attr ( self . carttrans )
self . viewer . add_geom ( cart )
l , r , t , b = - polewidth / 2 , polewidth / 2 , polelen - polewidth / 2 , - polewidth / 2
pole = rendering . FilledPolygon ( [ ( l , b ) , ( l , t ) , ( r , t ) , ( r , b ) ] )
pole . set_color ( .8 , .6 , .4 )
self . poletrans = rendering . Transform ( translation = ( 0 , axleoffset ) )
pole . add_attr ( self . poletrans )
pole . add_attr ( self . carttrans )
self . viewer . add_geom ( pole )
self . axle = rendering . make_circle ( polewidth / 2 )
self . axle . add_attr ( self . poletrans )
self . axle . add_attr ( self . carttrans )
self . axle . set_color ( .5 , .5 , .8 )
self . viewer . add_geom ( self . axle )
self . track = rendering . Line ( ( 0 , carty ) , ( screen_width , carty ) )
self . track . set_color ( 0 , 0 , 0 )
self . viewer . add_geom ( self . track )
x = self . state
cartx = x [ 0 ] * scale + screen_width / 2.0 # MIDDLE OF CART
self . carttrans . set_translation ( cartx , carty )
self . poletrans . set_rotation ( - x [ 2 ] )
2016-06-06 10:06:26 +03:00
return self . viewer . render ( return_rgb_array = mode == ' rgb_array ' )