mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Update gen-mds.py (#370)
This commit is contained in:
@@ -1,177 +1,94 @@
|
||||
__author__ = "Sander Schulhoff"
|
||||
__email__ = "sanderschulhoff@gmail.com"
|
||||
|
||||
import os
|
||||
import re
|
||||
from functools import reduce
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from docs.scripts.utils import trim
|
||||
|
||||
import gymnasium as gym
|
||||
from utils import kill_strs, trim
|
||||
from gymnasium.envs.registration import find_highest_version, get_env_id
|
||||
|
||||
|
||||
LAYOUT = "env"
|
||||
filtered_envs = defaultdict(list)
|
||||
exclude_env_names = [
|
||||
"GymV21Environment",
|
||||
"GymV26Environment",
|
||||
"FrozenLake8x8",
|
||||
"LunarLanderContinuous",
|
||||
"BipedalWalkerHardcore",
|
||||
"CartPoleJax",
|
||||
"PendulumJax",
|
||||
"Jax-Blackjack",
|
||||
]
|
||||
for env_spec in gym.registry.values():
|
||||
if env_spec.name not in exclude_env_names:
|
||||
highest_version = find_highest_version(env_spec.namespace, env_spec.name)
|
||||
env_id = get_env_id(env_spec.namespace, env_spec.name, highest_version)
|
||||
|
||||
pattern = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
|
||||
gym.logger.set_level(gym.logger.DISABLED)
|
||||
|
||||
all_envs = list(gym.envs.registry.values())
|
||||
filtered_envs_by_type = {}
|
||||
|
||||
# Obtain filtered list
|
||||
for env_spec in tqdm(all_envs):
|
||||
if any(x in str(env_spec.id) for x in kill_strs):
|
||||
continue
|
||||
|
||||
# gymnasium.envs.env_type.env.EnvClass
|
||||
# ale_py.env.gym:AtariEnv
|
||||
split = env_spec.entry_point.split(".")
|
||||
# ignore gymnasium.envs.env_type:Env
|
||||
env_module = split[0]
|
||||
if len(split) < 4 and env_module != "ale_py":
|
||||
continue
|
||||
env_type = split[2] if env_module != "ale_py" else "atari"
|
||||
env_version = env_spec.version
|
||||
|
||||
# ignore unit test envs and old versions of atari envs
|
||||
if env_module == "ale_py" or env_type == "unittest":
|
||||
continue
|
||||
|
||||
try:
|
||||
env = gym.make(env_spec.id)
|
||||
split = str(type(env.unwrapped)).split(".")
|
||||
env_name = split[3]
|
||||
|
||||
if env_type not in filtered_envs_by_type.keys():
|
||||
filtered_envs_by_type[env_type] = {}
|
||||
# only store new entries and higher versions
|
||||
if env_name not in filtered_envs_by_type[env_type] or (
|
||||
env_name in filtered_envs_by_type[env_type]
|
||||
and env_version > filtered_envs_by_type[env_type][env_name].version
|
||||
env_spec = gym.spec(env_id)
|
||||
if (
|
||||
isinstance(env_spec.entry_point, str)
|
||||
and "gymnasium" in env_spec.entry_point
|
||||
):
|
||||
filtered_envs_by_type[env_type][env_name] = env_spec
|
||||
env_module = env_spec.entry_point.split(".")[2]
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if env_spec not in filtered_envs[env_module]:
|
||||
filtered_envs[env_module].append(env_spec)
|
||||
|
||||
# Sort
|
||||
filtered_envs = list(
|
||||
reduce(
|
||||
lambda s, x: s + x,
|
||||
map(
|
||||
lambda arr: sorted(arr, key=lambda x: x.name),
|
||||
map(lambda dic: list(dic.values()), list(filtered_envs_by_type.values())),
|
||||
),
|
||||
[],
|
||||
)
|
||||
)
|
||||
# print(filtered_envs.keys())
|
||||
for env_module, env_specs in filtered_envs.items():
|
||||
env_module_name = env_module.replace("_", " ").title()
|
||||
print(env_module_name)
|
||||
env_specs = sorted(env_specs, key=lambda spec: spec.name)
|
||||
|
||||
for i, env_spec in enumerate(env_specs):
|
||||
print(f"\t{i=}, {env_spec.name}")
|
||||
env = gym.make(env_spec)
|
||||
env_docstring = trim(env.unwrapped.__doc__)
|
||||
assert env_docstring
|
||||
|
||||
# Update Docs
|
||||
for i, env_spec in tqdm(enumerate(filtered_envs)):
|
||||
print("ID:", env_spec.id)
|
||||
env_type = env_spec.entry_point.split(".")[2]
|
||||
try:
|
||||
env = gym.make(env_spec.id)
|
||||
snake_env_name = re.sub(r"(?<!^)(?=[A-Z])", "_", env_spec.name).lower()
|
||||
title_env_name = re.sub(r"(?<!^)(?=[A-Z])", " ", env_spec.name).title()
|
||||
|
||||
# variants dont get their own pages
|
||||
e_n = str(env_spec).lower()
|
||||
|
||||
docstring = env.unwrapped.__doc__
|
||||
if not docstring:
|
||||
docstring = env.unwrapped.__class__.__doc__
|
||||
docstring = trim(docstring)
|
||||
|
||||
# pascal case
|
||||
pascal_env_name = env_spec.id.split("-")[0]
|
||||
snake_env_name = pattern.sub("_", pascal_env_name).lower()
|
||||
title_env_name = snake_env_name.replace("_", " ").title()
|
||||
env_type_title = env_type.replace("_", " ").title()
|
||||
related_pages_meta = ""
|
||||
if i == 0 or not env_type == filtered_envs[i - 1].entry_point.split(".")[2]:
|
||||
if i == 0:
|
||||
related_pages_meta = "firstpage:\n"
|
||||
elif (
|
||||
i == len(filtered_envs) - 1
|
||||
or not env_type == filtered_envs[i + 1].entry_point.split(".")[2]
|
||||
):
|
||||
elif i == len(env_specs) - 1:
|
||||
related_pages_meta = "lastpage:\n"
|
||||
else:
|
||||
related_pages_meta = ""
|
||||
|
||||
# path for saving video
|
||||
v_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"environments",
|
||||
env_type,
|
||||
snake_env_name + ".md",
|
||||
)
|
||||
action_space_table = env.action_space.__repr__().replace("\n", "")
|
||||
observation_space_table = env.observation_space.__repr__().replace("\n", "")
|
||||
|
||||
front_matter = f"""---
|
||||
env_page = f"""---
|
||||
autogenerated:
|
||||
title: {title_env_name}
|
||||
{related_pages_meta}---
|
||||
|
||||
# {title_env_name}
|
||||
|
||||
```{{figure}} ../../_static/videos/{env_module}/{snake_env_name}.gif
|
||||
:width: 200px
|
||||
:name: {snake_env_name}
|
||||
```
|
||||
|
||||
This environment is part of the {env_module_name} environments which contains general information about the environment.
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| Action Space | `{re.sub(' +', ' ', action_space_table)}` |
|
||||
| Observation Space | `{re.sub(' +', ' ', observation_space_table)}` |
|
||||
| import | `gymnasium.make("{env.spec.id}")` |
|
||||
|
||||
{env_docstring}
|
||||
"""
|
||||
title = f"# {title_env_name}"
|
||||
gif = (
|
||||
"```{figure}"
|
||||
+ f" ../../_static/videos/{env_type}/{snake_env_name}.gif"
|
||||
+ f" \n:width: 200px\n:name: {snake_env_name}\n```"
|
||||
|
||||
env_md_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"environments",
|
||||
env_module,
|
||||
snake_env_name + ".md",
|
||||
)
|
||||
info = (
|
||||
"This environment is part of the "
|
||||
+ f"<a href='..'>{env_type_title} environments</a>."
|
||||
+ "Please read that page first for general information."
|
||||
)
|
||||
env_table = "| | |\n|---|---|\n"
|
||||
env_table += f"| Action Space | {env.action_space} |\n"
|
||||
|
||||
if env.observation_space.shape:
|
||||
env_table += f"| Observation Shape | {env.observation_space.shape} |\n"
|
||||
|
||||
if hasattr(env.observation_space, "high"):
|
||||
high = env.observation_space.high
|
||||
|
||||
if hasattr(high, "shape"):
|
||||
if len(high.shape) == 3:
|
||||
high = high[0][0][0]
|
||||
if env_type == "mujoco":
|
||||
high = high[0]
|
||||
high = np.round(high, 2)
|
||||
high = str(high).replace("\n", " ")
|
||||
env_table += f"| Observation High | {high} |\n"
|
||||
|
||||
if hasattr(env.observation_space, "low"):
|
||||
low = env.observation_space.low
|
||||
if hasattr(low, "shape"):
|
||||
if len(low.shape) == 3:
|
||||
low = low[0][0][0]
|
||||
if env_type == "mujoco":
|
||||
low = low[0]
|
||||
low = np.round(low, 2)
|
||||
low = str(low).replace("\n", " ")
|
||||
env_table += f"| Observation Low | {low} |\n"
|
||||
else:
|
||||
env_table += f"| Observation Space | {env.observation_space} |\n"
|
||||
|
||||
env_table += f'| Import | `gymnasium.make("{env_spec.id}")` | \n'
|
||||
|
||||
if docstring is None:
|
||||
docstring = "No information provided"
|
||||
all_text = f"""{front_matter}
|
||||
{title}
|
||||
|
||||
{gif}
|
||||
|
||||
{info}
|
||||
|
||||
{env_table}
|
||||
|
||||
{docstring}
|
||||
"""
|
||||
file = open(v_path, "w", encoding="utf-8")
|
||||
file.write(all_text)
|
||||
file = open(env_md_path, "w", encoding="utf-8")
|
||||
file.write(env_page)
|
||||
file.close()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
Reference in New Issue
Block a user