Files
Gymnasium/scripts/generate_json.py

89 lines
2.7 KiB
Python
Raw Normal View History

from __future__ import unicode_literals
from gym import envs, spaces
import json
import os
import sys
import hashlib
import logging
logger = logging.getLogger(__name__)
from gym.envs.tests.test_envs import should_skip_env_spec_for_tests
from gym.envs.tests.test_envs_semantics import generate_rollout_hash, hash_object
DATA_DIR = os.path.join(os.path.dirname(__file__), os.pardir, 'gym', 'envs', 'tests')
ROLLOUT_STEPS = 100
episodes = ROLLOUT_STEPS
steps = ROLLOUT_STEPS
python_version = sys.version_info.major
if python_version == 3:
ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout_py3.json')
else:
ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout_py2.json')
if not os.path.isfile(ROLLOUT_FILE):
with open(ROLLOUT_FILE, "w") as outfile:
json.dump({}, outfile, indent=2)
def create_rollout(spec):
"""
Takes as input the environment spec for which the rollout is to be generated.
Returns a bool which indicates whether the new rollout was added to the json file.
"""
# Skip platform-dependent Doom environments
if should_skip_env_spec_for_tests(spec) or 'Doom' in spec.id:
logger.warn("Skipping tests for {}".format(spec.id))
return False
# Skip environments that are nondeterministic
if spec.nondeterministic:
logger.warn("Skipping tests for nondeterministic env {}".format(spec.id))
return False
# Skip broken environments
# TODO: look into these environments
if spec.id in ['PredictObsCartpole-v0', 'InterpretabilityCartpoleObservations-v0']:
logger.warn("Skipping tests for {}".format(spec.id))
return False
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
# Skip generating rollouts that already exist
if spec.id in rollout_dict:
logger.warn("Rollout already exists for {}".format(spec.id))
return False
logger.info("Generating rollout for {}".format(spec.id))
try:
observations_hash, actions_hash, rewards_hash, dones_hash = generate_rollout_hash(spec)
except:
# If running the env generates an exception, don't write to the rollout file
logger.warn("Exception {} thrown while generating rollout for {}. Rollout not added.".format(sys.exc_info()[0], spec.id))
return False
rollout = {}
rollout['observations'] = observations_hash
rollout['actions'] = actions_hash
rollout['rewards'] = rewards_hash
rollout['dones'] = dones_hash
rollout_dict[spec.id] = rollout
with open(ROLLOUT_FILE, "w") as outfile:
json.dump(rollout_dict, outfile, indent=2)
return True
def add_new_rollouts():
environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
for spec in environments:
create_rollout(spec)
add_new_rollouts()