""" Classic cart-pole system implemented by Rich Sutton et al. Copied from http://incompleteideas.net/sutton/book/code/pole.c permalink: https://perma.cc/C9ZM-652R """ import math from typing import Optional import gym from gym import spaces, logger from gym.utils import seeding import numpy as np class CartPoleEnv(gym.Env): """ ### Description This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077). A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum starts upright, and the goal is to prevent it from falling over by increasing and reducing the cart's velocity. ### Action Space The agent take a 1-element vector for actions. The action space is `(action)` in `[0, 1]`, where `action` is used to push the cart with a fixed amount of force: | Num | Action | |-----|------------------------| | 0 | Push cart to the left | | 1 | Push cart to the right | Note: The amount the velocity is reduced or increased is not fixed as it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it ### Observation Space The observation is a `ndarray` with shape `(4,)` where the elements correspond to the following: | Num | Observation | Min | Max | |-----|-----------------------|----------------------|--------------------| | 0 | Cart Position | -4.8* | 4.8* | | 1 | Cart Velocity | -Inf | Inf | | 2 | Pole Angle | ~ -0.418 rad (-24°)** | ~ 0.418 rad (24°)** | | 3 | Pole Angular Velocity | -Inf | Inf | **Note:** above denotes the ranges of possible observations for each element, but in two cases this range exceeds the range of possible values in an un-terminated episode: - `*`: the cart x-position can be observed between `(-4.8, 4.8)`, but an episode terminates if the cart leaves the `(-2.4, 2.4)` range. - `**`: Similarly, the pole angle can be observed between `(-.418, .418)` radians or precisely **±24°**, but an episode is terminated if the pole angle is outside the `(-.2095, .2095)` range or precisely **±12°** ### Rewards Reward is 1 for every step taken, including the termination step. The threshold is 475 for v1. ### Starting State All observations are assigned a uniform random value between (-0.05, 0.05) ### Episode Termination The episode terminates of one of the following occurs: 1. Pole Angle is more than ±12° 2. Cart Position is more than ±2.4 (center of the cart reaches the edge of the display) 3. Episode length is greater than 500 (200 for v0) ### Arguments No additional arguments are currently supported. """ metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50} def __init__(self): self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 self.total_mass = self.masspole + self.masscart self.length = 0.5 # actually half the pole's length self.polemass_length = self.masspole * self.length self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates self.kinematics_integrator = "euler" # Angle at which to fail the episode self.theta_threshold_radians = 12 * 2 * math.pi / 360 self.x_threshold = 2.4 # Angle limit set to 2 * theta_threshold_radians so failing observation # is still within bounds. high = np.array( [ self.x_threshold * 2, np.finfo(np.float32).max, self.theta_threshold_radians * 2, np.finfo(np.float32).max, ], dtype=np.float32, ) self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(-high, high, dtype=np.float32) self.viewer = None self.state = None self.steps_beyond_done = None def step(self, action): err_msg = f"{action!r} ({type(action)}) invalid" assert self.action_space.contains(action), err_msg x, x_dot, theta, theta_dot = self.state force = self.force_mag if action == 1 else -self.force_mag costheta = math.cos(theta) sintheta = math.sin(theta) # For the interested reader: # https://coneural.org/florian/papers/05_cart_pole.pdf temp = ( force + self.polemass_length * theta_dot ** 2 * sintheta ) / self.total_mass thetaacc = (self.gravity * sintheta - costheta * temp) / ( self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass) ) xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass if self.kinematics_integrator == "euler": x = x + self.tau * x_dot x_dot = x_dot + self.tau * xacc theta = theta + self.tau * theta_dot theta_dot = theta_dot + self.tau * thetaacc else: # semi-implicit euler x_dot = x_dot + self.tau * xacc x = x + self.tau * x_dot theta_dot = theta_dot + self.tau * thetaacc theta = theta + self.tau * theta_dot self.state = (x, x_dot, theta, theta_dot) done = bool( x < -self.x_threshold or x > self.x_threshold or theta < -self.theta_threshold_radians or theta > self.theta_threshold_radians ) if not done: reward = 1.0 elif self.steps_beyond_done is None: # Pole just fell! self.steps_beyond_done = 0 reward = 1.0 else: if self.steps_beyond_done == 0: logger.warn( "You are calling 'step()' even though this " "environment has already returned done = True. You " "should always call 'reset()' once you receive 'done = " "True' -- any further steps are undefined behavior." ) self.steps_beyond_done += 1 reward = 0.0 return np.array(self.state, dtype=np.float32), reward, done, {} def reset( self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None, ): super().reset(seed=seed) self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) self.steps_beyond_done = None if not return_info: return np.array(self.state, dtype=np.float32) else: return np.array(self.state, dtype=np.float32), {} def render(self, mode="human"): screen_width = 600 screen_height = 400 world_width = self.x_threshold * 2 scale = screen_width / world_width carty = 100 # TOP OF CART polewidth = 10.0 polelen = scale * (2 * self.length) cartwidth = 50.0 cartheight = 30.0 if self.viewer is None: from gym.utils import pyglet_rendering self.viewer = pyglet_rendering.Viewer(screen_width, screen_height) l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 axleoffset = cartheight / 4.0 cart = pyglet_rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) self.carttrans = pyglet_rendering.Transform() cart.add_attr(self.carttrans) self.viewer.add_geom(cart) l, r, t, b = ( -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2, ) pole = pyglet_rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) pole.set_color(0.8, 0.6, 0.4) self.poletrans = pyglet_rendering.Transform(translation=(0, axleoffset)) pole.add_attr(self.poletrans) pole.add_attr(self.carttrans) self.viewer.add_geom(pole) self.axle = pyglet_rendering.make_circle(polewidth / 2) self.axle.add_attr(self.poletrans) self.axle.add_attr(self.carttrans) self.axle.set_color(0.5, 0.5, 0.8) self.viewer.add_geom(self.axle) self.track = pyglet_rendering.Line((0, carty), (screen_width, carty)) self.track.set_color(0, 0, 0) self.viewer.add_geom(self.track) self._pole_geom = pole if self.state is None: return None # Edit the pole polygon vertex pole = self._pole_geom l, r, t, b = ( -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2, ) pole.v = [(l, b), (l, t), (r, t), (r, b)] x = self.state cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART self.carttrans.set_translation(cartx, carty) self.poletrans.set_rotation(-x[2]) return self.viewer.render(return_rgb_array=mode == "rgb_array") def close(self): if self.viewer: self.viewer.close() self.viewer = None