Files
Gymnasium/gym/envs/toy_text/cliffwalking.py

148 lines
4.9 KiB
Python
Raw Normal View History

import sys
2020-06-20 00:03:48 +02:00
from contextlib import closing
from io import StringIO
2021-12-22 19:25:36 +01:00
from typing import Optional
import numpy as np
from gym import Env, spaces
from gym.envs.toy_text.utils import categorical_sample
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
2021-12-22 19:25:36 +01:00
class CliffWalkingEnv(Env):
"""
This is a simple implementation of the Gridworld Cliff
reinforcement learning task.
Adapted from Example 6.6 (page 106) from Reinforcement Learning: An Introduction
by Sutton and Barto:
http://incompleteideas.net/book/bookdraft2018jan1.pdf
With inspiration from:
https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/cliff_walking.py
The board is a 4x12 matrix, with (using NumPy matrix indexing):
[3, 0] as the start at bottom-left
[3, 11] as the goal at bottom-right
[3, 1..10] as the cliff at bottom-center
Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward
and a reset to the start. An episode terminates when the agent reaches the goal.
"""
2021-07-29 02:26:34 +02:00
metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
def __init__(self):
self.shape = (4, 12)
self.start_state_index = np.ravel_multi_index((3, 0), self.shape)
2021-12-22 19:25:36 +01:00
self.nS = np.prod(self.shape)
self.nA = 4
# Cliff Location
self._cliff = np.zeros(self.shape, dtype=np.bool)
self._cliff[3, 1:-1] = True
# Calculate transition probabilities and rewards
2021-12-22 19:25:36 +01:00
self.P = {}
for s in range(self.nS):
position = np.unravel_index(s, self.shape)
2021-12-22 19:25:36 +01:00
self.P[s] = {a: [] for a in range(self.nA)}
self.P[s][UP] = self._calculate_transition_prob(position, [-1, 0])
self.P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1])
self.P[s][DOWN] = self._calculate_transition_prob(position, [1, 0])
self.P[s][LEFT] = self._calculate_transition_prob(position, [0, -1])
# Calculate initial state distribution
# We always start in state (3, 0)
2021-12-22 19:25:36 +01:00
self.initial_state_distrib = np.zeros(self.nS)
self.initial_state_distrib[self.start_state_index] = 1.0
2021-12-22 19:25:36 +01:00
self.observation_space = spaces.Discrete(self.nS)
self.action_space = spaces.Discrete(self.nA)
def _limit_coordinates(self, coord):
"""
Prevent the agent from falling out of the grid world
:param coord:
:return:
"""
coord[0] = min(coord[0], self.shape[0] - 1)
coord[0] = max(coord[0], 0)
coord[1] = min(coord[1], self.shape[1] - 1)
coord[1] = max(coord[1], 0)
return coord
def _calculate_transition_prob(self, current, delta):
"""
Determine the outcome for an action. Transition Prob is always 1.0.
:param current: Current position on the grid as (row, col)
:param delta: Change in position for transition
:return: (1.0, new_state, reward, done)
"""
new_position = np.array(current) + np.array(delta)
new_position = self._limit_coordinates(new_position).astype(int)
new_state = np.ravel_multi_index(tuple(new_position), self.shape)
if self._cliff[tuple(new_position)]:
return [(1.0, self.start_state_index, -100, False)]
terminal_state = (self.shape[0] - 1, self.shape[1] - 1)
is_done = tuple(new_position) == terminal_state
return [(1.0, new_state, -1, is_done)]
2021-12-22 19:25:36 +01:00
def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d = transitions[i]
self.s = s
self.lastaction = a
return (int(s), r, d, {"prob": p})
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
2021-12-22 19:25:36 +01:00
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
2021-12-22 19:25:36 +01:00
2021-07-29 02:26:34 +02:00
def render(self, mode="human"):
outfile = StringIO() if mode == "ansi" else sys.stdout
for s in range(self.nS):
position = np.unravel_index(s, self.shape)
if self.s == s:
output = " x "
# Print terminal state
elif position == (3, 11):
output = " T "
elif self._cliff[position]:
output = " C "
else:
output = " o "
if position[1] == 0:
output = output.lstrip()
if position[1] == self.shape[1] - 1:
output = output.rstrip()
2021-07-29 02:26:34 +02:00
output += "\n"
outfile.write(output)
2021-07-29 02:26:34 +02:00
outfile.write("\n")
2020-06-20 00:03:48 +02:00
# No need to return anything for human
2021-07-29 02:26:34 +02:00
if mode != "human":
2020-06-20 00:03:48 +02:00
with closing(outfile):
2021-07-29 02:26:34 +02:00
return outfile.getvalue()