from gym import logger import numpy as np from gym import error from gym.utils import closer env_closer = closer.Closer() # Env-related abstractions class Env(object): """The main OpenAI Gym class. It encapsulates an environment with arbitrary behind-the-scenes dynamics. An environment can be partially or fully observed. The main API methods that users of this class need to know are: step reset render close seed And set the following attributes: action_space: The Space object corresponding to valid actions observation_space: The Space object corresponding to valid observations reward_range: A tuple corresponding to the min and max possible rewards Note: a default reward range set to [-inf,+inf] already exists. Set it if you want a narrower range. The methods are accessed publicly as "step", "reset", etc.. The non-underscored versions are wrapper methods to which we may add functionality over time. """ # Set this in SOME subclasses metadata = {'render.modes': []} reward_range = (-np.inf, np.inf) spec = None # Set these in ALL subclasses action_space = None observation_space = None def step(self, action): """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. Accepts an action and returns a tuple (observation, reward, done, info). Args: action (object): an action provided by the environment Returns: observation (object): agent's observation of the current environment reward (float) : amount of reward returned after previous action done (boolean): whether the episode has ended, in which case further step() calls will return undefined results info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning) """ raise NotImplementedError def reset(self): """Resets the state of the environment and returns an initial observation. Returns: observation (object): the initial observation of the space. """ raise NotImplementedError def render(self, mode='human'): """Renders the environment. The set of supported modes varies per environment. (And some environments do not support rendering at all.) By convention, if mode is: - human: render to the current display or terminal and return nothing. Usually for human consumption. - rgb_array: Return an numpy.ndarray with shape (x, y, 3), representing RGB values for an x-by-y pixel image, suitable for turning into a video. - ansi: Return a string (str) or StringIO.StringIO containing a terminal-style text representation. The text can include newlines and ANSI escape sequences (e.g. for colors). Note: Make sure that your class's metadata 'render.modes' key includes the list of supported modes. It's recommended to call super() in implementations to use the functionality of this method. Args: mode (str): the mode to render with close (bool): close all open renderings Example: class MyEnv(Env): metadata = {'render.modes': ['human', 'rgb_array']} def render(self, mode='human'): if mode == 'rgb_array': return np.array(...) # return RGB frame suitable for video elif mode is 'human': ... # pop up a window and render else: super(MyEnv, self).render(mode=mode) # just raise an exception """ raise NotImplementedError def close(self): """Override _close in your subclass to perform any necessary cleanup. Environments will automatically close() themselves when garbage collected or when the program exits. """ return def seed(self, seed=None): """Sets the seed for this env's random number generator(s). Note: Some environments use multiple pseudorandom number generators. We want to capture all such seeds used in order to ensure that there aren't accidental correlations between multiple generators. Returns: list: Returns the list of seeds used in this env's random number generators. The first value in the list should be the "main" seed, or the value which a reproducer should pass to 'seed'. Often, the main seed equals the provided 'seed', but this won't be true if seed=None, for example. """ logger.warn("Could not seed environment %s", self) return @property def unwrapped(self): """Completely unwrap this env. Returns: gym.Env: The base non-wrapped gym.Env instance """ return self def broadcast(self, name, args=None): """ Allows sending of a generic event to an environment subclass. Will also propagate through any Wrappers surrounding env. Returns: object: Returns the result of event implementation. """ return self.event(name, args) def event(self, name, args=None): """ Override this in an Env or Wrapper subclass to handle any specific event propagated through the stack. Returns: object: Subclasses can return any desired result or None. """ return None def __str__(self): if self.spec is None: return '<{} instance>'.format(type(self).__name__) else: return '<{}<{}>>'.format(type(self).__name__, self.spec.id) # Space-related abstractions class Space(object): """Defines the observation and action spaces, so you can write generic code that applies to any Env. For example, you can choose a random action. """ def __init__(self, shape=None, dtype=None): self.shape = None if shape is None else tuple(shape) self.dtype = None if dtype is None else np.dtype(dtype) def sample(self): """ Uniformly randomly sample a random element of this space """ raise NotImplementedError def contains(self, x): """ Return boolean specifying if x is a valid member of this space """ raise NotImplementedError def to_jsonable(self, sample_n): """Convert a batch of samples from this space to a JSONable data type.""" # By default, assume identity is JSONable return sample_n def from_jsonable(self, sample_n): """Convert a JSONable data type to a batch of samples from this space.""" # By default, assume identity is JSONable return sample_n def compatible(self, space): """ Return boolean specifying if space is compatible with this Space (equal shape structure, ignoring bounds). None matches any Space. """ # compare classes if type(self) != type(space): return False # TODO - compare dtypes? # compare shapes return self.shape == space.shape warn_once = True def deprecated_warn_once(text): global warn_once if not warn_once: return warn_once = False logger.warn(text) class Wrapper(Env): env = None def __init__(self, env): self.env = env self.action_space = self.env.action_space self.observation_space = self.env.observation_space self.reward_range = self.env.reward_range self.metadata = self.env.metadata self._warn_double_wrap() @classmethod def class_name(cls): return cls.__name__ def _warn_double_wrap(self): env = self.env while True: if isinstance(env, Wrapper): if env.class_name() == self.class_name(): raise error.DoubleWrapperError("Attempted to double wrap with Wrapper: {}".format(self.__class__.__name__)) env = env.env else: break def step(self, action): if hasattr(self, "_step"): deprecated_warn_once("%s doesn't implement 'step' method, but it implements deprecated '_step' method." % type(self)) self.step = self._step return self.step(action) else: deprecated_warn_once("%s doesn't implement 'step' method, " % type(self) + "which is required for wrappers derived directly from Wrapper. Deprecated default implementation is used.") return self.env.step(action) def reset(self, **kwargs): if hasattr(self, "_reset"): deprecated_warn_once("%s doesn't implement 'reset' method, but it implements deprecated '_reset' method." % type(self)) self.reset = self._reset return self._reset(**kwargs) else: deprecated_warn_once("%s doesn't implement 'reset' method, " % type(self) + "which is required for wrappers derived directly from Wrapper. Deprecated default implementation is used.") return self.env.reset(**kwargs) def render(self, mode='human'): return self.env.render(mode) def close(self): if self.env: return self.env.close() def seed(self, seed=None): return self.env.seed(seed) def __str__(self): return '<{}{}>'.format(type(self).__name__, self.env) def __repr__(self): return str(self) @property def unwrapped(self): return self.env.unwrapped @property def spec(self): return self.env.spec def broadcast(self, name, args=None): """ Allows sending of a generic event to a Wrapper or Env subclass. Will propagate event through all Wrappers down to and including Env. If any layer's 'event' method implementation returns a non-None result, that result will be returned from the outer-most event call. Returns: object: Outer-most non-None result from event method implementations. """ result = self.event(name, args) inner_result = self.env.broadcast(name, args) return inner_result if result is None else result class ObservationWrapper(Wrapper): def step(self, action): observation, reward, done, info = self.env.step(action) return self.observation(observation), reward, done, info def reset(self, **kwargs): observation = self.env.reset(**kwargs) return self.observation(observation) def observation(self, observation): deprecated_warn_once("%s doesn't implement 'observation' method. Maybe it implements deprecated '_observation' method." % type(self)) return self._observation(observation) class RewardWrapper(Wrapper): def reset(self): return self.env.reset() def step(self, action): observation, reward, done, info = self.env.step(action) return observation, self.reward(reward), done, info def reward(self, reward): deprecated_warn_once("%s doesn't implement 'reward' method. Maybe it implements deprecated '_reward' method." % type(self)) return self._reward(reward) class ActionWrapper(Wrapper): def step(self, action): action = self.action(action) return self.env.step(action) def reset(self): return self.env.reset() def action(self, action): deprecated_warn_once("%s doesn't implement 'action' method. Maybe it implements deprecated '_action' method." % type(self)) return self._action(action) def reverse_action(self, action): deprecated_warn_once("%s doesn't implement 'reverse_action' method. Maybe it implements deprecated '_reverse_action' method." % type(self)) return self._reverse_action(action) class EventWrapper(Wrapper): def step(self, action): return self.env.step(action) def reset(self): return self.env.reset() def event(self, name, args=None): deprecated_warn_once("%s doesn't implement 'event' method." % type(self)) return None