mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Add docs (#13)
This commit is contained in:
109
docs/scripts/gen_atari_table.py
Normal file
109
docs/scripts/gen_atari_table.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import tabulate
|
||||
from tqdm import tqdm
|
||||
|
||||
import gymnasium
|
||||
|
||||
|
||||
def shortened_repr(lst):
|
||||
assert all(isinstance(item, int) for item in lst)
|
||||
assert len(set(lst)) == len(lst)
|
||||
lst = sorted(lst)
|
||||
|
||||
if lst[-1] - lst[0] == len(lst) - 1 and len(lst) > 3:
|
||||
return f"`[{lst[0]}, ..., {lst[-1]}]`"
|
||||
elif len(lst) > 3 and lst[-2] - lst[0] == len(lst) - 2:
|
||||
return f"`[{lst[0]}, ..., {lst[-2]}, {lst[-1]}]`"
|
||||
return f"`{str(lst)}`"
|
||||
|
||||
|
||||
def to_gymnasium_spelling(game):
|
||||
parts = game.split("_")
|
||||
return "".join([part.capitalize() for part in parts])
|
||||
|
||||
|
||||
atari_envs = [
|
||||
"adventure",
|
||||
"air_raid",
|
||||
"alien",
|
||||
"amidar",
|
||||
"assault",
|
||||
"asterix",
|
||||
"asteroids",
|
||||
"atlantis",
|
||||
"bank_heist",
|
||||
"battle_zone",
|
||||
"beam_rider",
|
||||
"berzerk",
|
||||
"bowling",
|
||||
"boxing",
|
||||
"breakout",
|
||||
"carnival",
|
||||
"centipede",
|
||||
"chopper_command",
|
||||
"crazy_climber",
|
||||
"defender",
|
||||
"demon_attack",
|
||||
"double_dunk",
|
||||
"elevator_action",
|
||||
"enduro",
|
||||
"fishing_derby",
|
||||
"freeway",
|
||||
"frostbite",
|
||||
"gopher",
|
||||
"gravitar",
|
||||
"hero",
|
||||
"ice_hockey",
|
||||
"jamesbond",
|
||||
"journey_escape",
|
||||
"kangaroo",
|
||||
"krull",
|
||||
"kung_fu_master",
|
||||
"montezuma_revenge",
|
||||
"ms_pacman",
|
||||
"name_this_game",
|
||||
"phoenix",
|
||||
"pitfall",
|
||||
"pong",
|
||||
"pooyan",
|
||||
"private_eye",
|
||||
"qbert",
|
||||
"riverraid",
|
||||
"road_runner",
|
||||
"robotank",
|
||||
"seaquest",
|
||||
"skiing",
|
||||
"solaris",
|
||||
"space_invaders",
|
||||
"star_gunner",
|
||||
"tennis",
|
||||
"time_pilot",
|
||||
"tutankham",
|
||||
"up_n_down",
|
||||
"venture",
|
||||
"video_pinball",
|
||||
"wizard_of_wor",
|
||||
"yars_revenge",
|
||||
"zaxxon",
|
||||
]
|
||||
|
||||
|
||||
header = ["Environment", "Valid Modes", "Valid Difficulties", "Default Mode"]
|
||||
rows = []
|
||||
|
||||
for game in tqdm(atari_envs):
|
||||
env = gymnasium.make(f"ALE/{to_gymnasium_spelling(game)}-v5")
|
||||
valid_modes = env.unwrapped.ale.getAvailableModes()
|
||||
valid_difficulties = env.unwrapped.ale.getAvailableDifficulties()
|
||||
difficulty = env.unwrapped.ale.cloneState().getDifficulty()
|
||||
assert difficulty == 0, difficulty
|
||||
rows.append(
|
||||
[
|
||||
to_gymnasium_spelling(game),
|
||||
shortened_repr(valid_modes),
|
||||
shortened_repr(valid_difficulties),
|
||||
f"`{env.unwrapped.ale.cloneState().getCurrentMode()}`",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
print(tabulate.tabulate(rows, headers=header, tablefmt="github"))
|
183
docs/scripts/gen_envs_display.py
Normal file
183
docs/scripts/gen_envs_display.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import sys
|
||||
|
||||
all_envs = [
|
||||
{
|
||||
"id": "mujoco",
|
||||
"list": [
|
||||
"ant",
|
||||
"half_cheetah",
|
||||
"hopper",
|
||||
"humanoid_standup",
|
||||
"humanoid",
|
||||
"inverted_double_pendulum",
|
||||
"inverted_pendulum",
|
||||
"reacher",
|
||||
"swimmer",
|
||||
"walker2d",
|
||||
],
|
||||
},
|
||||
{"id": "toy_text", "list": ["blackjack", "frozen_lake"]},
|
||||
{"id": "box2d", "list": ["bipedal_walker", "car_racing", "lunar_lander"]},
|
||||
{
|
||||
"id": "classic_control",
|
||||
"list": [
|
||||
"acrobot",
|
||||
"cart_pole",
|
||||
"mountain_car_continuous",
|
||||
"mountain_car",
|
||||
"pendulum",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "atari",
|
||||
"list": [
|
||||
"adventure",
|
||||
"air_raid",
|
||||
"alien",
|
||||
"amidar",
|
||||
"assault",
|
||||
"asterix",
|
||||
"asteroids",
|
||||
"atlantis",
|
||||
"bank_heist",
|
||||
"battle_zone",
|
||||
"beam_rider",
|
||||
"berzerk",
|
||||
"bowling",
|
||||
"boxing",
|
||||
"breakout",
|
||||
"carnival",
|
||||
"centipede",
|
||||
"chopper_command",
|
||||
"crazy_climber",
|
||||
"defender",
|
||||
"demon_attack",
|
||||
"double_dunk",
|
||||
"elevator_action",
|
||||
"enduro",
|
||||
"fishing_derby",
|
||||
"freeway",
|
||||
"frostbite",
|
||||
"gopher",
|
||||
"gravitar",
|
||||
"hero",
|
||||
"ice_hockey",
|
||||
"jamesbond",
|
||||
"journey_escape",
|
||||
"kangaroo",
|
||||
"krull",
|
||||
"kung_fu_master",
|
||||
"montezuma_revenge",
|
||||
"ms_pacman",
|
||||
"name_this_game",
|
||||
"phoenix",
|
||||
"pitfall",
|
||||
"pong",
|
||||
"pooyan",
|
||||
"private_eye",
|
||||
"qbert",
|
||||
"riverraid",
|
||||
"road_runner",
|
||||
"robotank",
|
||||
"seaquest",
|
||||
"skiing",
|
||||
"solaris",
|
||||
"space_invaders",
|
||||
"star_gunner",
|
||||
"tennis",
|
||||
"time_pilot",
|
||||
"tutankham",
|
||||
"up_n_down",
|
||||
"venture",
|
||||
"video_pinball",
|
||||
"wizard_of_wor",
|
||||
"yars_revenge",
|
||||
"zaxxon",
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def create_grid_cell(type_id, env_id, base_path):
|
||||
return f"""
|
||||
<a href="{base_path}{env_id}">
|
||||
<div class="env-grid__cell">
|
||||
<div class="cell__image-container">
|
||||
<img src="/_static/videos/{type_id}/{env_id}.gif">
|
||||
</div>
|
||||
<div class="cell__title">
|
||||
<span>{' '.join(env_id.split('_')).title()}</span>
|
||||
</div>
|
||||
</div>
|
||||
</a>
|
||||
"""
|
||||
|
||||
|
||||
def generate_page(env, limit=-1, base_path=""):
|
||||
env_type_id = env["id"]
|
||||
env_list = env["list"]
|
||||
cells = [create_grid_cell(env_type_id, env_id, base_path) for env_id in env_list]
|
||||
non_limited_page = limit == -1 or limit >= len(cells)
|
||||
if non_limited_page:
|
||||
cells = "\n".join(cells)
|
||||
else:
|
||||
cells = "\n".join(cells[:limit])
|
||||
|
||||
more_btn = (
|
||||
"""<a href="./complete_list">
|
||||
<button class="more-btn">
|
||||
See More Environments
|
||||
</button>
|
||||
</a>"""
|
||||
if not non_limited_page
|
||||
else ""
|
||||
)
|
||||
return f"""
|
||||
<div class="env-grid">
|
||||
{cells}
|
||||
</div>
|
||||
{more_btn}
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python gen_envs_display [ env_type ]
|
||||
"""
|
||||
|
||||
type_dict_arr = []
|
||||
type_arg = ""
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
type_arg = sys.argv[1]
|
||||
|
||||
for env in all_envs:
|
||||
if type_arg == env["id"] or type_arg == "":
|
||||
type_dict_arr.append(env)
|
||||
|
||||
for type_dict in type_dict_arr:
|
||||
type_id = type_dict["id"]
|
||||
envs_path = f"../environments/{type_id}"
|
||||
if len(type_dict["list"]) > 20:
|
||||
page = generate_page(type_dict, limit=9)
|
||||
fp = open(f"{envs_path}/index.html", "w+", encoding="utf-8")
|
||||
fp.write(page)
|
||||
fp.close()
|
||||
|
||||
page = generate_page(type_dict, base_path="../")
|
||||
fp = open(f"{envs_path}/complete_list.html", "w+", encoding="utf-8")
|
||||
fp.write(page)
|
||||
fp.close()
|
||||
|
||||
fp = open(f"{envs_path}/complete_list.md", "w+", encoding="utf-8")
|
||||
env_name = " ".join(type_id.split("_")).title()
|
||||
fp.write(
|
||||
f"# Complete List - {env_name}\n"
|
||||
+ "```{raw} html\n:file: complete_list.html\n```"
|
||||
)
|
||||
fp.close()
|
||||
else:
|
||||
page = generate_page(type_dict)
|
||||
fp = open(f"{envs_path}/index.html", "w+", encoding="utf-8")
|
||||
fp.write(page)
|
||||
fp.close()
|
89
docs/scripts/gen_gifs.py
Normal file
89
docs/scripts/gen_gifs.py
Normal file
@@ -0,0 +1,89 @@
|
||||
__author__ = "Sander Schulhoff"
|
||||
__email__ = "sanderschulhoff@gmail.com"
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from utils import kill_strs
|
||||
|
||||
import gymnasium
|
||||
|
||||
# snake to camel case: https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case # noqa: E501
|
||||
pattern = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
# how many steps to record an env for
|
||||
LENGTH = 300
|
||||
# iterate through all envspecs
|
||||
for env_spec in tqdm(gymnasium.envs.registry.values()):
|
||||
if "Cliff" not in env_spec.id:
|
||||
continue
|
||||
|
||||
if any(x in str(env_spec.id) for x in kill_strs):
|
||||
continue
|
||||
print(env_spec.id)
|
||||
# try catch in case missing some installs
|
||||
try:
|
||||
env = gymnasium.make(env_spec.id)
|
||||
# the gymnasium needs to be rgb renderable
|
||||
if not ("rgb_array" in env.metadata["render_modes"]):
|
||||
continue
|
||||
# extract env name/type from class path
|
||||
split = str(type(env.unwrapped)).split(".")
|
||||
|
||||
# get rid of version info
|
||||
env_name = env_spec.id.split("-")[0]
|
||||
# convert NameLikeThis to name_like_this
|
||||
env_name = pattern.sub("_", env_name).lower()
|
||||
# get the env type (e.g. Box2D)
|
||||
env_type = split[2]
|
||||
|
||||
# if its an atari gymnasium
|
||||
# if env_spec.id[0:3] == "ALE":
|
||||
# continue
|
||||
# env_name = env_spec.id.split("-")[0][4:]
|
||||
# env_name = pattern.sub('_', env_name).lower()
|
||||
|
||||
# path for saving video
|
||||
# v_path = os.path.join("..", "pages", "environments", env_type, "videos") # noqa: E501
|
||||
# # create dir if it doesn't exist
|
||||
# if not path.isdir(v_path):
|
||||
# mkdir(v_path)
|
||||
|
||||
# obtain and save LENGTH frames worth of steps
|
||||
frames = []
|
||||
while True:
|
||||
state, info = env.reset()
|
||||
done = False
|
||||
while not done and len(frames) <= LENGTH:
|
||||
|
||||
frame = env.render(mode="rgb_array")
|
||||
repeat = (
|
||||
int(60 / env.metadata["render_fps"])
|
||||
if env_type == "toy_text"
|
||||
else 1
|
||||
)
|
||||
for i in range(repeat):
|
||||
frames.append(Image.fromarray(frame))
|
||||
action = env.action_space.sample()
|
||||
state_next, reward, done, info = env.step(action)
|
||||
|
||||
if len(frames) > LENGTH:
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
# make sure video doesn't already exist
|
||||
# if not os.path.exists(os.path.join(v_path, env_name + ".gif")):
|
||||
frames[0].save(
|
||||
os.path.join("..", "_static", "videos", env_type, env_name + ".gif"),
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=50,
|
||||
loop=0,
|
||||
)
|
||||
print("Saved: " + env_name)
|
||||
|
||||
except BaseException as e:
|
||||
print("ERROR", e)
|
||||
continue
|
153
docs/scripts/gen_mds.py
Normal file
153
docs/scripts/gen_mds.py
Normal file
@@ -0,0 +1,153 @@
|
||||
__author__ = "Sander Schulhoff"
|
||||
__email__ = "sanderschulhoff@gmail.com"
|
||||
|
||||
import os
|
||||
import re
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from utils import kill_strs, trim
|
||||
|
||||
import gymnasium
|
||||
|
||||
LAYOUT = "env"
|
||||
|
||||
pattern = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
|
||||
gymnasium.logger.set_level(gymnasium.logger.DISABLED)
|
||||
|
||||
all_envs = list(gymnasium.envs.registry.values())
|
||||
filtered_envs_by_type = {}
|
||||
|
||||
# Obtain filtered list
|
||||
for env_spec in tqdm(all_envs):
|
||||
|
||||
if any(x in str(env_spec.id) for x in kill_strs):
|
||||
continue
|
||||
|
||||
try:
|
||||
env = gymnasium.make(env_spec.id)
|
||||
split = str(type(env.unwrapped)).split(".")
|
||||
env_type = split[2]
|
||||
|
||||
if env_type == "atari" or env_type == "unittest":
|
||||
continue
|
||||
|
||||
if env_type not in filtered_envs_by_type.keys():
|
||||
filtered_envs_by_type[env_type] = []
|
||||
filtered_envs_by_type[env_type].append((env_spec, env_type))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
# Sort
|
||||
filtered_envs = list(
|
||||
reduce(
|
||||
lambda s, x: s + x,
|
||||
map(
|
||||
lambda arr: sorted(arr, key=lambda x: x[0].name),
|
||||
list(filtered_envs_by_type.values()),
|
||||
),
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Update Docs
|
||||
for i, (env_spec, env_type) in tqdm(enumerate(filtered_envs)):
|
||||
print("ID:", env_spec.id)
|
||||
try:
|
||||
env = gymnasium.make(env_spec.id)
|
||||
|
||||
# variants dont get their own pages
|
||||
e_n = str(env_spec).lower()
|
||||
|
||||
docstring = env.unwrapped.__doc__
|
||||
if not docstring:
|
||||
docstring = env.unwrapped.__class__.__doc__
|
||||
docstring = trim(docstring)
|
||||
|
||||
# pascal case
|
||||
pascal_env_name = env_spec.id.split("-")[0]
|
||||
snake_env_name = pattern.sub("_", pascal_env_name).lower()
|
||||
title_env_name = snake_env_name.replace("_", " ").title()
|
||||
env_type_title = env_type.replace("_", " ").title()
|
||||
related_pages_meta = ""
|
||||
if i == 0 or not env_type == filtered_envs[i - 1][1]:
|
||||
related_pages_meta = "firstpage:\n"
|
||||
elif i == len(filtered_envs) - 1 or not env_type == filtered_envs[i + 1][1]:
|
||||
related_pages_meta = "lastpage:\n"
|
||||
|
||||
# path for saving video
|
||||
v_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"environments",
|
||||
env_type,
|
||||
snake_env_name + ".md",
|
||||
)
|
||||
|
||||
front_matter = f"""---
|
||||
AUTOGENERATED: DO NOT EDIT FILE DIRECTLY
|
||||
title: {title_env_name}
|
||||
{related_pages_meta}---
|
||||
"""
|
||||
title = f"# {title_env_name}"
|
||||
gif = (
|
||||
"```{figure}"
|
||||
+ f" ../../_static/videos/{env_type}/{snake_env_name}.gif"
|
||||
+ f" \n:width: 200px\n:name: {snake_env_name}\n```"
|
||||
)
|
||||
info = (
|
||||
"This environment is part of the"
|
||||
+ f"<a href='..'>{env_type_title} environments</a>."
|
||||
+ "Please read that page first for general information."
|
||||
)
|
||||
env_table = "| | |\n|---|---|\n"
|
||||
env_table += f"| Action Space | {env.action_space} |\n"
|
||||
|
||||
if env.observation_space.shape:
|
||||
env_table += f"| Observation Shape | {env.observation_space.shape} |\n"
|
||||
|
||||
if hasattr(env.observation_space, "high"):
|
||||
high = env.observation_space.high
|
||||
|
||||
if hasattr(high, "shape"):
|
||||
if len(high.shape) == 3:
|
||||
high = high[0][0][0]
|
||||
high = np.round(high, 2)
|
||||
high = str(high).replace("\n", " ")
|
||||
env_table += f"| Observation High | {high} |\n"
|
||||
|
||||
if hasattr(env.observation_space, "low"):
|
||||
low = env.observation_space.low
|
||||
if hasattr(low, "shape"):
|
||||
if len(low.shape) == 3:
|
||||
low = low[0][0][0]
|
||||
low = np.round(low, 2)
|
||||
low = str(low).replace("\n", " ")
|
||||
env_table += f"| Observation Low | {low} |\n"
|
||||
else:
|
||||
env_table += f"| Observation Space | {env.observation_space} |\n"
|
||||
|
||||
env_table += f'| Import | `gymnasium.make("{env_spec.id}")` | \n'
|
||||
|
||||
if docstring is None:
|
||||
docstring = "No information provided"
|
||||
all_text = f"""{front_matter}
|
||||
{title}
|
||||
|
||||
{gif}
|
||||
|
||||
{info}
|
||||
|
||||
{env_table}
|
||||
|
||||
{docstring}
|
||||
"""
|
||||
file = open(v_path, "w", encoding="utf-8")
|
||||
file.write(all_text)
|
||||
file.close()
|
||||
except Exception as e:
|
||||
print(e)
|
14
docs/scripts/move_404.py
Normal file
14
docs/scripts/move_404.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Provide a path")
|
||||
filePath = sys.argv[1]
|
||||
|
||||
with open(filePath, "r+") as fp:
|
||||
content = fp.read()
|
||||
content = content.replace('href="../', 'href="/').replace('src="../', 'src="/')
|
||||
fp.seek(0)
|
||||
fp.truncate()
|
||||
|
||||
fp.write(content)
|
44
docs/scripts/utils.py
Normal file
44
docs/scripts/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# stolen from python docs
|
||||
def trim(docstring):
|
||||
if not docstring:
|
||||
return ""
|
||||
# Convert tabs to spaces (following the normal Python rules)
|
||||
# and split into a list of lines:
|
||||
lines = docstring.expandtabs().splitlines()
|
||||
# Determine minimum indentation (first line doesn't count):
|
||||
indent = 232323
|
||||
for line in lines[1:]:
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indent = min(indent, len(line) - len(stripped))
|
||||
# Remove indentation (first line is special):
|
||||
trimmed = [lines[0].strip()]
|
||||
if indent < 232323:
|
||||
for line in lines[1:]:
|
||||
trimmed.append(line[indent:].rstrip())
|
||||
# Strip off trailing and leading blank lines:
|
||||
while trimmed and not trimmed[-1]:
|
||||
trimmed.pop()
|
||||
while trimmed and not trimmed[0]:
|
||||
trimmed.pop(0)
|
||||
# Return a single string:
|
||||
return "\n".join(trimmed)
|
||||
|
||||
|
||||
# dont want envs which contain these
|
||||
kill_strs = [
|
||||
"eterministic",
|
||||
"ALE",
|
||||
"-ram",
|
||||
"Frameskip",
|
||||
"Hard",
|
||||
"LanderContinu",
|
||||
"8x8",
|
||||
"uessing",
|
||||
"otter",
|
||||
"oinflip",
|
||||
"hain",
|
||||
"oulette",
|
||||
"DomainRandom",
|
||||
"RacingDiscrete",
|
||||
]
|
Reference in New Issue
Block a user