From 51a735e752d41f190e737e1292d3f8483d69baf9 Mon Sep 17 00:00:00 2001 From: John Balis Date: Sun, 9 Apr 2023 05:26:37 -0500 Subject: [PATCH] Jax Cliffwalking Env (#407) --- gymnasium/envs/__init__.py | 5 + gymnasium/envs/tabular/__init__.py | 1 + gymnasium/envs/tabular/cliffwalking.py | 386 ++++++++++++++++++ .../functional/test_jax_cliffwalking.py | 125 ++++++ 4 files changed, 517 insertions(+) create mode 100644 gymnasium/envs/tabular/cliffwalking.py create mode 100644 tests/experimental/functional/test_jax_cliffwalking.py diff --git a/gymnasium/envs/__init__.py b/gymnasium/envs/__init__.py index 4204eb7fd..eb1bf71f1 100644 --- a/gymnasium/envs/__init__.py +++ b/gymnasium/envs/__init__.py @@ -171,6 +171,11 @@ register( kwargs={"sutton_and_barto": True, "natural": False}, ) +register( + id="tablular/CliffWalking-v0", + entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv", +) + # Mujoco # ---------------------------------------- diff --git a/gymnasium/envs/tabular/__init__.py b/gymnasium/envs/tabular/__init__.py index e75f64364..e15a61c04 100644 --- a/gymnasium/envs/tabular/__init__.py +++ b/gymnasium/envs/tabular/__init__.py @@ -1,3 +1,4 @@ """Provides Tabular JAX FuncEnv implementations.""" from gymnasium.envs.tabular.blackjack import BlackJackJaxEnv +from gymnasium.envs.tabular.cliffwalking import CliffWalkingJaxEnv diff --git a/gymnasium/envs/tabular/cliffwalking.py b/gymnasium/envs/tabular/cliffwalking.py new file mode 100644 index 000000000..ffde1b5e4 --- /dev/null +++ b/gymnasium/envs/tabular/cliffwalking.py @@ -0,0 +1,386 @@ +"""This module provides a CliffWalking functional environment and Gymnasium environment wrapper CliffWalkingJaxEnv.""" + + +from __future__ import annotations + +from os import path +from typing import TYPE_CHECKING, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np +from jax.random import PRNGKey + +from gymnasium import spaces +from gymnasium.error import DependencyNotInstalled +from gymnasium.experimental.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv +from gymnasium.utils import EzPickle +from gymnasium.wrappers import HumanRendering + + +if TYPE_CHECKING: + import numpy + import pygame + + +class RenderStateType(NamedTuple): + """A named tuple which contains the full render state of the Cliffwalking Env. This is static during the episode.""" + + screen: pygame.surface + shape: tuple[int, int] + nS: int + cell_size: tuple[int, int] + cliff: numpy.ndarray + elf_images: tuple[pygame.Surface, pygame.Surface, pygame.Surface, pygame.Surface] + start_img: pygame.Surface + goal_img: pygame.Surface + bg_imgs: tuple[str, str] + mountain_bg_img: tuple[pygame.Surface, pygame.Surface] + near_cliff_imgs: tuple[str, str] + near_cliff_img: tuple[pygame.Surface, pygame.Surface] + cliff_img: pygame.Surface + + +# RenderStateType =RenderState #Tuple["pygame.Surface", Tuple[int, int], int, Tuple[int, int], "numpy.ndarray", Tuple["pygame.Surface", "pygame.Surface", "pygame.Surface", "pygame.Surface"], "pygame.Surface", "pygame.Surface", Tuple[str, str], Tuple["pygame.surface", "pygame.surface"], Tuple[str, str], Tuple["pygame.surface", "pygame.surface"], "pygame.surface"] + + +class EnvState(NamedTuple): + """A named tuple which contains the full state of the Cliffwalking game.""" + + player_position: jnp.array + last_action: int + fallen: bool + + +def fell_off(player_position): + """Checks to see if the player_position means the player has fallen of the cliff.""" + return ( + (player_position[0] == 3) + * (player_position[1] >= 1) + * (player_position[1] <= 10) + ) + + +class CliffWalkingFunctional( + FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType] +): + """Cliff walking involves crossing a gridworld from start to goal while avoiding falling off a cliff. + + ## Description + The game starts with the player at location [3, 0] of the 4x12 grid world with the + goal located at [3, 11]. If the player reaches the goal the episode ends. + + A cliff runs along [3, 1..10]. If the player moves to a cliff location it + returns to the start location. + + The player makes moves until they reach the goal. + + Adapted from Example 6.6 (page 132) from Reinforcement Learning: An Introduction + by Sutton and Barto [1]. + + With inspiration from: + [https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/cliff_walking.py](https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/cliff_walking.py) + + ## Action Space + The action shape is `(1,)` in the range `{0, 3}` indicating + which direction to move the player. + + - 0: Move up + - 1: Move right + - 2: Move down + - 3: Move left + + ## Observation Space + There are 3 x 12 + 1 possible states. The player cannot be at the cliff, nor at + the goal as the latter results in the end of the episode. What remains are all + the positions of the first 3 rows plus the bottom-left cell. + + The observation is a value representing the player's current position as + current_row * nrows + current_col (where both the row and col start at 0). + + For example, the stating position can be calculated as follows: 3 * 12 + 0 = 36. + + The observation is returned as an `numpy.ndarray` with shape `(1,)` and dtype `numpy.int32` . + + ## Starting State + The episode starts with the player in state `[36]` (location [3, 0]). + + ## Reward + Each time step incurs -1 reward, unless the player stepped into the cliff, + which incurs -100 reward. + + ## Episode End + The episode terminates when the player enters state `[47]` (location [3, 11]). + + + ## Arguments + + ```python + import gymnasium as gym + gym.make('tablular/CliffWalking-v0') + ``` + + ## References + [1] R. Sutton and A. Barto, “Reinforcement Learning: + An Introduction” 2020. [Online]. Available: [http://www.incompleteideas.net/book/RLbook2020.pdf](http://www.incompleteideas.net/book/RLbook2020.pdf) + + ## Version History + - v0: Initial version release + + """ + + action_space = spaces.Box(low=0, high=3, dtype=np.int32) # 4 directions + observation_space = spaces.Box( + low=0, high=(12 * 4) - 1, shape=(1,), dtype=np.int32 + ) # A discrete state corresponds to each possible location + + metadata = { + "render_modes": ["rgb_array"], + "render_fps": 4, + } + + def transition(self, state: EnvState, action: int | jnp.ndarray, key: PRNGKey): + """The Cliffwalking environment's state transition function.""" + new_position = state.player_position + + # where is the agent trying to go? + new_position = jnp.array( + [ + new_position[0] + (1 * (action == 2)) + (-1 * (action == 0)), + new_position[1] + (1 * (action == 1)) + (-1 * (action == 3)), + ] + ) + + # prevent out of bounds + new_position = jnp.array( + [ + jnp.maximum(jnp.minimum(new_position[0], 3), 0), + jnp.maximum(jnp.minimum(new_position[1], 11), 0), + ] + ) + + # if we fell off, we have to start over from scratch from (3,0) + fallen = fell_off(new_position) + new_position = jnp.array( + [ + new_position[0] * (1 - fallen) + 3 * fallen, + new_position[1] * (1 - fallen), + ] + ) + new_state = EnvState( + player_position=new_position.reshape((2,)), + last_action=action[0], + fallen=fallen, + ) + + return new_state + + def initial(self, rng: PRNGKey): + """Cliffwalking initial observation function.""" + player_position = jnp.array([3, 0]) + + state = EnvState(player_position=player_position, last_action=-1, fallen=False) + + return state + + def observation(self, state: EnvState) -> int: + """Cliffwalking observation.""" + return jnp.array( + state.player_position[0] * 12 + state.player_position[1] + ).reshape((1,)) + + def terminal(self, state: EnvState) -> jnp.ndarray: + """Determines if a particular Cliffwalking observation is terminal.""" + return jnp.array_equal(state.player_position, jnp.array([3, 11])) + + def reward( + self, state: EnvState, action: ActType, next_state: StateType + ) -> jnp.ndarray: + """Calculates reward from a state.""" + state = next_state + reward = -1 + (-99 * state.fallen[0]) + return jax.lax.convert_element_type(reward, jnp.float32) + + def render_init( + self, screen_width: int = 600, screen_height: int = 500 + ) -> RenderStateType: + """Returns an initial render state.""" + try: + import pygame + except ImportError: + raise DependencyNotInstalled( + "pygame is not installed, run `pip install gymnasium[classic_control]`" + ) + + cell_size = (60, 60) + window_size = ( + 4 * cell_size[0], + 12 * cell_size[1], + ) + + pygame.init() + screen = pygame.Surface((window_size[1], window_size[0])) + + shape = (4, 12) + nS = 4 * 12 + # Cliff Location + cliff = np.zeros(shape, dtype=bool) + cliff[3, 1:-1] = True + + hikers = [ + path.join(path.dirname(__file__), "../toy_text/img/elf_up.png"), + path.join(path.dirname(__file__), "../toy_text/img/elf_right.png"), + path.join(path.dirname(__file__), "../toy_text/img/elf_down.png"), + path.join(path.dirname(__file__), "../toy_text/img/elf_left.png"), + ] + + cell_size = (60, 60) + + elf_images = [ + pygame.transform.scale(pygame.image.load(f_name), cell_size) + for f_name in hikers + ] + file_name = path.join(path.dirname(__file__), "../toy_text/img/stool.png") + start_img = pygame.transform.scale(pygame.image.load(file_name), cell_size) + file_name = path.join(path.dirname(__file__), "../toy_text/img/cookie.png") + goal_img = pygame.transform.scale(pygame.image.load(file_name), cell_size) + bg_imgs = [ + path.join(path.dirname(__file__), "../toy_text/img/mountain_bg1.png"), + path.join(path.dirname(__file__), "../toy_text/img/mountain_bg2.png"), + ] + mountain_bg_img = [ + pygame.transform.scale(pygame.image.load(f_name), cell_size) + for f_name in bg_imgs + ] + near_cliff_imgs = [ + path.join( + path.dirname(__file__), "../toy_text/img/mountain_near-cliff1.png" + ), + path.join( + path.dirname(__file__), "../toy_text/img/mountain_near-cliff2.png" + ), + ] + near_cliff_img = [ + pygame.transform.scale(pygame.image.load(f_name), cell_size) + for f_name in near_cliff_imgs + ] + file_name = path.join( + path.dirname(__file__), "../toy_text/img/mountain_cliff.png" + ) + cliff_img = pygame.transform.scale(pygame.image.load(file_name), cell_size) + + return RenderStateType( + screen=screen, + shape=shape, + nS=nS, + cell_size=cell_size, + cliff=cliff, + elf_images=tuple(elf_images), + start_img=start_img, + goal_img=goal_img, + bg_imgs=tuple(bg_imgs), + mountain_bg_img=tuple(mountain_bg_img), + near_cliff_imgs=tuple(near_cliff_imgs), + near_cliff_img=tuple(near_cliff_img), + cliff_img=cliff_img, + ) + + def render_image( + self, + state: StateType, + render_state: RenderStateType, + ) -> tuple[RenderStateType, np.ndarray]: + """Renders an image from a state.""" + try: + import pygame + except ImportError: + raise DependencyNotInstalled( + "pygame is not installed, run `pip install gymnasium[toy_text]`" + ) + ( + window_surface, + shape, + nS, + cell_size, + cliff, + elf_images, + start_img, + goal_img, + bg_imgs, + mountain_bg_img, + near_cliff_imgs, + near_cliff_img, + cliff_img, + ) = render_state + + for s in range(nS): + row, col = np.unravel_index(s, shape) + pos = (col * cell_size[0], row * cell_size[1]) + check_board_mask = row % 2 ^ col % 2 + window_surface.blit(mountain_bg_img[check_board_mask], pos) + + if cliff[row, col]: + window_surface.blit(cliff_img, pos) + if row < shape[0] - 1 and cliff[row + 1, col]: + window_surface.blit(near_cliff_img[check_board_mask], pos) + if s == 36: + window_surface.blit(start_img, pos) + if s == nS - 1: + window_surface.blit(goal_img, pos) + if s == state.player_position[0] * 12 + state.player_position[1]: + elf_pos = (pos[0], pos[1] - 0.1 * cell_size[1]) + last_action = state.last_action if state.last_action != -1 else 2 + window_surface.blit(elf_images[last_action], elf_pos) + + return render_state, np.transpose( + np.array(pygame.surfarray.pixels3d(window_surface)), axes=(1, 0, 2) + ) + + def render_close(self, render_state: RenderStateType) -> None: + """Closes the render state.""" + try: + import pygame + except ImportError as e: + raise DependencyNotInstalled( + "pygame is not installed, run `pip install gymnasium[toy-text]`" + ) from e + pygame.display.quit() + pygame.quit() + + +class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle): + """A Gymnasium Env wrapper for the functional cliffwalking env.""" + + metadata = {"render_modes": ["rgb_array"], "render_fps": 50} + + def __init__(self, render_mode: str | None = None, **kwargs): + """Initializes Gym wrapper for cliffwalking functional env.""" + EzPickle.__init__(self, render_mode=render_mode, **kwargs) + env = CliffWalkingFunctional(**kwargs) + env.transform(jax.jit) + + super().__init__( + env, + metadata=self.metadata, + render_mode=render_mode, + ) + + +if __name__ == "__main__": + """ + Temporary environment tester function. + """ + + env = HumanRendering(CliffWalkingJaxEnv(render_mode="rgb_array")) + + obs, info = env.reset() + print(obs, info) + + terminal = False + while not terminal: + action = int(input("Please input an action\n")) + obs, reward, terminal, truncated, info = env.step(action) + print(obs, reward, terminal, truncated, info) + + exit() diff --git a/tests/experimental/functional/test_jax_cliffwalking.py b/tests/experimental/functional/test_jax_cliffwalking.py new file mode 100644 index 000000000..029a4a76f --- /dev/null +++ b/tests/experimental/functional/test_jax_cliffwalking.py @@ -0,0 +1,125 @@ +"""Tests for Jax cliffwalking functional env.""" + + +import jax +import jax.numpy as jnp +import jax.random as jrng +import pytest + +from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional + + +def test_normal_CliffWalkingFunctional(): + """Tests to ensure that cliffwalking env step and reset functions return the correct types.""" + env = CliffWalkingFunctional() + rng = jrng.PRNGKey(0) + + split_rng, rng = jrng.split(rng) + + state = env.initial(split_rng) + env.action_space.seed(0) + + for t in range(10): + obs = env.observation(state) + action = env.action_space.sample() + + split_rng, rng = jrng.split(rng) + + next_state = env.transition(state, action, split_rng) + reward = env.reward(state, action, next_state) + terminal = env.terminal(next_state) + + assert len(state) == len(next_state) + try: + float(reward) + except ValueError: + pytest.fail("Reward is not castable to float") + try: + bool(terminal) + except ValueError: + pytest.fail("Terminal is not castable to bool") + + assert next_state[0].dtype == jnp.int32 + assert next_state[1].dtype == jnp.int32 + assert next_state[2].dtype == bool + + assert rng.dtype == jnp.uint32 + assert obs.dtype == jnp.int32 + + state = next_state + + +def test_jit_CliffWalkingFunctional(): + """Tests the Jax CliffWalkingFunctional env, but in a jitted context.""" + env = CliffWalkingFunctional() + rng = jrng.PRNGKey(0) + env.transform(jax.jit) + + split_rng, rng = jrng.split(rng) + + state = env.initial(split_rng) + env.action_space.seed(0) + + for t in range(10): + obs = env.observation(state) + action = env.action_space.sample() + split_rng, rng = jrng.split(rng) + next_state = env.transition(state, action, split_rng) + reward = env.reward(state, action, next_state) + terminal = env.terminal(next_state) + + assert len(state) == len(next_state) + try: + float(reward) + except ValueError: + pytest.fail("Reward is not castable to float") + try: + bool(terminal) + except ValueError: + pytest.fail("Terminal is not castable to bool") + + assert next_state[0].dtype == jnp.int32 + assert next_state[1].dtype == jnp.int32 + assert next_state[2].dtype == bool + + assert rng.dtype == jnp.uint32 + assert obs.dtype == jnp.int32 + + state = next_state + + +def test_vmap_BlackJack(): + """Tests the Jax CliffWalking env with vmap.""" + env = CliffWalkingFunctional() + num_envs = 10 + rng, *split_rng = jrng.split( + jrng.PRNGKey(0), num_envs + 1 + ) # this plus 1 is important because we want + # num_envs subkeys and a main entropy source key which necessitates an additional key + + env.transform(jax.vmap) + env.transform(jax.jit) + state = env.initial(jnp.array(split_rng)) + env.action_space.seed(0) + + for t in range(10): + obs = env.observation(state) + action = jnp.array([env.action_space.sample() for _ in range(num_envs)]) + # if isinstance(env.action_space, Discrete): + # action = action.reshape((num_envs, 1)) + rng, *split_rng = jrng.split(rng, num_envs + 1) + next_state = env.transition(state, action, jnp.array(split_rng)) + terminal = env.terminal(next_state) + reward = env.reward(state, action, next_state) + + assert len(next_state) == len(state) + assert reward.shape == (num_envs,) + assert reward.dtype == jnp.float32 + assert terminal.shape == (num_envs,) + assert terminal.dtype == bool + assert isinstance(obs, jnp.ndarray) + assert obs[0].dtype == jnp.int32 + assert obs[1].dtype == jnp.int32 + assert obs[2].dtype == jnp.int32 + + state = next_state