2017-05-15 14:18:08 +04:00
|
|
|
import sys
|
2020-06-20 00:03:48 +02:00
|
|
|
from contextlib import closing
|
|
|
|
from io import StringIO
|
2021-12-22 19:25:36 +01:00
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import numpy as np
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2021-12-22 19:25:36 +01:00
|
|
|
from gym import Env, spaces
|
|
|
|
from gym.envs.toy_text.utils import categorical_sample
|
2022-06-08 00:20:56 +02:00
|
|
|
from gym.utils.renderer import Renderer
|
2017-05-15 14:18:08 +04:00
|
|
|
|
|
|
|
UP = 0
|
|
|
|
RIGHT = 1
|
|
|
|
DOWN = 2
|
|
|
|
LEFT = 3
|
|
|
|
|
|
|
|
|
2021-12-22 19:25:36 +01:00
|
|
|
class CliffWalkingEnv(Env):
|
2017-05-15 14:18:08 +04:00
|
|
|
"""
|
|
|
|
This is a simple implementation of the Gridworld Cliff
|
|
|
|
reinforcement learning task.
|
|
|
|
|
2022-03-11 23:56:57 +01:00
|
|
|
Adapted from Example 6.6 (page 106) from [Reinforcement Learning: An Introduction
|
|
|
|
by Sutton and Barto](http://incompleteideas.net/book/bookdraft2018jan1.pdf).
|
2019-03-02 08:18:30 +09:00
|
|
|
|
2017-05-15 14:18:08 +04:00
|
|
|
With inspiration from:
|
|
|
|
https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/cliff_walking.py
|
|
|
|
|
2022-03-11 23:56:57 +01:00
|
|
|
### Description
|
2020-11-09 13:24:26 -05:00
|
|
|
The board is a 4x12 matrix, with (using NumPy matrix indexing):
|
2022-03-11 23:56:57 +01:00
|
|
|
- [3, 0] as the start at bottom-left
|
|
|
|
- [3, 11] as the goal at bottom-right
|
|
|
|
- [3, 1..10] as the cliff at bottom-center
|
2017-05-15 14:18:08 +04:00
|
|
|
|
2022-03-11 23:56:57 +01:00
|
|
|
If the agent steps on the cliff it returns to the start.
|
|
|
|
An episode terminates when the agent reaches the goal.
|
|
|
|
|
|
|
|
### Actions
|
|
|
|
There are 4 discrete deterministic actions:
|
|
|
|
- 0: move up
|
|
|
|
- 1: move right
|
|
|
|
- 2: move down
|
|
|
|
- 3: move left
|
|
|
|
|
|
|
|
### Observations
|
2022-05-25 14:46:41 +01:00
|
|
|
There are 3x12 + 1 possible states. In fact, the agent cannot be at the cliff, nor at the goal
|
|
|
|
(as this results the end of episode). They remain all the positions of the first 3 rows plus the bottom-left cell.
|
|
|
|
The observation is simply the current position encoded as
|
|
|
|
[flattened index](https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html).
|
2022-03-11 23:56:57 +01:00
|
|
|
|
|
|
|
### Reward
|
|
|
|
Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward.
|
|
|
|
|
|
|
|
### Arguments
|
|
|
|
|
|
|
|
```
|
|
|
|
gym.make('CliffWalking-v0')
|
|
|
|
```
|
|
|
|
|
|
|
|
### Version History
|
|
|
|
- v0: Initial version release
|
2017-05-15 14:18:08 +04:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-02-28 15:54:03 -05:00
|
|
|
metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
|
2017-05-15 14:18:08 +04:00
|
|
|
|
2022-06-08 00:20:56 +02:00
|
|
|
def __init__(self, render_mode: Optional[str] = None):
|
2017-05-15 14:18:08 +04:00
|
|
|
self.shape = (4, 12)
|
|
|
|
self.start_state_index = np.ravel_multi_index((3, 0), self.shape)
|
|
|
|
|
2021-12-22 19:25:36 +01:00
|
|
|
self.nS = np.prod(self.shape)
|
|
|
|
self.nA = 4
|
2017-05-15 14:18:08 +04:00
|
|
|
|
|
|
|
# Cliff Location
|
2022-03-14 14:27:03 +00:00
|
|
|
self._cliff = np.zeros(self.shape, dtype=bool)
|
2017-05-15 14:18:08 +04:00
|
|
|
self._cliff[3, 1:-1] = True
|
|
|
|
|
|
|
|
# Calculate transition probabilities and rewards
|
2021-12-22 19:25:36 +01:00
|
|
|
self.P = {}
|
|
|
|
for s in range(self.nS):
|
2017-05-15 14:18:08 +04:00
|
|
|
position = np.unravel_index(s, self.shape)
|
2021-12-22 19:25:36 +01:00
|
|
|
self.P[s] = {a: [] for a in range(self.nA)}
|
|
|
|
self.P[s][UP] = self._calculate_transition_prob(position, [-1, 0])
|
|
|
|
self.P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1])
|
|
|
|
self.P[s][DOWN] = self._calculate_transition_prob(position, [1, 0])
|
|
|
|
self.P[s][LEFT] = self._calculate_transition_prob(position, [0, -1])
|
2017-05-15 14:18:08 +04:00
|
|
|
|
|
|
|
# Calculate initial state distribution
|
|
|
|
# We always start in state (3, 0)
|
2021-12-22 19:25:36 +01:00
|
|
|
self.initial_state_distrib = np.zeros(self.nS)
|
|
|
|
self.initial_state_distrib[self.start_state_index] = 1.0
|
2017-05-15 14:18:08 +04:00
|
|
|
|
2021-12-22 19:25:36 +01:00
|
|
|
self.observation_space = spaces.Discrete(self.nS)
|
|
|
|
self.action_space = spaces.Discrete(self.nA)
|
2017-05-15 14:18:08 +04:00
|
|
|
|
2022-06-08 00:20:56 +02:00
|
|
|
self.render_mode = render_mode
|
|
|
|
self.renderer = Renderer(self.render_mode, self._render)
|
|
|
|
|
2022-05-25 14:46:41 +01:00
|
|
|
def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray:
|
|
|
|
"""Prevent the agent from falling out of the grid world."""
|
2017-05-15 14:18:08 +04:00
|
|
|
coord[0] = min(coord[0], self.shape[0] - 1)
|
|
|
|
coord[0] = max(coord[0], 0)
|
|
|
|
coord[1] = min(coord[1], self.shape[1] - 1)
|
|
|
|
coord[1] = max(coord[1], 0)
|
|
|
|
return coord
|
|
|
|
|
|
|
|
def _calculate_transition_prob(self, current, delta):
|
2022-05-25 14:46:41 +01:00
|
|
|
"""Determine the outcome for an action. Transition Prob is always 1.0.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
current: Current position on the grid as (row, col)
|
|
|
|
delta: Change in position for transition
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple of ``(1.0, new_state, reward, done)``
|
2017-05-15 14:18:08 +04:00
|
|
|
"""
|
|
|
|
new_position = np.array(current) + np.array(delta)
|
|
|
|
new_position = self._limit_coordinates(new_position).astype(int)
|
|
|
|
new_state = np.ravel_multi_index(tuple(new_position), self.shape)
|
|
|
|
if self._cliff[tuple(new_position)]:
|
|
|
|
return [(1.0, self.start_state_index, -100, False)]
|
|
|
|
|
|
|
|
terminal_state = (self.shape[0] - 1, self.shape[1] - 1)
|
|
|
|
is_done = tuple(new_position) == terminal_state
|
|
|
|
return [(1.0, new_state, -1, is_done)]
|
|
|
|
|
2021-12-22 19:25:36 +01:00
|
|
|
def step(self, a):
|
|
|
|
transitions = self.P[self.s][a]
|
|
|
|
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
|
|
|
p, s, r, d = transitions[i]
|
|
|
|
self.s = s
|
|
|
|
self.lastaction = a
|
2022-06-08 00:20:56 +02:00
|
|
|
self.renderer.render_step()
|
2021-12-22 19:25:36 +01:00
|
|
|
return (int(s), r, d, {"prob": p})
|
|
|
|
|
2022-02-06 17:28:27 -06:00
|
|
|
def reset(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
seed: Optional[int] = None,
|
|
|
|
return_info: bool = False,
|
|
|
|
options: Optional[dict] = None
|
|
|
|
):
|
2021-12-22 19:25:36 +01:00
|
|
|
super().reset(seed=seed)
|
|
|
|
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
|
|
|
self.lastaction = None
|
2022-06-08 00:20:56 +02:00
|
|
|
self.renderer.reset()
|
|
|
|
self.renderer.render_step()
|
2022-02-06 17:28:27 -06:00
|
|
|
if not return_info:
|
|
|
|
return int(self.s)
|
|
|
|
else:
|
|
|
|
return int(self.s), {"prob": 1}
|
2021-12-22 19:25:36 +01:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
def render(self, mode="human"):
|
2022-06-08 00:20:56 +02:00
|
|
|
if self.render_mode is not None:
|
|
|
|
return self.renderer.get_renders()
|
|
|
|
else:
|
|
|
|
return self._render(mode)
|
|
|
|
|
|
|
|
def _render(self, mode):
|
|
|
|
assert mode in self.metadata["render_modes"]
|
2021-07-29 02:26:34 +02:00
|
|
|
outfile = StringIO() if mode == "ansi" else sys.stdout
|
2017-05-15 14:18:08 +04:00
|
|
|
|
|
|
|
for s in range(self.nS):
|
|
|
|
position = np.unravel_index(s, self.shape)
|
|
|
|
if self.s == s:
|
|
|
|
output = " x "
|
|
|
|
# Print terminal state
|
|
|
|
elif position == (3, 11):
|
|
|
|
output = " T "
|
|
|
|
elif self._cliff[position]:
|
|
|
|
output = " C "
|
|
|
|
else:
|
|
|
|
output = " o "
|
|
|
|
|
|
|
|
if position[1] == 0:
|
|
|
|
output = output.lstrip()
|
|
|
|
if position[1] == self.shape[1] - 1:
|
|
|
|
output = output.rstrip()
|
2021-07-29 02:26:34 +02:00
|
|
|
output += "\n"
|
2017-05-15 14:18:08 +04:00
|
|
|
|
|
|
|
outfile.write(output)
|
2021-07-29 02:26:34 +02:00
|
|
|
outfile.write("\n")
|
2017-05-15 14:18:08 +04:00
|
|
|
|
2020-06-20 00:03:48 +02:00
|
|
|
# No need to return anything for human
|
2021-07-29 02:26:34 +02:00
|
|
|
if mode != "human":
|
2020-06-20 00:03:48 +02:00
|
|
|
with closing(outfile):
|
2021-07-29 02:26:34 +02:00
|
|
|
return outfile.getvalue()
|