mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
* feat: add `isort` to `pre-commit` * ci: skip `__init__.py` file for `isort` * ci: make `isort` mandatory in lint pipeline * docs: add a section on Git hooks * ci: check isort diff * fix: isort from master branch * docs: add pre-commit badge * ci: update black + bandit versions * feat: add PR template * refactor: PR template * ci: remove bandit * docs: add Black badge * ci: try to remove all `|| true` statements * ci: remove lint_python job - Remove `lint_python` CI job - Move `pyupgrade` job to `pre-commit` workflow * fix: avoid messing with typing * docs: add a note on running `pre-cpmmit` manually * ci: apply `pre-commit` to the whole codebase
255 lines
8.7 KiB
Python
255 lines
8.7 KiB
Python
import sys
|
|
from contextlib import closing
|
|
from io import StringIO
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
from gym import Env, spaces, utils
|
|
from gym.envs.toy_text.utils import categorical_sample
|
|
|
|
MAP = [
|
|
"+---------+",
|
|
"|R: | : :G|",
|
|
"| : | : : |",
|
|
"| : : : : |",
|
|
"| | : | : |",
|
|
"|Y| : |B: |",
|
|
"+---------+",
|
|
]
|
|
|
|
|
|
class TaxiEnv(Env):
|
|
"""
|
|
|
|
The Taxi Problem
|
|
from "Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition"
|
|
by Tom Dietterich
|
|
|
|
### Description
|
|
There are four designated locations in the grid world indicated by R(ed),
|
|
G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off
|
|
at a random square and the passenger is at a random location. The taxi
|
|
drives to the passenger's location, picks up the passenger, drives to the
|
|
passenger's destination (another one of the four specified locations), and
|
|
then drops off the passenger. Once the passenger is dropped off, the episode ends.
|
|
|
|
Map:
|
|
|
|
+---------+
|
|
|R: | : :G|
|
|
| : | : : |
|
|
| : : : : |
|
|
| | : | : |
|
|
|Y| : |B: |
|
|
+---------+
|
|
|
|
### Actions
|
|
There are 6 discrete deterministic actions:
|
|
- 0: move south
|
|
- 1: move north
|
|
- 2: move east
|
|
- 3: move west
|
|
- 4: pickup passenger
|
|
- 5: drop off passenger
|
|
|
|
### Observations
|
|
There are 500 discrete states since there are 25 taxi positions, 5 possible
|
|
locations of the passenger (including the case when the passenger is in the
|
|
taxi), and 4 destination locations.
|
|
|
|
Note that there are 400 states that can actually be reached during an
|
|
episode. The missing states correspond to situations in which the passenger
|
|
is at the same location as their destination, as this typically signals the
|
|
end of an episode. Four additional states can be observed right after a
|
|
successful episodes, when both the passenger and the taxi are at the destination.
|
|
This gives a total of 404 reachable discrete states.
|
|
|
|
Passenger locations:
|
|
- 0: R(ed)
|
|
- 1: G(reen)
|
|
- 2: Y(ellow)
|
|
- 3: B(lue)
|
|
- 4: in taxi
|
|
|
|
Destinations:
|
|
- 0: R(ed)
|
|
- 1: G(reen)
|
|
- 2: Y(ellow)
|
|
- 3: B(lue)
|
|
|
|
### Rewards
|
|
- -1 per step unless other reward is triggered.
|
|
- +20 delivering passenger.
|
|
- -10 executing "pickup" and "drop-off" actions illegally.
|
|
|
|
### Rendering
|
|
- blue: passenger
|
|
- magenta: destination
|
|
- yellow: empty taxi
|
|
- green: full taxi
|
|
- other letters (R, G, Y and B): locations for passengers and destinations
|
|
state space is represented by:
|
|
(taxi_row, taxi_col, passenger_location, destination)
|
|
|
|
### Arguments
|
|
|
|
```
|
|
gym.make('Taxi-v3')
|
|
```
|
|
|
|
### Version History
|
|
* v3: Map Correction + Cleaner Domain Description
|
|
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
|
|
* v1: Remove (3,2) from locs, add passidx<4 check
|
|
* v0: Initial versions release
|
|
"""
|
|
|
|
metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
|
|
|
|
def __init__(self):
|
|
self.desc = np.asarray(MAP, dtype="c")
|
|
|
|
self.locs = locs = [(0, 0), (0, 4), (4, 0), (4, 3)]
|
|
|
|
num_states = 500
|
|
num_rows = 5
|
|
num_columns = 5
|
|
max_row = num_rows - 1
|
|
max_col = num_columns - 1
|
|
self.initial_state_distrib = np.zeros(num_states)
|
|
num_actions = 6
|
|
self.P = {
|
|
state: {action: [] for action in range(num_actions)}
|
|
for state in range(num_states)
|
|
}
|
|
for row in range(num_rows):
|
|
for col in range(num_columns):
|
|
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
|
|
for dest_idx in range(len(locs)):
|
|
state = self.encode(row, col, pass_idx, dest_idx)
|
|
if pass_idx < 4 and pass_idx != dest_idx:
|
|
self.initial_state_distrib[state] += 1
|
|
for action in range(num_actions):
|
|
# defaults
|
|
new_row, new_col, new_pass_idx = row, col, pass_idx
|
|
reward = (
|
|
-1
|
|
) # default reward when there is no pickup/dropoff
|
|
done = False
|
|
taxi_loc = (row, col)
|
|
|
|
if action == 0:
|
|
new_row = min(row + 1, max_row)
|
|
elif action == 1:
|
|
new_row = max(row - 1, 0)
|
|
if action == 2 and self.desc[1 + row, 2 * col + 2] == b":":
|
|
new_col = min(col + 1, max_col)
|
|
elif action == 3 and self.desc[1 + row, 2 * col] == b":":
|
|
new_col = max(col - 1, 0)
|
|
elif action == 4: # pickup
|
|
if pass_idx < 4 and taxi_loc == locs[pass_idx]:
|
|
new_pass_idx = 4
|
|
else: # passenger not at location
|
|
reward = -10
|
|
elif action == 5: # dropoff
|
|
if (taxi_loc == locs[dest_idx]) and pass_idx == 4:
|
|
new_pass_idx = dest_idx
|
|
done = True
|
|
reward = 20
|
|
elif (taxi_loc in locs) and pass_idx == 4:
|
|
new_pass_idx = locs.index(taxi_loc)
|
|
else: # dropoff at wrong location
|
|
reward = -10
|
|
new_state = self.encode(
|
|
new_row, new_col, new_pass_idx, dest_idx
|
|
)
|
|
self.P[state][action].append((1.0, new_state, reward, done))
|
|
self.initial_state_distrib /= self.initial_state_distrib.sum()
|
|
self.action_space = spaces.Discrete(num_actions)
|
|
self.observation_space = spaces.Discrete(num_states)
|
|
|
|
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
|
|
# (5) 5, 5, 4
|
|
i = taxi_row
|
|
i *= 5
|
|
i += taxi_col
|
|
i *= 5
|
|
i += pass_loc
|
|
i *= 4
|
|
i += dest_idx
|
|
return i
|
|
|
|
def decode(self, i):
|
|
out = []
|
|
out.append(i % 4)
|
|
i = i // 4
|
|
out.append(i % 5)
|
|
i = i // 5
|
|
out.append(i % 5)
|
|
i = i // 5
|
|
out.append(i)
|
|
assert 0 <= i < 5
|
|
return reversed(out)
|
|
|
|
def step(self, a):
|
|
transitions = self.P[self.s][a]
|
|
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
|
p, s, r, d = transitions[i]
|
|
self.s = s
|
|
self.lastaction = a
|
|
return (int(s), r, d, {"prob": p})
|
|
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: Optional[int] = None,
|
|
return_info: bool = False,
|
|
options: Optional[dict] = None,
|
|
):
|
|
super().reset(seed=seed)
|
|
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
|
self.lastaction = None
|
|
if not return_info:
|
|
return int(self.s)
|
|
else:
|
|
return int(self.s), {"prob": 1}
|
|
|
|
def render(self, mode="human"):
|
|
outfile = StringIO() if mode == "ansi" else sys.stdout
|
|
|
|
out = self.desc.copy().tolist()
|
|
out = [[c.decode("utf-8") for c in line] for line in out]
|
|
taxi_row, taxi_col, pass_idx, dest_idx = self.decode(self.s)
|
|
|
|
def ul(x):
|
|
return "_" if x == " " else x
|
|
|
|
if pass_idx < 4:
|
|
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
|
out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True
|
|
)
|
|
pi, pj = self.locs[pass_idx]
|
|
out[1 + pi][2 * pj + 1] = utils.colorize(
|
|
out[1 + pi][2 * pj + 1], "blue", bold=True
|
|
)
|
|
else: # passenger in taxi
|
|
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
|
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
|
|
)
|
|
|
|
di, dj = self.locs[dest_idx]
|
|
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
|
|
outfile.write("\n".join(["".join(row) for row in out]) + "\n")
|
|
if self.lastaction is not None:
|
|
outfile.write(
|
|
f" ({['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'][self.lastaction]})\n"
|
|
)
|
|
else:
|
|
outfile.write("\n")
|
|
|
|
# No need to return anything for human
|
|
if mode != "human":
|
|
with closing(outfile):
|
|
return outfile.getvalue()
|