mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
import logging
|
|
import pkg_resources
|
|
import re
|
|
import sys
|
|
from gym import error
|
|
from gym.utils.atexit_utils import env_close_registry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
# This format is true today, but it's *not* an official spec.
|
|
env_id_re = re.compile(r'^([\w:-]+)-v(\d+)$')
|
|
|
|
def load(name):
|
|
entry_point = pkg_resources.EntryPoint.parse('x={}'.format(name))
|
|
result = entry_point.load(False)
|
|
return result
|
|
|
|
class EnvSpec(object):
|
|
"""A specification for a particular instance of the environment. Used
|
|
to register the parameters for official evaluations.
|
|
|
|
Args:
|
|
id (str): The official environment ID
|
|
entry_point (Optional[str]): The Python entrypoint of the environment class (e.g. module.name:Class)
|
|
timestep_limit (int): The max number of timesteps per episode during training
|
|
trials (int): The number of trials to average reward over
|
|
reward_threshold (Optional[int]): The reward threshold before the task is considered solved
|
|
kwargs (dict): The kwargs to pass to the environment class
|
|
|
|
Attributes:
|
|
id (str): The official environment ID
|
|
timestep_limit (int): The max number of timesteps per episode in official evaluation
|
|
trials (int): The number of trials run in official evaluation
|
|
"""
|
|
|
|
def __init__(self, id, entry_point=None, timestep_limit=1000, trials=100, reward_threshold=None, kwargs=None):
|
|
self.id = id
|
|
# Evaluation parameters
|
|
self.timestep_limit = timestep_limit
|
|
self.trials = trials
|
|
self.reward_threshold = reward_threshold
|
|
|
|
# We may make some of these other parameters public if they're
|
|
# useful.
|
|
match = env_id_re.search(id)
|
|
if not match:
|
|
raise error.Error('Attempted to register malformed environment ID: {}. (Currently all IDs must be of the form {}.)'.format(id, env_id_re.pattern))
|
|
self._entry_point = entry_point
|
|
self._kwargs = {} if kwargs is None else kwargs
|
|
|
|
def make(self):
|
|
"""Instantiates an instance of the environment with appropriate kwargs"""
|
|
if self._entry_point is None:
|
|
raise error.Error('Attempting to make deprecated env {}. (HINT: is there a newer registered version of this env?)'.format(self.id))
|
|
|
|
cls = load(self._entry_point)
|
|
env = cls(**self._kwargs)
|
|
|
|
# Make the enviroment aware of which spec it came from.
|
|
env.spec = self
|
|
# Register the env for atexit
|
|
env.env_exit_id = env_close_registry.register(env)
|
|
return env
|
|
|
|
def __repr__(self):
|
|
return "EnvSpec({})".format(self.id)
|
|
|
|
|
|
class EnvRegistry(object):
|
|
"""Register an env by ID. IDs remain stable over time and are
|
|
guaranteed to resolve to the same environment dynamics (or be
|
|
desupported). The goal is that results on a particular environment
|
|
should always be comparable, and not depend on the version of the
|
|
code that was running.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.env_specs = {}
|
|
|
|
def make(self, id):
|
|
logger.info('Making new env: %s', id)
|
|
spec = self.spec(id)
|
|
return spec.make()
|
|
|
|
def all(self):
|
|
return self.env_specs.values()
|
|
|
|
def spec(self, id):
|
|
match = env_id_re.search(id)
|
|
if not match:
|
|
raise error.Error('Attempted to look up malformed environment ID: {}. (Currently all IDs must be of the form {}.)'.format(id.encode('utf-8'), env_id_re.pattern))
|
|
|
|
try:
|
|
return self.env_specs[id]
|
|
except KeyError:
|
|
raise error.UnregisteredEnv('No registered env with id: {}'.format(id))
|
|
|
|
def register(self, id, **kwargs):
|
|
if id in self.env_specs:
|
|
raise error.Error('Cannot re-register id: {}'.format(id))
|
|
self.env_specs[id] = EnvSpec(id, **kwargs)
|
|
|
|
# Have a global registry
|
|
registry = EnvRegistry()
|
|
register = registry.register
|
|
make = registry.make
|
|
spec = registry.spec
|