[WIP] Start adding Filter API (#329)

Expand Wrapper API
This commit is contained in:
Greg Brockman
2016-09-04 00:38:03 -07:00
committed by GitHub
parent 1452dc3ca2
commit f30ff469b8
2 changed files with 61 additions and 16 deletions

View File

@@ -31,7 +31,7 @@ del logger_setup
sanity_check_dependencies()
from gym.core import Env, Space, Wrapper
from gym.core import Env, Space, Wrapper, ObservationWrapper, ActionWrapper, RewardWrapper
from gym.envs import make, spec
from gym.scoreboard.api import upload

View File

@@ -122,14 +122,16 @@ class Env(object):
return observation, reward, done, info
def reset(self):
"""
Resets the state of the environment and returns an initial observation.
"""Resets the state of the environment and returns an initial
observation. Will call 'configure()' if not already called.
Returns:
observation (object): the initial observation of the space. (Initial reward is assumed to be 0.)
Returns: observation (object): the initial observation of the
space. (Initial reward is assumed to be 0.)
"""
if self.metadata.get('configure.required') and not self._configured:
raise error.Error("{} requires calling 'configure()' before 'reset()'".format(self))
raise error.Error("{} requires manually calling 'configure()' before 'reset()'".format(self))
elif not self._configured:
self.configure()
self.monitor._before_reset()
observation = self._reset()
@@ -236,7 +238,7 @@ class Env(object):
self._configured = True
try:
return self._configure(*args, **kwargs)
self._configure(*args, **kwargs)
except TypeError as e:
# It can be confusing if you have the wrong environment
# and try calling with unsupported arguments, since your
@@ -301,14 +303,22 @@ class Space(object):
return sample_n
class Wrapper(Env):
def __init__(self, env):
# Make sure self.env is always defined, even if things break
# early.
env = None
def __init__(self, env=None):
self.env = env
self.metadata = env.metadata
self.action_space = env.action_space
self.observation_space = env.observation_space
self.reward_range = env.reward_range
self._spec = env.spec
self._unwrapped = env.unwrapped
# Merge with the base metadata
metadata = self.metadata
self.metadata = self.env.metadata.copy()
self.metadata.update(metadata)
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
self.reward_range = self.env.reward_range
self._spec = self.env.spec
self._unwrapped = self.env.unwrapped
def _step(self, action):
return self.env.step(action)
@@ -317,9 +327,13 @@ class Wrapper(Env):
return self.env.reset()
def _render(self, mode='human', close=False):
if self.env is None:
return
return self.env.render(mode, close)
def _close(self):
if self.env is None:
return
return self.env.close()
def _configure(self, *args, **kwargs):
@@ -329,7 +343,10 @@ class Wrapper(Env):
return self.env.seed(seed)
def __str__(self):
return '<{}{} instance>'.format(type(self).__name__, self.env)
return '<{}{}>'.format(type(self).__name__, self.env)
def __repr__(self):
return str(self)
@property
def spec(self):
@@ -340,6 +357,34 @@ class Wrapper(Env):
@spec.setter
def spec(self, spec):
# Won't have an env attr while in the __new__ from gym.Env
if hasattr(self, 'env'):
if self.env is not None:
self.env.spec = spec
self._spec = spec
class ObservationWrapper(Wrapper):
def _reset(self):
observation = self.env.reset()
return self._observation(observation)
def _step(self, action):
observation, reward, done, info = self.env.step(action)
return self._observation(observation), reward, done, info
def _observation(self, observation):
raise NotImplementedError
class RewardWrapper(Wrapper):
def _step(self, action):
observation, reward, done, info = self.env.step(action)
return observation, self._reward(reward), done, info
def _reward(self, reward):
raise NotImplementedError
class ActionWrapper(Wrapper):
def _step(self, action):
action = self._action(action)
return self.env.step(action)
def _action(self, action):
raise NotImplementedError