mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Jax Cliffwalking Env (#407)
This commit is contained in:
@@ -171,6 +171,11 @@ register(
|
||||
kwargs={"sutton_and_barto": True, "natural": False},
|
||||
)
|
||||
|
||||
register(
|
||||
id="tablular/CliffWalking-v0",
|
||||
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
|
||||
)
|
||||
|
||||
|
||||
# Mujoco
|
||||
# ----------------------------------------
|
||||
|
@@ -1,3 +1,4 @@
|
||||
"""Provides Tabular JAX FuncEnv implementations."""
|
||||
|
||||
from gymnasium.envs.tabular.blackjack import BlackJackJaxEnv
|
||||
from gymnasium.envs.tabular.cliffwalking import CliffWalkingJaxEnv
|
||||
|
386
gymnasium/envs/tabular/cliffwalking.py
Normal file
386
gymnasium/envs/tabular/cliffwalking.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""This module provides a CliffWalking functional environment and Gymnasium environment wrapper CliffWalkingJaxEnv."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import path
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
|
||||
from gymnasium.utils import EzPickle
|
||||
from gymnasium.wrappers import HumanRendering
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy
|
||||
import pygame
|
||||
|
||||
|
||||
class RenderStateType(NamedTuple):
|
||||
"""A named tuple which contains the full render state of the Cliffwalking Env. This is static during the episode."""
|
||||
|
||||
screen: pygame.surface
|
||||
shape: tuple[int, int]
|
||||
nS: int
|
||||
cell_size: tuple[int, int]
|
||||
cliff: numpy.ndarray
|
||||
elf_images: tuple[pygame.Surface, pygame.Surface, pygame.Surface, pygame.Surface]
|
||||
start_img: pygame.Surface
|
||||
goal_img: pygame.Surface
|
||||
bg_imgs: tuple[str, str]
|
||||
mountain_bg_img: tuple[pygame.Surface, pygame.Surface]
|
||||
near_cliff_imgs: tuple[str, str]
|
||||
near_cliff_img: tuple[pygame.Surface, pygame.Surface]
|
||||
cliff_img: pygame.Surface
|
||||
|
||||
|
||||
# RenderStateType =RenderState #Tuple["pygame.Surface", Tuple[int, int], int, Tuple[int, int], "numpy.ndarray", Tuple["pygame.Surface", "pygame.Surface", "pygame.Surface", "pygame.Surface"], "pygame.Surface", "pygame.Surface", Tuple[str, str], Tuple["pygame.surface", "pygame.surface"], Tuple[str, str], Tuple["pygame.surface", "pygame.surface"], "pygame.surface"]
|
||||
|
||||
|
||||
class EnvState(NamedTuple):
|
||||
"""A named tuple which contains the full state of the Cliffwalking game."""
|
||||
|
||||
player_position: jnp.array
|
||||
last_action: int
|
||||
fallen: bool
|
||||
|
||||
|
||||
def fell_off(player_position):
|
||||
"""Checks to see if the player_position means the player has fallen of the cliff."""
|
||||
return (
|
||||
(player_position[0] == 3)
|
||||
* (player_position[1] >= 1)
|
||||
* (player_position[1] <= 10)
|
||||
)
|
||||
|
||||
|
||||
class CliffWalkingFunctional(
|
||||
FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]
|
||||
):
|
||||
"""Cliff walking involves crossing a gridworld from start to goal while avoiding falling off a cliff.
|
||||
|
||||
## Description
|
||||
The game starts with the player at location [3, 0] of the 4x12 grid world with the
|
||||
goal located at [3, 11]. If the player reaches the goal the episode ends.
|
||||
|
||||
A cliff runs along [3, 1..10]. If the player moves to a cliff location it
|
||||
returns to the start location.
|
||||
|
||||
The player makes moves until they reach the goal.
|
||||
|
||||
Adapted from Example 6.6 (page 132) from Reinforcement Learning: An Introduction
|
||||
by Sutton and Barto [<a href="#cliffwalk_ref">1</a>].
|
||||
|
||||
With inspiration from:
|
||||
[https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/cliff_walking.py](https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/cliff_walking.py)
|
||||
|
||||
## Action Space
|
||||
The action shape is `(1,)` in the range `{0, 3}` indicating
|
||||
which direction to move the player.
|
||||
|
||||
- 0: Move up
|
||||
- 1: Move right
|
||||
- 2: Move down
|
||||
- 3: Move left
|
||||
|
||||
## Observation Space
|
||||
There are 3 x 12 + 1 possible states. The player cannot be at the cliff, nor at
|
||||
the goal as the latter results in the end of the episode. What remains are all
|
||||
the positions of the first 3 rows plus the bottom-left cell.
|
||||
|
||||
The observation is a value representing the player's current position as
|
||||
current_row * nrows + current_col (where both the row and col start at 0).
|
||||
|
||||
For example, the stating position can be calculated as follows: 3 * 12 + 0 = 36.
|
||||
|
||||
The observation is returned as an `numpy.ndarray` with shape `(1,)` and dtype `numpy.int32` .
|
||||
|
||||
## Starting State
|
||||
The episode starts with the player in state `[36]` (location [3, 0]).
|
||||
|
||||
## Reward
|
||||
Each time step incurs -1 reward, unless the player stepped into the cliff,
|
||||
which incurs -100 reward.
|
||||
|
||||
## Episode End
|
||||
The episode terminates when the player enters state `[47]` (location [3, 11]).
|
||||
|
||||
|
||||
## Arguments
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
gym.make('tablular/CliffWalking-v0')
|
||||
```
|
||||
|
||||
## References
|
||||
<a id="cliffwalk_ref"></a>[1] R. Sutton and A. Barto, “Reinforcement Learning:
|
||||
An Introduction” 2020. [Online]. Available: [http://www.incompleteideas.net/book/RLbook2020.pdf](http://www.incompleteideas.net/book/RLbook2020.pdf)
|
||||
|
||||
## Version History
|
||||
- v0: Initial version release
|
||||
|
||||
"""
|
||||
|
||||
action_space = spaces.Box(low=0, high=3, dtype=np.int32) # 4 directions
|
||||
observation_space = spaces.Box(
|
||||
low=0, high=(12 * 4) - 1, shape=(1,), dtype=np.int32
|
||||
) # A discrete state corresponds to each possible location
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["rgb_array"],
|
||||
"render_fps": 4,
|
||||
}
|
||||
|
||||
def transition(self, state: EnvState, action: int | jnp.ndarray, key: PRNGKey):
|
||||
"""The Cliffwalking environment's state transition function."""
|
||||
new_position = state.player_position
|
||||
|
||||
# where is the agent trying to go?
|
||||
new_position = jnp.array(
|
||||
[
|
||||
new_position[0] + (1 * (action == 2)) + (-1 * (action == 0)),
|
||||
new_position[1] + (1 * (action == 1)) + (-1 * (action == 3)),
|
||||
]
|
||||
)
|
||||
|
||||
# prevent out of bounds
|
||||
new_position = jnp.array(
|
||||
[
|
||||
jnp.maximum(jnp.minimum(new_position[0], 3), 0),
|
||||
jnp.maximum(jnp.minimum(new_position[1], 11), 0),
|
||||
]
|
||||
)
|
||||
|
||||
# if we fell off, we have to start over from scratch from (3,0)
|
||||
fallen = fell_off(new_position)
|
||||
new_position = jnp.array(
|
||||
[
|
||||
new_position[0] * (1 - fallen) + 3 * fallen,
|
||||
new_position[1] * (1 - fallen),
|
||||
]
|
||||
)
|
||||
new_state = EnvState(
|
||||
player_position=new_position.reshape((2,)),
|
||||
last_action=action[0],
|
||||
fallen=fallen,
|
||||
)
|
||||
|
||||
return new_state
|
||||
|
||||
def initial(self, rng: PRNGKey):
|
||||
"""Cliffwalking initial observation function."""
|
||||
player_position = jnp.array([3, 0])
|
||||
|
||||
state = EnvState(player_position=player_position, last_action=-1, fallen=False)
|
||||
|
||||
return state
|
||||
|
||||
def observation(self, state: EnvState) -> int:
|
||||
"""Cliffwalking observation."""
|
||||
return jnp.array(
|
||||
state.player_position[0] * 12 + state.player_position[1]
|
||||
).reshape((1,))
|
||||
|
||||
def terminal(self, state: EnvState) -> jnp.ndarray:
|
||||
"""Determines if a particular Cliffwalking observation is terminal."""
|
||||
return jnp.array_equal(state.player_position, jnp.array([3, 11]))
|
||||
|
||||
def reward(
|
||||
self, state: EnvState, action: ActType, next_state: StateType
|
||||
) -> jnp.ndarray:
|
||||
"""Calculates reward from a state."""
|
||||
state = next_state
|
||||
reward = -1 + (-99 * state.fallen[0])
|
||||
return jax.lax.convert_element_type(reward, jnp.float32)
|
||||
|
||||
def render_init(
|
||||
self, screen_width: int = 600, screen_height: int = 500
|
||||
) -> RenderStateType:
|
||||
"""Returns an initial render state."""
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
|
||||
cell_size = (60, 60)
|
||||
window_size = (
|
||||
4 * cell_size[0],
|
||||
12 * cell_size[1],
|
||||
)
|
||||
|
||||
pygame.init()
|
||||
screen = pygame.Surface((window_size[1], window_size[0]))
|
||||
|
||||
shape = (4, 12)
|
||||
nS = 4 * 12
|
||||
# Cliff Location
|
||||
cliff = np.zeros(shape, dtype=bool)
|
||||
cliff[3, 1:-1] = True
|
||||
|
||||
hikers = [
|
||||
path.join(path.dirname(__file__), "../toy_text/img/elf_up.png"),
|
||||
path.join(path.dirname(__file__), "../toy_text/img/elf_right.png"),
|
||||
path.join(path.dirname(__file__), "../toy_text/img/elf_down.png"),
|
||||
path.join(path.dirname(__file__), "../toy_text/img/elf_left.png"),
|
||||
]
|
||||
|
||||
cell_size = (60, 60)
|
||||
|
||||
elf_images = [
|
||||
pygame.transform.scale(pygame.image.load(f_name), cell_size)
|
||||
for f_name in hikers
|
||||
]
|
||||
file_name = path.join(path.dirname(__file__), "../toy_text/img/stool.png")
|
||||
start_img = pygame.transform.scale(pygame.image.load(file_name), cell_size)
|
||||
file_name = path.join(path.dirname(__file__), "../toy_text/img/cookie.png")
|
||||
goal_img = pygame.transform.scale(pygame.image.load(file_name), cell_size)
|
||||
bg_imgs = [
|
||||
path.join(path.dirname(__file__), "../toy_text/img/mountain_bg1.png"),
|
||||
path.join(path.dirname(__file__), "../toy_text/img/mountain_bg2.png"),
|
||||
]
|
||||
mountain_bg_img = [
|
||||
pygame.transform.scale(pygame.image.load(f_name), cell_size)
|
||||
for f_name in bg_imgs
|
||||
]
|
||||
near_cliff_imgs = [
|
||||
path.join(
|
||||
path.dirname(__file__), "../toy_text/img/mountain_near-cliff1.png"
|
||||
),
|
||||
path.join(
|
||||
path.dirname(__file__), "../toy_text/img/mountain_near-cliff2.png"
|
||||
),
|
||||
]
|
||||
near_cliff_img = [
|
||||
pygame.transform.scale(pygame.image.load(f_name), cell_size)
|
||||
for f_name in near_cliff_imgs
|
||||
]
|
||||
file_name = path.join(
|
||||
path.dirname(__file__), "../toy_text/img/mountain_cliff.png"
|
||||
)
|
||||
cliff_img = pygame.transform.scale(pygame.image.load(file_name), cell_size)
|
||||
|
||||
return RenderStateType(
|
||||
screen=screen,
|
||||
shape=shape,
|
||||
nS=nS,
|
||||
cell_size=cell_size,
|
||||
cliff=cliff,
|
||||
elf_images=tuple(elf_images),
|
||||
start_img=start_img,
|
||||
goal_img=goal_img,
|
||||
bg_imgs=tuple(bg_imgs),
|
||||
mountain_bg_img=tuple(mountain_bg_img),
|
||||
near_cliff_imgs=tuple(near_cliff_imgs),
|
||||
near_cliff_img=tuple(near_cliff_img),
|
||||
cliff_img=cliff_img,
|
||||
)
|
||||
|
||||
def render_image(
|
||||
self,
|
||||
state: StateType,
|
||||
render_state: RenderStateType,
|
||||
) -> tuple[RenderStateType, np.ndarray]:
|
||||
"""Renders an image from a state."""
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[toy_text]`"
|
||||
)
|
||||
(
|
||||
window_surface,
|
||||
shape,
|
||||
nS,
|
||||
cell_size,
|
||||
cliff,
|
||||
elf_images,
|
||||
start_img,
|
||||
goal_img,
|
||||
bg_imgs,
|
||||
mountain_bg_img,
|
||||
near_cliff_imgs,
|
||||
near_cliff_img,
|
||||
cliff_img,
|
||||
) = render_state
|
||||
|
||||
for s in range(nS):
|
||||
row, col = np.unravel_index(s, shape)
|
||||
pos = (col * cell_size[0], row * cell_size[1])
|
||||
check_board_mask = row % 2 ^ col % 2
|
||||
window_surface.blit(mountain_bg_img[check_board_mask], pos)
|
||||
|
||||
if cliff[row, col]:
|
||||
window_surface.blit(cliff_img, pos)
|
||||
if row < shape[0] - 1 and cliff[row + 1, col]:
|
||||
window_surface.blit(near_cliff_img[check_board_mask], pos)
|
||||
if s == 36:
|
||||
window_surface.blit(start_img, pos)
|
||||
if s == nS - 1:
|
||||
window_surface.blit(goal_img, pos)
|
||||
if s == state.player_position[0] * 12 + state.player_position[1]:
|
||||
elf_pos = (pos[0], pos[1] - 0.1 * cell_size[1])
|
||||
last_action = state.last_action if state.last_action != -1 else 2
|
||||
window_surface.blit(elf_images[last_action], elf_pos)
|
||||
|
||||
return render_state, np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(window_surface)), axes=(1, 0, 2)
|
||||
)
|
||||
|
||||
def render_close(self, render_state: RenderStateType) -> None:
|
||||
"""Closes the render state."""
|
||||
try:
|
||||
import pygame
|
||||
except ImportError as e:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[toy-text]`"
|
||||
) from e
|
||||
pygame.display.quit()
|
||||
pygame.quit()
|
||||
|
||||
|
||||
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
"""A Gymnasium Env wrapper for the functional cliffwalking env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
|
||||
def __init__(self, render_mode: str | None = None, **kwargs):
|
||||
"""Initializes Gym wrapper for cliffwalking functional env."""
|
||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||
env = CliffWalkingFunctional(**kwargs)
|
||||
env.transform(jax.jit)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
metadata=self.metadata,
|
||||
render_mode=render_mode,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Temporary environment tester function.
|
||||
"""
|
||||
|
||||
env = HumanRendering(CliffWalkingJaxEnv(render_mode="rgb_array"))
|
||||
|
||||
obs, info = env.reset()
|
||||
print(obs, info)
|
||||
|
||||
terminal = False
|
||||
while not terminal:
|
||||
action = int(input("Please input an action\n"))
|
||||
obs, reward, terminal, truncated, info = env.step(action)
|
||||
print(obs, reward, terminal, truncated, info)
|
||||
|
||||
exit()
|
125
tests/experimental/functional/test_jax_cliffwalking.py
Normal file
125
tests/experimental/functional/test_jax_cliffwalking.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for Jax cliffwalking functional env."""
|
||||
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional
|
||||
|
||||
|
||||
def test_normal_CliffWalkingFunctional():
|
||||
"""Tests to ensure that cliffwalking env step and reset functions return the correct types."""
|
||||
env = CliffWalkingFunctional()
|
||||
rng = jrng.PRNGKey(0)
|
||||
|
||||
split_rng, rng = jrng.split(rng)
|
||||
|
||||
state = env.initial(split_rng)
|
||||
env.action_space.seed(0)
|
||||
|
||||
for t in range(10):
|
||||
obs = env.observation(state)
|
||||
action = env.action_space.sample()
|
||||
|
||||
split_rng, rng = jrng.split(rng)
|
||||
|
||||
next_state = env.transition(state, action, split_rng)
|
||||
reward = env.reward(state, action, next_state)
|
||||
terminal = env.terminal(next_state)
|
||||
|
||||
assert len(state) == len(next_state)
|
||||
try:
|
||||
float(reward)
|
||||
except ValueError:
|
||||
pytest.fail("Reward is not castable to float")
|
||||
try:
|
||||
bool(terminal)
|
||||
except ValueError:
|
||||
pytest.fail("Terminal is not castable to bool")
|
||||
|
||||
assert next_state[0].dtype == jnp.int32
|
||||
assert next_state[1].dtype == jnp.int32
|
||||
assert next_state[2].dtype == bool
|
||||
|
||||
assert rng.dtype == jnp.uint32
|
||||
assert obs.dtype == jnp.int32
|
||||
|
||||
state = next_state
|
||||
|
||||
|
||||
def test_jit_CliffWalkingFunctional():
|
||||
"""Tests the Jax CliffWalkingFunctional env, but in a jitted context."""
|
||||
env = CliffWalkingFunctional()
|
||||
rng = jrng.PRNGKey(0)
|
||||
env.transform(jax.jit)
|
||||
|
||||
split_rng, rng = jrng.split(rng)
|
||||
|
||||
state = env.initial(split_rng)
|
||||
env.action_space.seed(0)
|
||||
|
||||
for t in range(10):
|
||||
obs = env.observation(state)
|
||||
action = env.action_space.sample()
|
||||
split_rng, rng = jrng.split(rng)
|
||||
next_state = env.transition(state, action, split_rng)
|
||||
reward = env.reward(state, action, next_state)
|
||||
terminal = env.terminal(next_state)
|
||||
|
||||
assert len(state) == len(next_state)
|
||||
try:
|
||||
float(reward)
|
||||
except ValueError:
|
||||
pytest.fail("Reward is not castable to float")
|
||||
try:
|
||||
bool(terminal)
|
||||
except ValueError:
|
||||
pytest.fail("Terminal is not castable to bool")
|
||||
|
||||
assert next_state[0].dtype == jnp.int32
|
||||
assert next_state[1].dtype == jnp.int32
|
||||
assert next_state[2].dtype == bool
|
||||
|
||||
assert rng.dtype == jnp.uint32
|
||||
assert obs.dtype == jnp.int32
|
||||
|
||||
state = next_state
|
||||
|
||||
|
||||
def test_vmap_BlackJack():
|
||||
"""Tests the Jax CliffWalking env with vmap."""
|
||||
env = CliffWalkingFunctional()
|
||||
num_envs = 10
|
||||
rng, *split_rng = jrng.split(
|
||||
jrng.PRNGKey(0), num_envs + 1
|
||||
) # this plus 1 is important because we want
|
||||
# num_envs subkeys and a main entropy source key which necessitates an additional key
|
||||
|
||||
env.transform(jax.vmap)
|
||||
env.transform(jax.jit)
|
||||
state = env.initial(jnp.array(split_rng))
|
||||
env.action_space.seed(0)
|
||||
|
||||
for t in range(10):
|
||||
obs = env.observation(state)
|
||||
action = jnp.array([env.action_space.sample() for _ in range(num_envs)])
|
||||
# if isinstance(env.action_space, Discrete):
|
||||
# action = action.reshape((num_envs, 1))
|
||||
rng, *split_rng = jrng.split(rng, num_envs + 1)
|
||||
next_state = env.transition(state, action, jnp.array(split_rng))
|
||||
terminal = env.terminal(next_state)
|
||||
reward = env.reward(state, action, next_state)
|
||||
|
||||
assert len(next_state) == len(state)
|
||||
assert reward.shape == (num_envs,)
|
||||
assert reward.dtype == jnp.float32
|
||||
assert terminal.shape == (num_envs,)
|
||||
assert terminal.dtype == bool
|
||||
assert isinstance(obs, jnp.ndarray)
|
||||
assert obs[0].dtype == jnp.int32
|
||||
assert obs[1].dtype == jnp.int32
|
||||
assert obs[2].dtype == jnp.int32
|
||||
|
||||
state = next_state
|
Reference in New Issue
Block a user