mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-26 16:27:11 +00:00
Update gen-gifs to new render api (#371)
This commit is contained in:
@@ -1,89 +1,74 @@
|
||||
__author__ = "Sander Schulhoff"
|
||||
__email__ = "sanderschulhoff@gmail.com"
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import gymnasium
|
||||
from utils import kill_strs
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import find_highest_version, get_env_id
|
||||
|
||||
|
||||
# snake to camel case: https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case # noqa: E501
|
||||
pattern = re.compile(r"(?<!^)(?=[A-Z])")
|
||||
# how many steps to record an env for
|
||||
LENGTH = 300
|
||||
# iterate through all envspecs
|
||||
for env_spec in tqdm(gymnasium.envs.registry.values()):
|
||||
if "Cliff" not in env_spec.id:
|
||||
|
||||
|
||||
exclude_env_names = [
|
||||
"GymV21Environment",
|
||||
"GymV26Environment",
|
||||
"FrozenLake8x8",
|
||||
"LunarLanderContinuous",
|
||||
"BipedalWalkerHardcore",
|
||||
"CartPoleJax",
|
||||
"PendulumJax",
|
||||
"Jax-Blackjack",
|
||||
]
|
||||
for env_spec in gym.registry.values():
|
||||
if env_spec.name in exclude_env_names:
|
||||
continue
|
||||
|
||||
if any(x in str(env_spec.id) for x in kill_strs):
|
||||
continue
|
||||
print(env_spec.id)
|
||||
# try catch in case missing some installs
|
||||
try:
|
||||
env = gymnasium.make(env_spec.id)
|
||||
# the gymnasium needs to be rgb renderable
|
||||
if not ("rgb_array" in env.metadata["render_modes"]):
|
||||
continue
|
||||
# extract env name/type from class path
|
||||
split = str(type(env.unwrapped)).split(".")
|
||||
highest_version = find_highest_version(env_spec.namespace, env_spec.name)
|
||||
env_id = get_env_id(env_spec.namespace, env_spec.name, highest_version)
|
||||
|
||||
# get rid of version info
|
||||
env_name = env_spec.id.split("-")[0]
|
||||
# convert NameLikeThis to name_like_this
|
||||
env_name = pattern.sub("_", env_name).lower()
|
||||
# get the env type (e.g. Box2D)
|
||||
env_type = split[2]
|
||||
if env_id == env_spec.id and isinstance(env_spec.entry_point, str):
|
||||
if "gymnasium" in env_spec.entry_point or (
|
||||
"ALE" == env_spec.namespace and env_spec.kwargs["obs_type"] == "rgb"
|
||||
):
|
||||
print(env_spec.id)
|
||||
env = gym.make(env_spec, render_mode="rgb_array").unwrapped
|
||||
|
||||
# if its an atari gymnasium
|
||||
# if env_spec.id[0:3] == "ALE":
|
||||
# continue
|
||||
# env_name = env_spec.id.split("-")[0][4:]
|
||||
# env_name = pattern.sub('_', env_name).lower()
|
||||
# the gymnasium needs to be rgb renderable
|
||||
if "rgb_array" not in env.metadata["render_modes"]:
|
||||
continue
|
||||
|
||||
# path for saving video
|
||||
# v_path = os.path.join("..", "pages", "environments", env_type, "videos") # noqa: E501
|
||||
# # create dir if it doesn't exist
|
||||
# if not path.isdir(v_path):
|
||||
# mkdir(v_path)
|
||||
# obtain and save LENGTH frames worth of steps
|
||||
frames = []
|
||||
env.reset()
|
||||
while len(frames) <= LENGTH:
|
||||
frames.append(Image.fromarray(env.render()))
|
||||
|
||||
# obtain and save LENGTH frames worth of steps
|
||||
frames = []
|
||||
while True:
|
||||
state, info = env.reset()
|
||||
terminated, truncated = False, False
|
||||
while not (terminated or truncated) and len(frames) <= LENGTH:
|
||||
frame = env.render(mode="rgb_array")
|
||||
repeat = (
|
||||
int(60 / env.metadata["render_fps"])
|
||||
if env_type == "toy_text"
|
||||
else 1
|
||||
)
|
||||
for i in range(repeat):
|
||||
frames.append(Image.fromarray(frame))
|
||||
action = env.action_space.sample()
|
||||
state_next, reward, terminated, truncated, info = env.step(action)
|
||||
_, _, terminated, truncated, _ = env.step(action)
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
|
||||
if len(frames) > LENGTH:
|
||||
break
|
||||
env.close()
|
||||
|
||||
env.close()
|
||||
# make sure video doesn't already exist
|
||||
# if not os.path.exists(os.path.join(v_path, env_name + ".gif")):
|
||||
if "ALE" == env_spec.namespace:
|
||||
env_module = "atari"
|
||||
env_name = env_spec.kwargs["game"]
|
||||
else:
|
||||
env_module = env_spec.entry_point.split(".")[2]
|
||||
env_name = re.sub(r"(?<!^)(?=[A-Z])", "_", env_spec.name).lower()
|
||||
|
||||
# make sure video doesn't already exist
|
||||
# if not os.path.exists(os.path.join(v_path, env_name + ".gif")):
|
||||
frames[0].save(
|
||||
os.path.join("..", "_static", "videos", env_type, env_name + ".gif"),
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=50,
|
||||
loop=0,
|
||||
)
|
||||
print("Saved: " + env_name)
|
||||
|
||||
except BaseException as e:
|
||||
print("ERROR", e)
|
||||
continue
|
||||
# render_fps = env.metadata.get("render_fps", 30)
|
||||
video_path = os.path.join(
|
||||
"..", "_static", "videos", env_module, env_name + ".gif"
|
||||
)
|
||||
frames[0].save(
|
||||
video_path,
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=50, # milliseconds for the frame
|
||||
loop=0,
|
||||
)
|
||||
|
Reference in New Issue
Block a user