mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-01 10:27:43 +00:00
removed calls to reset from init (#2394)
* removed all calls to reset * passing tests * fix off-by-one error * revert * merge master into branch * add OrderEnforcing Wrapper * add orderenforcing to the docs * add option for disabling * add argument to EnvSpec
This commit is contained in:
@@ -141,4 +141,15 @@ Lastly the `name_prefix` allows you to customize the name of the videos.
|
||||
`TimeLimit(env, max_episode_steps)` [text]
|
||||
* Needs review (including for good assertion messages and test coverage)
|
||||
|
||||
`OrderEnforcing(env)` [text]
|
||||
|
||||
`OrderEnforcing` is a light-weight wrapper that throws an exception when `env.step()` is called before `env.reset()`, the wrapper is enabled by default for environment specs without `max_episode_steps` and can be disabled by passing `order_enforce=False` like:
|
||||
```python3
|
||||
register(
|
||||
id="CustomEnv-v1",
|
||||
entry_point="...",
|
||||
order_enforce=False,
|
||||
)
|
||||
```
|
||||
|
||||
Some sort of vector environment conversion wrapper needs to be added here, this will be figured out after the API is changed.
|
||||
|
@@ -142,8 +142,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
categoryBits=0x0001,
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
high = np.array([np.inf] * 24).astype(np.float32)
|
||||
self.action_space = spaces.Box(
|
||||
np.array([-1, -1, -1, -1]).astype(np.float32),
|
||||
|
@@ -117,8 +117,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
# Nop, fire left engine, main engine, right engine
|
||||
self.action_space = spaces.Discrete(4)
|
||||
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
@@ -85,7 +85,6 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
)
|
||||
|
||||
self.seed()
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
|
@@ -40,6 +40,7 @@ class EnvSpec(object):
|
||||
reward_threshold (Optional[int]): The reward threshold before the task is considered solved
|
||||
nondeterministic (bool): Whether this environment is non-deterministic even after seeding
|
||||
max_episode_steps (Optional[int]): The maximum number of steps that an episode can consist of
|
||||
order_enforce (Optional[int]): Whether to wrap the environment in an orderEnforcing wrapper
|
||||
kwargs (dict): The kwargs to pass to the environment class
|
||||
|
||||
"""
|
||||
@@ -51,6 +52,7 @@ class EnvSpec(object):
|
||||
reward_threshold=None,
|
||||
nondeterministic=False,
|
||||
max_episode_steps=None,
|
||||
order_enforce=True,
|
||||
kwargs=None,
|
||||
):
|
||||
self.id = id
|
||||
@@ -58,6 +60,7 @@ class EnvSpec(object):
|
||||
self.reward_threshold = reward_threshold
|
||||
self.nondeterministic = nondeterministic
|
||||
self.max_episode_steps = max_episode_steps
|
||||
self.order_enforce = order_enforce
|
||||
self._kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
match = env_id_re.search(id)
|
||||
@@ -77,8 +80,10 @@ class EnvSpec(object):
|
||||
self.id
|
||||
)
|
||||
)
|
||||
|
||||
_kwargs = self._kwargs.copy()
|
||||
_kwargs.update(kwargs)
|
||||
|
||||
if callable(self.entry_point):
|
||||
env = self.entry_point(**_kwargs)
|
||||
else:
|
||||
@@ -89,7 +94,15 @@ class EnvSpec(object):
|
||||
spec = copy.deepcopy(self)
|
||||
spec._kwargs = _kwargs
|
||||
env.unwrapped.spec = spec
|
||||
if env.spec.max_episode_steps is not None:
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
|
||||
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
|
||||
else:
|
||||
if self.order_enforce:
|
||||
from gym.wrappers.order_enforcing import OrderEnforcing
|
||||
|
||||
env = OrderEnforcing(env)
|
||||
return env
|
||||
|
||||
def __repr__(self):
|
||||
@@ -115,10 +128,6 @@ class EnvRegistry(object):
|
||||
logger.info("Making new env: %s", path)
|
||||
spec = self.spec(path)
|
||||
env = spec.make(**kwargs)
|
||||
if env.spec.max_episode_steps is not None:
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
|
||||
env = TimeLimit(env, max_episode_steps=env.spec.max_episode_steps)
|
||||
return env
|
||||
|
||||
def all(self):
|
||||
|
@@ -86,8 +86,6 @@ class BlackjackEnv(gym.Env):
|
||||
|
||||
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
|
||||
self.sab = sab
|
||||
# Start the first game
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
|
@@ -300,10 +300,14 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
|
||||
|
||||
# ============= Check the spaces (observation and action) ================
|
||||
_check_spaces(env)
|
||||
|
||||
# Define aliases for convenience
|
||||
observation_space = env.observation_space
|
||||
action_space = env.action_space
|
||||
try:
|
||||
env.step(env.action_space.sample())
|
||||
|
||||
except AssertionError as e:
|
||||
assert str(e) == "Cannot call env.step() before calling reset()"
|
||||
|
||||
# Warn the user if needed.
|
||||
# A warning means that the environment may run but not work properly with popular RL libraries.
|
||||
|
16
gym/wrappers/order_enforcing.py
Normal file
16
gym/wrappers/order_enforcing.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import gym
|
||||
|
||||
|
||||
class OrderEnforcing(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super(OrderEnforcing, self).__init__(env)
|
||||
self._has_reset = False
|
||||
|
||||
def step(self, action):
|
||||
assert self._has_reset, "Cannot call env.step() before calling reset()"
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._has_reset = True
|
||||
return self.env.reset(**kwargs)
|
@@ -11,6 +11,7 @@ from gym.wrappers import (
|
||||
|
||||
|
||||
def test_record_video_using_default_trigger():
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.wrappers.RecordVideo(env, "videos")
|
||||
env.reset()
|
||||
@@ -66,7 +67,7 @@ def test_record_video_within_vector():
|
||||
_, _, _, infos = envs.step(envs.action_space.sample())
|
||||
for info in infos:
|
||||
if "episode" in info.keys():
|
||||
print(f"i, episode_reward={info['episode']['r']}")
|
||||
print(f"episode_reward={info['episode']['r']}")
|
||||
break
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
|
Reference in New Issue
Block a user