removed calls to reset from init (#2394)

* removed all calls to reset

* passing tests

* fix off-by-one error

* revert

* merge master into branch

* add OrderEnforcing Wrapper

* add orderenforcing to the docs

* add option for disabling

* add argument to EnvSpec
This commit is contained in:
Ahmed Omar
2021-09-16 16:16:49 +02:00
committed by GitHub
parent e212043a93
commit 2754d9737e
9 changed files with 47 additions and 13 deletions

View File

@@ -141,4 +141,15 @@ Lastly the `name_prefix` allows you to customize the name of the videos.
`TimeLimit(env, max_episode_steps)` [text] `TimeLimit(env, max_episode_steps)` [text]
* Needs review (including for good assertion messages and test coverage) * Needs review (including for good assertion messages and test coverage)
`OrderEnforcing(env)` [text]
`OrderEnforcing` is a light-weight wrapper that throws an exception when `env.step()` is called before `env.reset()`, the wrapper is enabled by default for environment specs without `max_episode_steps` and can be disabled by passing `order_enforce=False` like:
```python3
register(
id="CustomEnv-v1",
entry_point="...",
order_enforce=False,
)
```
Some sort of vector environment conversion wrapper needs to be added here, this will be figured out after the API is changed. Some sort of vector environment conversion wrapper needs to be added here, this will be figured out after the API is changed.

View File

@@ -142,8 +142,6 @@ class BipedalWalker(gym.Env, EzPickle):
categoryBits=0x0001, categoryBits=0x0001,
) )
self.reset()
high = np.array([np.inf] * 24).astype(np.float32) high = np.array([np.inf] * 24).astype(np.float32)
self.action_space = spaces.Box( self.action_space = spaces.Box(
np.array([-1, -1, -1, -1]).astype(np.float32), np.array([-1, -1, -1, -1]).astype(np.float32),

View File

@@ -117,8 +117,6 @@ class LunarLander(gym.Env, EzPickle):
# Nop, fire left engine, main engine, right engine # Nop, fire left engine, main engine, right engine
self.action_space = spaces.Discrete(4) self.action_space = spaces.Discrete(4)
self.reset()
def seed(self, seed=None): def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
return [seed] return [seed]

View File

@@ -85,7 +85,6 @@ class Continuous_MountainCarEnv(gym.Env):
) )
self.seed() self.seed()
self.reset()
def seed(self, seed=None): def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)

View File

@@ -40,6 +40,7 @@ class EnvSpec(object):
reward_threshold (Optional[int]): The reward threshold before the task is considered solved reward_threshold (Optional[int]): The reward threshold before the task is considered solved
nondeterministic (bool): Whether this environment is non-deterministic even after seeding nondeterministic (bool): Whether this environment is non-deterministic even after seeding
max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of
order_enforce (Optional[int]): Whether to wrap the environment in an orderEnforcing wrapper
kwargs (dict): The kwargs to pass to the environment class kwargs (dict): The kwargs to pass to the environment class
""" """
@@ -51,6 +52,7 @@ class EnvSpec(object):
reward_threshold=None, reward_threshold=None,
nondeterministic=False, nondeterministic=False,
max_episode_steps=None, max_episode_steps=None,
order_enforce=True,
kwargs=None, kwargs=None,
): ):
self.id = id self.id = id
@@ -58,6 +60,7 @@ class EnvSpec(object):
self.reward_threshold = reward_threshold self.reward_threshold = reward_threshold
self.nondeterministic = nondeterministic self.nondeterministic = nondeterministic
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
self.order_enforce = order_enforce
self._kwargs = {} if kwargs is None else kwargs self._kwargs = {} if kwargs is None else kwargs
match = env_id_re.search(id) match = env_id_re.search(id)
@@ -77,8 +80,10 @@ class EnvSpec(object):
self.id self.id
) )
) )
_kwargs = self._kwargs.copy() _kwargs = self._kwargs.copy()
_kwargs.update(kwargs) _kwargs.update(kwargs)
if callable(self.entry_point): if callable(self.entry_point):
env = self.entry_point(**_kwargs) env = self.entry_point(**_kwargs)
else: else:
@@ -89,7 +94,15 @@ class EnvSpec(object):
spec = copy.deepcopy(self) spec = copy.deepcopy(self)
spec._kwargs = _kwargs spec._kwargs = _kwargs
env.unwrapped.spec = spec env.unwrapped.spec = spec
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
else:
if self.order_enforce:
from gym.wrappers.order_enforcing import OrderEnforcing
env = OrderEnforcing(env)
return env return env
def __repr__(self): def __repr__(self):
@@ -115,10 +128,6 @@ class EnvRegistry(object):
logger.info("Making new env: %s", path) logger.info("Making new env: %s", path)
spec = self.spec(path) spec = self.spec(path)
env = spec.make(**kwargs) env = spec.make(**kwargs)
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
return env return env
def all(self): def all(self):

View File

@@ -86,8 +86,6 @@ class BlackjackEnv(gym.Env):
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural # Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
self.sab = sab self.sab = sab
# Start the first game
self.reset()
def seed(self, seed=None): def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)

View File

@@ -300,10 +300,14 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
# ============= Check the spaces (observation and action) ================ # ============= Check the spaces (observation and action) ================
_check_spaces(env) _check_spaces(env)
# Define aliases for convenience # Define aliases for convenience
observation_space = env.observation_space observation_space = env.observation_space
action_space = env.action_space action_space = env.action_space
try:
env.step(env.action_space.sample())
except AssertionError as e:
assert str(e) == "Cannot call env.step() before calling reset()"
# Warn the user if needed. # Warn the user if needed.
# A warning means that the environment may run but not work properly with popular RL libraries. # A warning means that the environment may run but not work properly with popular RL libraries.

View File

@@ -0,0 +1,16 @@
import gym
class OrderEnforcing(gym.Wrapper):
def __init__(self, env):
super(OrderEnforcing, self).__init__(env)
self._has_reset = False
def step(self, action):
assert self._has_reset, "Cannot call env.step() before calling reset()"
observation, reward, done, info = self.env.step(action)
return observation, reward, done, info
def reset(self, **kwargs):
self._has_reset = True
return self.env.reset(**kwargs)

View File

@@ -11,6 +11,7 @@ from gym.wrappers import (
def test_record_video_using_default_trigger(): def test_record_video_using_default_trigger():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos") env = gym.wrappers.RecordVideo(env, "videos")
env.reset() env.reset()
@@ -66,7 +67,7 @@ def test_record_video_within_vector():
_, _, _, infos = envs.step(envs.action_space.sample()) _, _, _, infos = envs.step(envs.action_space.sample())
for info in infos: for info in infos:
if "episode" in info.keys(): if "episode" in info.keys():
print(f"i, episode_reward={info['episode']['r']}") print(f"episode_reward={info['episode']['r']}")
break break
assert os.path.isdir("videos") assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]