mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
* initial draft of optional info dict in reset function, implemented for cartpole, tests seem to be passing * merged core.py * updated return type annotation for reset function in core.py * optional metadata with return_info from reset added for all first party environments, with corresponding tests. Incomplete implementation for wrappers and vector wrappers * removed Optional type for return_info arguments * added tests for return_info to normalize wrapper and sync_vector_env * autoformatted using black * added optional reset metadata tests to several wrappers * added return_info capability to async_vector_env.py and test to verify functionality * added optional return_info test for record_video.py * removed tests for mujoco environments * autoformatted * improved test coverage for optional reset return_info * re-removed unit test envs accidentally reintroduced in merge * removed unnecessary import * changes based on code-review * small fix to core wrapper typing and autoformatted record_epsisode_stats * small change to pass flake8 style
252 lines
9.4 KiB
Python
252 lines
9.4 KiB
Python
"""
|
|
Classic cart-pole system implemented by Rich Sutton et al.
|
|
Copied from http://incompleteideas.net/sutton/book/code/pole.c
|
|
permalink: https://perma.cc/C9ZM-652R
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional
|
|
|
|
import gym
|
|
from gym import spaces, logger
|
|
from gym.utils import seeding
|
|
import numpy as np
|
|
|
|
|
|
class CartPoleEnv(gym.Env):
|
|
"""
|
|
### Description
|
|
This environment corresponds to the version of the cart-pole problem
|
|
described by Barto, Sutton, and Anderson in ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
|
|
A pole is attached by an un-actuated joint to a cart, which moves along a
|
|
frictionless track. The pendulum starts upright, and the goal is to prevent
|
|
it from falling over by increasing and reducing the cart's velocity.
|
|
|
|
### Action Space
|
|
The agent take a 1-element vector for actions.
|
|
The action space is `(action)` in `[0, 1]`, where `action` is used to push
|
|
the cart with a fixed amount of force:
|
|
|
|
| Num | Action |
|
|
|-----|------------------------|
|
|
| 0 | Push cart to the left |
|
|
| 1 | Push cart to the right |
|
|
|
|
Note: The amount the velocity is reduced or increased is not fixed as it depends on the angle the pole is pointing.
|
|
This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it
|
|
|
|
### Observation Space
|
|
The observation is a `ndarray` with shape `(4,)` where the elements correspond to the following:
|
|
|
|
| Num | Observation | Min | Max |
|
|
|-----|-----------------------|----------------------|--------------------|
|
|
| 0 | Cart Position | -4.8* | 4.8* |
|
|
| 1 | Cart Velocity | -Inf | Inf |
|
|
| 2 | Pole Angle | ~ -0.418 rad (-24°)** | ~ 0.418 rad (24°)** |
|
|
| 3 | Pole Angular Velocity | -Inf | Inf |
|
|
|
|
**Note:** above denotes the ranges of possible observations for each element, but in two cases this range exceeds the
|
|
range of possible values in an un-terminated episode:
|
|
- `*`: the cart x-position can be observed between `(-4.8, 4.8)`, but an episode terminates if the cart leaves the
|
|
`(-2.4, 2.4)` range.
|
|
- `**`: Similarly, the pole angle can be observed between `(-.418, .418)` radians or precisely **±24°**, but an episode is
|
|
terminated if the pole angle is outside the `(-.2095, .2095)` range or precisely **±12°**
|
|
|
|
### Rewards
|
|
Reward is 1 for every step taken, including the termination step. The threshold is 475 for v1.
|
|
|
|
### Starting State
|
|
All observations are assigned a uniform random value between (-0.05, 0.05)
|
|
|
|
### Episode Termination
|
|
The episode terminates of one of the following occurs:
|
|
|
|
1. Pole Angle is more than ±12°
|
|
2. Cart Position is more than ±2.4 (center of the cart reaches the edge of the display)
|
|
3. Episode length is greater than 500 (200 for v0)
|
|
|
|
### Arguments
|
|
|
|
No additional arguments are currently supported.
|
|
"""
|
|
|
|
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50}
|
|
|
|
def __init__(self):
|
|
self.gravity = 9.8
|
|
self.masscart = 1.0
|
|
self.masspole = 0.1
|
|
self.total_mass = self.masspole + self.masscart
|
|
self.length = 0.5 # actually half the pole's length
|
|
self.polemass_length = self.masspole * self.length
|
|
self.force_mag = 10.0
|
|
self.tau = 0.02 # seconds between state updates
|
|
self.kinematics_integrator = "euler"
|
|
|
|
# Angle at which to fail the episode
|
|
self.theta_threshold_radians = 12 * 2 * math.pi / 360
|
|
self.x_threshold = 2.4
|
|
|
|
# Angle limit set to 2 * theta_threshold_radians so failing observation
|
|
# is still within bounds.
|
|
high = np.array(
|
|
[
|
|
self.x_threshold * 2,
|
|
np.finfo(np.float32).max,
|
|
self.theta_threshold_radians * 2,
|
|
np.finfo(np.float32).max,
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
self.action_space = spaces.Discrete(2)
|
|
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
|
|
|
self.viewer = None
|
|
self.state = None
|
|
|
|
self.steps_beyond_done = None
|
|
|
|
def step(self, action):
|
|
err_msg = f"{action!r} ({type(action)}) invalid"
|
|
assert self.action_space.contains(action), err_msg
|
|
|
|
x, x_dot, theta, theta_dot = self.state
|
|
force = self.force_mag if action == 1 else -self.force_mag
|
|
costheta = math.cos(theta)
|
|
sintheta = math.sin(theta)
|
|
|
|
# For the interested reader:
|
|
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
|
temp = (
|
|
force + self.polemass_length * theta_dot ** 2 * sintheta
|
|
) / self.total_mass
|
|
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
|
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
|
|
)
|
|
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
|
|
|
|
if self.kinematics_integrator == "euler":
|
|
x = x + self.tau * x_dot
|
|
x_dot = x_dot + self.tau * xacc
|
|
theta = theta + self.tau * theta_dot
|
|
theta_dot = theta_dot + self.tau * thetaacc
|
|
else: # semi-implicit euler
|
|
x_dot = x_dot + self.tau * xacc
|
|
x = x + self.tau * x_dot
|
|
theta_dot = theta_dot + self.tau * thetaacc
|
|
theta = theta + self.tau * theta_dot
|
|
|
|
self.state = (x, x_dot, theta, theta_dot)
|
|
|
|
done = bool(
|
|
x < -self.x_threshold
|
|
or x > self.x_threshold
|
|
or theta < -self.theta_threshold_radians
|
|
or theta > self.theta_threshold_radians
|
|
)
|
|
|
|
if not done:
|
|
reward = 1.0
|
|
elif self.steps_beyond_done is None:
|
|
# Pole just fell!
|
|
self.steps_beyond_done = 0
|
|
reward = 1.0
|
|
else:
|
|
if self.steps_beyond_done == 0:
|
|
logger.warn(
|
|
"You are calling 'step()' even though this "
|
|
"environment has already returned done = True. You "
|
|
"should always call 'reset()' once you receive 'done = "
|
|
"True' -- any further steps are undefined behavior."
|
|
)
|
|
self.steps_beyond_done += 1
|
|
reward = 0.0
|
|
|
|
return np.array(self.state, dtype=np.float32), reward, done, {}
|
|
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: Optional[int] = None,
|
|
return_info: bool = False,
|
|
options: Optional[dict] = None,
|
|
):
|
|
super().reset(seed=seed)
|
|
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
|
self.steps_beyond_done = None
|
|
if not return_info:
|
|
return np.array(self.state, dtype=np.float32)
|
|
else:
|
|
return np.array(self.state, dtype=np.float32), {}
|
|
|
|
def render(self, mode="human"):
|
|
screen_width = 600
|
|
screen_height = 400
|
|
|
|
world_width = self.x_threshold * 2
|
|
scale = screen_width / world_width
|
|
carty = 100 # TOP OF CART
|
|
polewidth = 10.0
|
|
polelen = scale * (2 * self.length)
|
|
cartwidth = 50.0
|
|
cartheight = 30.0
|
|
|
|
if self.viewer is None:
|
|
from gym.utils import pyglet_rendering
|
|
|
|
self.viewer = pyglet_rendering.Viewer(screen_width, screen_height)
|
|
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
|
|
axleoffset = cartheight / 4.0
|
|
cart = pyglet_rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
|
|
self.carttrans = pyglet_rendering.Transform()
|
|
cart.add_attr(self.carttrans)
|
|
self.viewer.add_geom(cart)
|
|
l, r, t, b = (
|
|
-polewidth / 2,
|
|
polewidth / 2,
|
|
polelen - polewidth / 2,
|
|
-polewidth / 2,
|
|
)
|
|
pole = pyglet_rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
|
|
pole.set_color(0.8, 0.6, 0.4)
|
|
self.poletrans = pyglet_rendering.Transform(translation=(0, axleoffset))
|
|
pole.add_attr(self.poletrans)
|
|
pole.add_attr(self.carttrans)
|
|
self.viewer.add_geom(pole)
|
|
self.axle = pyglet_rendering.make_circle(polewidth / 2)
|
|
self.axle.add_attr(self.poletrans)
|
|
self.axle.add_attr(self.carttrans)
|
|
self.axle.set_color(0.5, 0.5, 0.8)
|
|
self.viewer.add_geom(self.axle)
|
|
self.track = pyglet_rendering.Line((0, carty), (screen_width, carty))
|
|
self.track.set_color(0, 0, 0)
|
|
self.viewer.add_geom(self.track)
|
|
|
|
self._pole_geom = pole
|
|
|
|
if self.state is None:
|
|
return None
|
|
|
|
# Edit the pole polygon vertex
|
|
pole = self._pole_geom
|
|
l, r, t, b = (
|
|
-polewidth / 2,
|
|
polewidth / 2,
|
|
polelen - polewidth / 2,
|
|
-polewidth / 2,
|
|
)
|
|
pole.v = [(l, b), (l, t), (r, t), (r, b)]
|
|
|
|
x = self.state
|
|
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART
|
|
self.carttrans.set_translation(cartx, carty)
|
|
self.poletrans.set_rotation(-x[2])
|
|
|
|
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
|
|
|
def close(self):
|
|
if self.viewer:
|
|
self.viewer.close()
|
|
self.viewer = None
|