Files
Gymnasium/scripts/generate_json.py
Justin Terry e9d2c41f2b redo black
2021-07-29 12:42:48 -04:00

112 lines
3.7 KiB
Python

from gym import envs, spaces, logger
import json
import os
import sys
import argparse
from gym.envs.tests.spec_list import should_skip_env_spec_for_tests
from gym.envs.tests.test_envs_semantics import generate_rollout_hash, hash_object
DATA_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "gym", "envs", "tests")
ROLLOUT_STEPS = 100
episodes = ROLLOUT_STEPS
steps = ROLLOUT_STEPS
ROLLOUT_FILE = os.path.join(DATA_DIR, "rollout.json")
if not os.path.isfile(ROLLOUT_FILE):
logger.info("No rollout file found. Writing empty json file to {}".format(ROLLOUT_FILE))
with open(ROLLOUT_FILE, "w") as outfile:
json.dump({}, outfile, indent=2)
def update_rollout_dict(spec, rollout_dict):
"""
Takes as input the environment spec for which the rollout is to be generated,
and the existing dictionary of rollouts. Returns True iff the dictionary was
modified.
"""
# Skip platform-dependent
if should_skip_env_spec_for_tests(spec):
logger.info("Skipping tests for {}".format(spec.id))
return False
# Skip environments that are nondeterministic
if spec.nondeterministic:
logger.info("Skipping tests for nondeterministic env {}".format(spec.id))
return False
logger.info("Generating rollout for {}".format(spec.id))
try:
(
observations_hash,
actions_hash,
rewards_hash,
dones_hash,
) = generate_rollout_hash(spec)
except:
# If running the env generates an exception, don't write to the rollout file
logger.warn(
"Exception {} thrown while generating rollout for {}. Rollout not added.".format(sys.exc_info()[0], spec.id)
)
return False
rollout = {}
rollout["observations"] = observations_hash
rollout["actions"] = actions_hash
rollout["rewards"] = rewards_hash
rollout["dones"] = dones_hash
existing = rollout_dict.get(spec.id)
if existing:
differs = False
for key, new_hash in rollout.items():
differs = differs or existing[key] != new_hash
if not differs:
logger.debug("Hashes match with existing for {}".format(spec.id))
return False
else:
logger.warn("Got new hash for {}. Overwriting.".format(spec.id))
rollout_dict[spec.id] = rollout
return True
def add_new_rollouts(spec_ids, overwrite):
environments = [spec for spec in envs.registry.all() if spec.entry_point is not None]
if spec_ids:
environments = [spec for spec in environments if spec.id in spec_ids]
assert len(environments) == len(spec_ids), "Some specs not found"
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
modified = False
for spec in environments:
if not overwrite and spec.id in rollout_dict:
logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
else:
modified = update_rollout_dict(spec, rollout_dict) or modified
if modified:
logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
with open(ROLLOUT_FILE, "w") as outfile:
json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
else:
logger.info("No modifications needed.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Overwrite " + "existing rollouts if hashes differ.",
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("specs", nargs="*", help="ids of env specs to check (default: all)")
args = parser.parse_args()
if args.verbose:
logger.set_level(logger.INFO)
add_new_rollouts(args.specs, args.force)