mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
import gym
|
|
from gym import spaces
|
|
import numpy as np
|
|
from os import path
|
|
|
|
class PendulumEnv(gym.Env):
|
|
metadata = {
|
|
'render.modes' : ['human', 'rgb_array'],
|
|
'video.frames_per_second' : 30
|
|
}
|
|
|
|
def __init__(self):
|
|
self.max_speed=8
|
|
self.max_torque=2.
|
|
self.dt=.05
|
|
self.viewer = None
|
|
|
|
high = np.array([1., 1., self.max_speed])
|
|
self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,))
|
|
self.observation_space = spaces.Box(low=-high, high=high)
|
|
|
|
def _step(self,u):
|
|
th, thdot = self.state # th := theta
|
|
|
|
g = 10.
|
|
m = 1.
|
|
l = 1.
|
|
dt = self.dt
|
|
|
|
self.last_u = u # for rendering
|
|
u = np.clip(u, -self.max_torque, self.max_torque)[0]
|
|
costs = angle_normalize(th)**2 + .1*thdot**2 + .001*(u**2)
|
|
|
|
newthdot = thdot + (-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt
|
|
newth = th + newthdot*dt
|
|
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) #pylint: disable=E1111
|
|
|
|
self.state = np.array([newth, newthdot])
|
|
return self._get_obs(), -costs, False, {}
|
|
|
|
def _reset(self):
|
|
high = np.array([np.pi, 1])
|
|
self.state = np.random.uniform(low=-high, high=high)
|
|
self.last_u = None
|
|
return self._get_obs()
|
|
|
|
def _get_obs(self):
|
|
theta, thetadot = self.state
|
|
return np.array([np.cos(theta), np.sin(theta), thetadot])
|
|
|
|
def _render(self, mode='human', close=False):
|
|
if close:
|
|
if self.viewer is not None:
|
|
self.viewer.close()
|
|
return
|
|
|
|
if self.viewer is None:
|
|
from gym.envs.classic_control import rendering
|
|
self.viewer = rendering.Viewer(500,500)
|
|
self.viewer.set_bounds(-2.2,2.2,-2.2,2.2)
|
|
rod = rendering.make_capsule(1, .2)
|
|
rod.set_color(.8, .3, .3)
|
|
self.pole_transform = rendering.Transform()
|
|
rod.add_attr(self.pole_transform)
|
|
self.viewer.add_geom(rod)
|
|
axle = rendering.make_circle(.05)
|
|
axle.set_color(0,0,0)
|
|
self.viewer.add_geom(axle)
|
|
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
|
|
self.img = rendering.Image(fname, 1., 1.)
|
|
self.imgtrans = rendering.Transform()
|
|
self.img.add_attr(self.imgtrans)
|
|
|
|
self.viewer.add_onetime(self.img)
|
|
self.pole_transform.set_rotation(self.state[0] + np.pi/2)
|
|
if self.last_u:
|
|
self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2)
|
|
|
|
|
|
self.viewer.render()
|
|
if mode == 'rgb_array':
|
|
return self.viewer.get_array()
|
|
elif mode is 'human':
|
|
pass
|
|
else:
|
|
return super(PendulumEnv, self).render(mode=mode)
|
|
|
|
def angle_normalize(x):
|
|
return (((x+np.pi) % (2*np.pi)) - np.pi)
|