Files
Gymnasium/gym/envs/registration.py

188 lines
7.5 KiB
Python
Raw Normal View History

2016-04-27 08:00:58 -07:00
import pkg_resources
import re
from gym import error, logger
2016-04-27 08:00:58 -07:00
# This format is true today, but it's *not* an official spec.
2016-10-31 11:36:40 -07:00
# [username/](env-name)-v(version) env-name is group 1, version is group 2
#
# 2016-10-31: We're experimentally expanding the environment ID format
# to include an optional username.
2016-11-11 21:09:25 -08:00
env_id_re = re.compile(r'^(?:[\w:-]+\/)?([\w:.-]+)-v(\d+)$')
2016-04-27 08:00:58 -07:00
def load(name):
entry_point = pkg_resources.EntryPoint.parse('x={}'.format(name))
2016-04-28 17:18:27 -07:00
result = entry_point.load(False)
return result
2016-04-27 08:00:58 -07:00
class EnvSpec(object):
"""A specification for a particular instance of the environment. Used
to register the parameters for official evaluations.
Args:
id (str): The official environment ID
entry_point (Optional[str]): The Python entrypoint of the environment class (e.g. module.name:Class)
2016-04-27 08:00:58 -07:00
trials (int): The number of trials to average reward over
reward_threshold (Optional[int]): The reward threshold before the task is considered solved
local_only: True iff the environment is to be used only on the local machine (e.g. debugging envs)
2016-04-27 08:00:58 -07:00
kwargs (dict): The kwargs to pass to the environment class
[WIP] add support for seeding environments (#135) * Make environments seedable * Fix monitor bugs - Set monitor_id before setting the infix. This was a bug that would yield incorrect results with multiple monitors. - Remove extra pid from stats recorder filename. This should be purely cosmetic. * Start uploading seeds in episode_batch * Fix _bigint_from_bytes for python3 * Set seed explicitly in random_agent * Pass through seed argument * Also pass through random state to spaces * Pass random state into the observation/action spaces * Make all _seed methods return the list of used seeds * Switch over to np.random where possible * Start hashing seeds, and also seed doom engine * Fixup seeding determinism in many cases * Seed before loading the ROM * Make seeding more Python3 friendly * Make the MuJoCo skipping a bit more forgiving * Remove debugging PDB calls * Make setInt argument into raw bytes * Validate and upload seeds * Skip box2d * Make seeds smaller, and change representation of seeds in upload * Handle long seeds * Fix RandomAgent example to be deterministic * Handle integer types correctly in Python2 and Python3 * Try caching pip * Try adding swap * Add df and free calls * Bump swap * Bump swap size * Try setting overcommit * Try other sysctls * Try fixing overcommit * Try just setting overcommit_memory=1 * Add explanatory comment * Add what's new section to readme * BUG: Mark ElevatorAction-ram-v0 as non-deterministic for now * Document seed * Move nondetermistic check into spec
2016-05-29 09:07:09 -07:00
nondeterministic (bool): Whether this environment is non-deterministic even after seeding
tags (dict[str:any]): A set of arbitrary key-value tags on this environment, including simple property=True tags
2016-04-27 08:00:58 -07:00
Attributes:
id (str): The official environment ID
trials (int): The number of trials run in official evaluation
"""
def __init__(self, id, entry_point=None, trials=100, reward_threshold=None, local_only=False, kwargs=None, nondeterministic=False, tags=None, max_episode_steps=None, max_episode_seconds=None, timestep_limit=None):
2016-04-27 08:00:58 -07:00
self.id = id
# Evaluation parameters
self.trials = trials
self.reward_threshold = reward_threshold
# Environment properties
self.nondeterministic = nondeterministic
2016-10-31 11:36:40 -07:00
if tags is None:
2016-09-28 10:24:44 -07:00
tags = {}
self.tags = tags
2016-04-27 08:00:58 -07:00
# BACKWARDS COMPAT 2017/1/18
if tags.get('wrapper_config.TimeLimit.max_episode_steps'):
max_episode_steps = tags.get('wrapper_config.TimeLimit.max_episode_steps')
# TODO: Add the following deprecation warning after 2017/02/18
# warnings.warn("DEPRECATION WARNING wrapper_config.TimeLimit has been deprecated. Replace any calls to `register(tags={'wrapper_config.TimeLimit.max_episode_steps': 200)}` with `register(max_episode_steps=200)`. This change was made 2017/1/31 and is included in gym version 0.8.0. If you are getting many of these warnings, you may need to update universe past version 0.21.3")
tags['wrapper_config.TimeLimit.max_episode_steps'] = max_episode_steps
######
# BACKWARDS COMPAT 2017/1/31
if timestep_limit is not None:
max_episode_steps = timestep_limit
# TODO: Add the following deprecation warning after 2017/03/01
# warnings.warn("register(timestep_limit={}) is deprecated. Use register(max_episode_steps={}) instead.".format(timestep_limit, timestep_limit))
######
self.max_episode_steps = max_episode_steps
self.max_episode_seconds = max_episode_seconds
2016-04-27 08:00:58 -07:00
# We may make some of these other parameters public if they're
# useful.
match = env_id_re.search(id)
if not match:
raise error.Error('Attempted to register malformed environment ID: {}. (Currently all IDs must be of the form {}.)'.format(id, env_id_re.pattern))
2016-05-18 12:08:27 -07:00
self._env_name = match.group(1)
2016-04-27 08:00:58 -07:00
self._entry_point = entry_point
self._local_only = local_only
2016-04-27 08:00:58 -07:00
self._kwargs = {} if kwargs is None else kwargs
def make(self):
"""Instantiates an instance of the environment with appropriate kwargs"""
2016-05-06 22:26:40 -07:00
if self._entry_point is None:
raise error.Error('Attempting to make deprecated env {}. (HINT: is there a newer registered version of this env?)'.format(self.id))
2017-09-18 11:25:12 -07:00
elif callable(self._entry_point):
env = self._entry_point()
else:
cls = load(self._entry_point)
env = cls(**self._kwargs)
2016-04-27 08:00:58 -07:00
# Make the enviroment aware of which spec it came from.
env.unwrapped.spec = self
2016-04-27 08:00:58 -07:00
return env
def __repr__(self):
return "EnvSpec({})".format(self.id)
@property
def timestep_limit(self):
return self.max_episode_steps
@timestep_limit.setter
def timestep_limit(self, value):
self.max_episode_steps = value
2016-04-27 08:00:58 -07:00
class EnvRegistry(object):
"""Register an env by ID. IDs remain stable over time and are
guaranteed to resolve to the same environment dynamics (or be
desupported). The goal is that results on a particular environment
should always be comparable, and not depend on the version of the
code that was running.
"""
def __init__(self):
self.env_specs = {}
def make(self, id):
logger.info('Making new env: %s', id)
spec = self.spec(id)
env = spec.make()
if hasattr(env, "_reset") and hasattr(env, "_step"):
patch_deprecated_methods(env)
if (env.spec.timestep_limit is not None) and not spec.tags.get('vnc'):
from gym.wrappers.time_limit import TimeLimit
env = TimeLimit(env,
max_episode_steps=env.spec.max_episode_steps,
max_episode_seconds=env.spec.max_episode_seconds)
return env
2016-04-27 08:00:58 -07:00
def all(self):
return self.env_specs.values()
def spec(self, id):
match = env_id_re.search(id)
if not match:
raise error.Error('Attempted to look up malformed environment ID: {}. (Currently all IDs must be of the form {}.)'.format(id.encode('utf-8'), env_id_re.pattern))
try:
return self.env_specs[id]
except KeyError:
# Parse the env name and check to see if it matches the non-version
# part of a valid env (could also check the exact number here)
env_name = match.group(1)
2016-05-18 12:08:27 -07:00
matching_envs = [valid_env_name for valid_env_name, valid_env_spec in self.env_specs.items()
if env_name == valid_env_spec._env_name]
if matching_envs:
raise error.DeprecatedEnv('Env {} not found (valid versions include {})'.format(id, matching_envs))
else:
raise error.UnregisteredEnv('No registered env with id: {}'.format(id))
2016-04-27 08:00:58 -07:00
def register(self, id, **kwargs):
2016-04-27 08:00:58 -07:00
if id in self.env_specs:
raise error.Error('Cannot re-register id: {}'.format(id))
self.env_specs[id] = EnvSpec(id, **kwargs)
2016-04-27 08:00:58 -07:00
# Have a global registry
registry = EnvRegistry()
def register(id, **kwargs):
return registry.register(id, **kwargs)
def make(id):
return registry.make(id)
def spec(id):
return registry.spec(id)
warn_once = True
def patch_deprecated_methods(env):
"""
Methods renamed from '_method' to 'method', render() no longer has 'close' parameter, close is a separate method.
For backward compatibility, this makes it possible to work with unmodified environments.
"""
global warn_once
if warn_once:
logger.warn("Environment '%s' has deprecated methods. Compatibility code invoked." % str(type(env)))
warn_once = False
env.reset = env._reset
env.step = env._step
env.seed = env._seed
def render(mode):
return env._render(mode, close=False)
def close():
env._render("human", close=True)
env.render = render
env.close = close