[WIP] Start adding Filter API (#329)

Expand Wrapper API
This commit is contained in:
Greg Brockman
2016-09-04 00:38:03 -07:00
committed by GitHub
parent 1452dc3ca2
commit f30ff469b8
2 changed files with 61 additions and 16 deletions

View File

@@ -31,7 +31,7 @@ del logger_setup
sanity_check_dependencies() sanity_check_dependencies()
from gym.core import Env, Space, Wrapper from gym.core import Env, Space, Wrapper, ObservationWrapper, ActionWrapper, RewardWrapper
from gym.envs import make, spec from gym.envs import make, spec
from gym.scoreboard.api import upload from gym.scoreboard.api import upload

View File

@@ -122,14 +122,16 @@ class Env(object):
return observation, reward, done, info return observation, reward, done, info
def reset(self): def reset(self):
""" """Resets the state of the environment and returns an initial
Resets the state of the environment and returns an initial observation. observation. Will call 'configure()' if not already called.
Returns: Returns: observation (object): the initial observation of the
observation (object): the initial observation of the space. (Initial reward is assumed to be 0.) space. (Initial reward is assumed to be 0.)
""" """
if self.metadata.get('configure.required') and not self._configured: if self.metadata.get('configure.required') and not self._configured:
raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self)) raise error.Error("{} requires manually calling 'configure()' before 'reset()'".format(self))
elif not self._configured:
self.configure()
self.monitor._before_reset() self.monitor._before_reset()
observation = self._reset() observation = self._reset()
@@ -236,7 +238,7 @@ class Env(object):
self._configured = True self._configured = True
try: try:
return self._configure(*args, **kwargs) self._configure(*args, **kwargs)
except TypeError as e: except TypeError as e:
# It can be confusing if you have the wrong environment # It can be confusing if you have the wrong environment
# and try calling with unsupported arguments, since your # and try calling with unsupported arguments, since your
@@ -301,14 +303,22 @@ class Space(object):
return sample_n return sample_n
class Wrapper(Env): class Wrapper(Env):
def __init__(self, env): # Make sure self.env is always defined, even if things break
# early.
env = None
def __init__(self, env=None):
self.env = env self.env = env
self.metadata = env.metadata # Merge with the base metadata
self.action_space = env.action_space metadata = self.metadata
self.observation_space = env.observation_space self.metadata = self.env.metadata.copy()
self.reward_range = env.reward_range self.metadata.update(metadata)
self._spec = env.spec
self._unwrapped = env.unwrapped self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
self.reward_range = self.env.reward_range
self._spec = self.env.spec
self._unwrapped = self.env.unwrapped
def _step(self, action): def _step(self, action):
return self.env.step(action) return self.env.step(action)
@@ -317,9 +327,13 @@ class Wrapper(Env):
return self.env.reset() return self.env.reset()
def _render(self, mode='human', close=False): def _render(self, mode='human', close=False):
if self.env is None:
return
return self.env.render(mode, close) return self.env.render(mode, close)
def _close(self): def _close(self):
if self.env is None:
return
return self.env.close() return self.env.close()
def _configure(self, *args, **kwargs): def _configure(self, *args, **kwargs):
@@ -329,7 +343,10 @@ class Wrapper(Env):
return self.env.seed(seed) return self.env.seed(seed)
def __str__(self): def __str__(self):
return '<{}{} instance>'.format(type(self).__name__, self.env) return '<{}{}>'.format(type(self).__name__, self.env)
def __repr__(self):
return str(self)
@property @property
def spec(self): def spec(self):
@@ -340,6 +357,34 @@ class Wrapper(Env):
@spec.setter @spec.setter
def spec(self, spec): def spec(self, spec):
# Won't have an env attr while in the __new__ from gym.Env # Won't have an env attr while in the __new__ from gym.Env
if hasattr(self, 'env'): if self.env is not None:
self.env.spec = spec self.env.spec = spec
self._spec = spec self._spec = spec
class ObservationWrapper(Wrapper):
def _reset(self):
observation = self.env.reset()
return self._observation(observation)
def _step(self, action):
observation, reward, done, info = self.env.step(action)
return self._observation(observation), reward, done, info
def _observation(self, observation):
raise NotImplementedError
class RewardWrapper(Wrapper):
def _step(self, action):
observation, reward, done, info = self.env.step(action)
return observation, self._reward(reward), done, info
def _reward(self, reward):
raise NotImplementedError
class ActionWrapper(Wrapper):
def _step(self, action):
action = self._action(action)
return self.env.step(action)
def _action(self, action):
raise NotImplementedError