From f30ff469b8497a7dbbccdb1bfa3b3638536d21c7 Mon Sep 17 00:00:00 2001 From: Greg Brockman Date: Sun, 4 Sep 2016 00:38:03 -0700 Subject: [PATCH] [WIP] Start adding Filter API (#329) Expand Wrapper API --- gym/__init__.py | 2 +- gym/core.py | 75 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/gym/__init__.py b/gym/__init__.py index 8fc7c62ab..25989007c 100644 --- a/gym/__init__.py +++ b/gym/__init__.py @@ -31,7 +31,7 @@ del logger_setup sanity_check_dependencies() -from gym.core import Env, Space, Wrapper +from gym.core import Env, Space, Wrapper, ObservationWrapper, ActionWrapper, RewardWrapper from gym.envs import make, spec from gym.scoreboard.api import upload diff --git a/gym/core.py b/gym/core.py index a5b59f798..4d06bed34 100644 --- a/gym/core.py +++ b/gym/core.py @@ -122,14 +122,16 @@ class Env(object): return observation, reward, done, info def reset(self): - """ - Resets the state of the environment and returns an initial observation. + """Resets the state of the environment and returns an initial + observation. Will call 'configure()' if not already called. - Returns: - observation (object): the initial observation of the space. (Initial reward is assumed to be 0.) + Returns: observation (object): the initial observation of the + space. (Initial reward is assumed to be 0.) """ if self.metadata.get('configure.required') and not self._configured: - raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self)) + raise error.Error("{} requires manually calling 'configure()' before 'reset()'".format(self)) + elif not self._configured: + self.configure() self.monitor._before_reset() observation = self._reset() @@ -236,7 +238,7 @@ class Env(object): self._configured = True try: - return self._configure(*args, **kwargs) + self._configure(*args, **kwargs) except TypeError as e: # It can be confusing if you have the wrong environment # and try calling with unsupported arguments, since your @@ -301,14 +303,22 @@ class Space(object): return sample_n class Wrapper(Env): - def __init__(self, env): + # Make sure self.env is always defined, even if things break + # early. + env = None + + def __init__(self, env=None): self.env = env - self.metadata = env.metadata - self.action_space = env.action_space - self.observation_space = env.observation_space - self.reward_range = env.reward_range - self._spec = env.spec - self._unwrapped = env.unwrapped + # Merge with the base metadata + metadata = self.metadata + self.metadata = self.env.metadata.copy() + self.metadata.update(metadata) + + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + self.reward_range = self.env.reward_range + self._spec = self.env.spec + self._unwrapped = self.env.unwrapped def _step(self, action): return self.env.step(action) @@ -317,9 +327,13 @@ class Wrapper(Env): return self.env.reset() def _render(self, mode='human', close=False): + if self.env is None: + return return self.env.render(mode, close) def _close(self): + if self.env is None: + return return self.env.close() def _configure(self, *args, **kwargs): @@ -329,7 +343,10 @@ class Wrapper(Env): return self.env.seed(seed) def __str__(self): - return '<{}{} instance>'.format(type(self).__name__, self.env) + return '<{}{}>'.format(type(self).__name__, self.env) + + def __repr__(self): + return str(self) @property def spec(self): @@ -340,6 +357,34 @@ class Wrapper(Env): @spec.setter def spec(self, spec): # Won't have an env attr while in the __new__ from gym.Env - if hasattr(self, 'env'): + if self.env is not None: self.env.spec = spec self._spec = spec + +class ObservationWrapper(Wrapper): + def _reset(self): + observation = self.env.reset() + return self._observation(observation) + + def _step(self, action): + observation, reward, done, info = self.env.step(action) + return self._observation(observation), reward, done, info + + def _observation(self, observation): + raise NotImplementedError + +class RewardWrapper(Wrapper): + def _step(self, action): + observation, reward, done, info = self.env.step(action) + return observation, self._reward(reward), done, info + + def _reward(self, reward): + raise NotImplementedError + +class ActionWrapper(Wrapper): + def _step(self, action): + action = self._action(action) + return self.env.step(action) + + def _action(self, action): + raise NotImplementedError