Seeding update (#2422)

* Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass

* Updated a bunch of RNG calls from the RandomState API to Generator API

* black; didn't expect that, did ya?

* Undo a typo

* blaaack

* More typo fixes

* Fixed setting/getting state in multidiscrete spaces

* Fix typo, fix a test to work with the new sampling

* Correctly (?) pass the randomly generated seed if np_random is called with None as seed

* Convert the Discrete sample to a python int (as opposed to np.int64)

* Remove some redundant imports

* First version of the compatibility layer for old-style RNG. Mainly to trigger tests.

* Removed redundant f-strings

* Style fixes, removing unused imports

* Try to make tests pass by removing atari from the dockerfile

* Try to make tests pass by removing atari from the setup

* Try to make tests pass by removing atari from the setup

* Try to make tests pass by removing atari from the setup

* First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings.

* black; didn't expect that, didya?

* Rename the reset parameter in VecEnvs back to `seed`

* Updated tests to use the new seeding method

* Removed a bunch of old `seed` calls.

Fixed a bug in AsyncVectorEnv

* Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset)

* Add explicit seed to wrappers reset

* Remove an accidental return

* Re-add some legacy functions with a warning.

* Use deprecation instead of regular warnings for the newly deprecated methods/functions
This commit is contained in:
Ariel Kwiatkowski
2021-12-08 22:14:15 +01:00
committed by GitHub
parent b84b69c872
commit c364506710
59 changed files with 386 additions and 294 deletions

View File

@@ -50,6 +50,7 @@
* `gym-foo/gym_foo/envs/foo_env.py` should look something like:
```python
from typing import Optional
import gym
from gym import error, spaces, utils
from gym.utils import seeding
@@ -61,7 +62,8 @@
...
def step(self, action):
...
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
...
def render(self, mode='human'):
...

View File

@@ -1,8 +1,10 @@
from abc import abstractmethod
from typing import Optional
import gym
from gym import error
from gym.utils import closer
from gym.utils import closer, seeding
from gym.logger import deprecation
class Env:
@@ -38,6 +40,9 @@ class Env:
action_space = None
observation_space = None
# Created
np_random = None
@abstractmethod
def step(self, action):
"""Run one timestep of the environment's dynamics. When end of
@@ -58,7 +63,7 @@ class Env:
raise NotImplementedError
@abstractmethod
def reset(self):
def reset(self, seed: Optional[int] = None):
"""Resets the environment to an initial state and returns an initial
observation.
@@ -71,7 +76,9 @@ class Env:
Returns:
observation (object): the initial observation.
"""
raise NotImplementedError
# Initialize the RNG if it's the first reset, or if the seed is manually passed
if seed is not None or self.np_random is None:
self.np_random, seed = seeding.np_random(seed)
@abstractmethod
def render(self, mode="human"):
@@ -136,7 +143,12 @@ class Env:
'seed'. Often, the main seed equals the provided 'seed', but
this won't be true if seed=None, for example.
"""
return
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead."
)
self.np_random, seed = seeding.np_random(seed)
return [seed]
@property
def unwrapped(self):
@@ -173,7 +185,8 @@ class GoalEnv(Env):
actual observations of the environment as per usual.
"""
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
# Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error(
@@ -286,8 +299,8 @@ class Wrapper(Env):
def step(self, action):
return self.env.step(action)
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(seed=seed, **kwargs)
def render(self, mode="human", **kwargs):
return self.env.render(mode, **kwargs)
@@ -313,8 +326,8 @@ class Wrapper(Env):
class ObservationWrapper(Wrapper):
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
observation = self.env.reset(seed=seed, **kwargs)
return self.observation(observation)
def step(self, action):
@@ -327,8 +340,8 @@ class ObservationWrapper(Wrapper):
class RewardWrapper(Wrapper):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(seed=seed, **kwargs)
def step(self, action):
observation, reward, done, info = self.env.step(action)
@@ -340,8 +353,8 @@ class RewardWrapper(Wrapper):
class ActionWrapper(Wrapper):
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(seed=seed, **kwargs)
def step(self, action):
return self.env.step(self.action(action))

View File

@@ -1,5 +1,6 @@
import sys
import math
from typing import Optional
import numpy as np
import Box2D
@@ -122,7 +123,6 @@ class BipedalWalker(gym.Env, EzPickle):
def __init__(self):
EzPickle.__init__(self)
self.seed()
self.viewer = None
self.world = Box2D.b2World()
@@ -149,10 +149,6 @@ class BipedalWalker(gym.Env, EzPickle):
)
self.observation_space = spaces.Box(-high, high)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _destroy(self):
if not self.terrain:
return
@@ -188,7 +184,7 @@ class BipedalWalker(gym.Env, EzPickle):
y += velocity
elif state == PIT and oneshot:
counter = self.np_random.randint(3, 5)
counter = self.np_random.integers(3, 5)
poly = [
(x, y),
(x + TERRAIN_STEP, y),
@@ -215,7 +211,7 @@ class BipedalWalker(gym.Env, EzPickle):
y -= 4 * TERRAIN_STEP
elif state == STUMP and oneshot:
counter = self.np_random.randint(1, 3)
counter = self.np_random.integers(1, 3)
poly = [
(x, y),
(x + counter * TERRAIN_STEP, y),
@@ -228,9 +224,9 @@ class BipedalWalker(gym.Env, EzPickle):
self.terrain.append(t)
elif state == STAIRS and oneshot:
stair_height = +1 if self.np_random.rand() > 0.5 else -1
stair_width = self.np_random.randint(4, 5)
stair_steps = self.np_random.randint(3, 5)
stair_height = +1 if self.np_random.random() > 0.5 else -1
stair_width = self.np_random.integers(4, 5)
stair_steps = self.np_random.integers(3, 5)
original_y = y
for s in range(stair_steps):
poly = [
@@ -266,9 +262,9 @@ class BipedalWalker(gym.Env, EzPickle):
self.terrain_y.append(y)
counter -= 1
if counter == 0:
counter = self.np_random.randint(TERRAIN_GRASS / 2, TERRAIN_GRASS)
counter = self.np_random.integers(TERRAIN_GRASS / 2, TERRAIN_GRASS)
if state == GRASS and hardcore:
state = self.np_random.randint(1, _STATES_)
state = self.np_random.integers(1, _STATES_)
oneshot = True
else:
state = GRASS
@@ -312,7 +308,8 @@ class BipedalWalker(gym.Env, EzPickle):
x2 = max(p[0] for p in poly)
self.cloud_poly.append((poly, x1, x2))
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_bug_workaround = ContactDetector(self)
self.world.contactListener = self.world.contactListener_bug_workaround

View File

@@ -32,6 +32,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
"""
import sys
import math
from typing import Optional
import numpy as np
import Box2D
@@ -121,7 +123,6 @@ class CarRacing(gym.Env, EzPickle):
def __init__(self, verbose=1):
EzPickle.__init__(self)
self.seed()
self.contactListener_keepref = FrictionDetector(self)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.viewer = None
@@ -145,10 +146,6 @@ class CarRacing(gym.Env, EzPickle):
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _destroy(self):
if not self.road:
return
@@ -343,7 +340,8 @@ class CarRacing(gym.Env, EzPickle):
self.track = track
return True
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy()
self.reward = 0.0
self.prev_reward = 0.0

View File

@@ -28,6 +28,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
import math
import sys
from typing import Optional
import numpy as np
import Box2D
@@ -93,7 +95,6 @@ class LunarLander(gym.Env, EzPickle):
def __init__(self):
EzPickle.__init__(self)
self.seed()
self.viewer = None
self.world = Box2D.b2World()
@@ -117,10 +118,6 @@ class LunarLander(gym.Env, EzPickle):
# Nop, fire left engine, main engine, right engine
self.action_space = spaces.Discrete(4)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _destroy(self):
if not self.moon:
return
@@ -133,7 +130,8 @@ class LunarLander(gym.Env, EzPickle):
self.world.DestroyBody(self.legs[0])
self.world.DestroyBody(self.legs[1])
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_keepref = ContactDetector(self)
self.world.contactListener = self.world.contactListener_keepref
@@ -504,10 +502,9 @@ def heuristic(env, s):
def demo_heuristic_lander(env, seed=None, render=False):
env.seed(seed)
total_reward = 0
steps = 0
s = env.reset()
s = env.reset(seed=seed)
while True:
a = heuristic(env, s)
s, r, done, info = env.step(a)

View File

@@ -1,4 +1,6 @@
"""classic Acrobot task"""
from typing import Optional
import numpy as np
from numpy import sin, cos, pi
@@ -94,13 +96,9 @@ class AcrobotEnv(core.Env):
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
self.action_space = spaces.Discrete(3)
self.state = None
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
np.float32
)

View File

@@ -5,6 +5,8 @@ permalink: https://perma.cc/C9ZM-652R
"""
import math
from typing import Optional
import gym
from gym import spaces, logger
from gym.utils import seeding
@@ -90,16 +92,11 @@ class CartPoleEnv(gym.Env):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
self.seed()
self.viewer = None
self.state = None
self.steps_beyond_done = None
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg
@@ -158,7 +155,8 @@ class CartPoleEnv(gym.Env):
return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None
return np.array(self.state, dtype=np.float32)

View File

@@ -14,6 +14,7 @@ permalink: https://perma.cc/6Z2N-PFWC
"""
import math
from typing import Optional
import numpy as np
@@ -83,12 +84,6 @@ class Continuous_MountainCarEnv(gym.Env):
low=self.low_state, high=self.high_state, dtype=np.float32
)
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
position = self.state[0]
@@ -119,7 +114,8 @@ class Continuous_MountainCarEnv(gym.Env):
self.state = np.array([position, velocity], dtype=np.float32)
return self.state, reward, done, {}
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32)

View File

@@ -3,6 +3,7 @@ http://incompleteideas.net/MountainCar/MountainCar1.cp
permalink: https://perma.cc/6Z2N-PFWC
"""
import math
from typing import Optional
import numpy as np
@@ -72,12 +73,6 @@ class MountainCarEnv(gym.Env):
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
assert self.action_space.contains(
action
@@ -97,7 +92,8 @@ class MountainCarEnv(gym.Env):
self.state = (position, velocity)
return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym
from gym import spaces
from gym.utils import seeding
@@ -24,12 +26,6 @@ class PendulumEnv(gym.Env):
)
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, u):
th, thdot = self.state # th := theta
@@ -49,7 +45,8 @@ class PendulumEnv(gym.Env):
self.state = np.array([newth, newthdot])
return self._get_obs(), -costs, False, {}
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None

View File

@@ -48,7 +48,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-0.1, high=0.1
)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -131,8 +131,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
qvel = (
self.init_qvel
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
)
self.set_state(qpos, qvel)

View File

@@ -31,7 +31,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -74,8 +74,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
qvel = (
self.init_qvel
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
)
self.set_state(qpos, qvel)

View File

@@ -35,7 +35,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.set_state(
self.init_qpos
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1,
)
return self._get_obs()

View File

@@ -1,6 +1,6 @@
from collections import OrderedDict
import os
from typing import Optional
from gym import error, spaces
from gym.utils import seeding
@@ -73,8 +73,6 @@ class MujocoEnv(gym.Env):
self._set_observation_space(observation)
self.seed()
def _set_action_space(self):
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
low, high = bounds.T
@@ -85,10 +83,6 @@ class MujocoEnv(gym.Env):
self.observation_space = convert_observation_to_space(observation)
return self.observation_space
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
# methods to override:
# ----------------------------
@@ -109,7 +103,8 @@ class MujocoEnv(gym.Env):
# -----------------------------
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.sim.reset()
ob = self.reset_model()
return ob

View File

@@ -183,7 +183,7 @@ class ManipulateEnv(hand_env.HandEnv):
axis = np.array([0.0, 0.0, 1.0])
z_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats))
self.np_random.integers(len(self.parallel_quats))
]
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
@@ -254,7 +254,7 @@ class ManipulateEnv(hand_env.HandEnv):
axis = np.array([0.0, 0.0, 1.0])
target_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats))
self.np_random.integers(len(self.parallel_quats))
]
target_quat = rotations.quat_mul(target_quat, parallel_quat)
elif self.target_rotation == "xyz":

View File

@@ -1,5 +1,7 @@
import os
import copy
from typing import Optional
import numpy as np
import gym
@@ -37,7 +39,6 @@ class RobotEnv(gym.GoalEnv):
"video.frames_per_second": int(np.round(1.0 / self.dt)),
}
self.seed()
self._env_setup(initial_qpos=initial_qpos)
self.initial_state = copy.deepcopy(self.sim.get_state())
@@ -65,10 +66,6 @@ class RobotEnv(gym.GoalEnv):
# Env methods
# ----------------------------
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
if np.array(action).shape != self.action_space.shape:
raise ValueError("Action dimension mismatch")
@@ -86,13 +83,13 @@ class RobotEnv(gym.GoalEnv):
reward = self.compute_reward(obs["achieved_goal"], self.goal, info)
return obs, reward, done, info
def reset(self):
def reset(self, seed: Optional[int] = None):
# Attempt to reset the simulator. Since we randomize initial conditions, it
# is possible to get into a state with numerical issues (e.g. due to penetration or
# Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand).
# In this case, we just keep randomizing until we eventually achieve a valid initial
# configuration.
super().reset()
super().reset(seed=seed)
did_reset_sim = False
while not did_reset_sim:
did_reset_sim = self._reset_sim()

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym
from gym import spaces
from gym.utils import seeding
@@ -78,7 +80,6 @@ class BlackjackEnv(gym.Env):
self.observation_space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
self.seed()
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
# Ref: http://www.bicyclecards.com/how-to-play/blackjack/
@@ -87,10 +88,6 @@ class BlackjackEnv(gym.Env):
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
self.sab = sab
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
assert self.action_space.contains(action)
if action: # hit: add a card to players hand and return
@@ -122,7 +119,8 @@ class BlackjackEnv(gym.Env):
def _get_obs(self):
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.dealer = draw_hand(self.np_random)
self.player = draw_hand(self.np_random)
return self._get_obs()

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
from gym import Env, spaces
@@ -11,7 +13,7 @@ def categorical_sample(prob_n, np_random):
"""
prob_n = np.asarray(prob_n)
csprob_n = np.cumsum(prob_n)
return (csprob_n > np_random.rand()).argmax()
return (csprob_n > np_random.random()).argmax()
class DiscreteEnv(Env):
@@ -40,14 +42,8 @@ class DiscreteEnv(Env):
self.action_space = spaces.Discrete(self.nA)
self.observation_space = spaces.Discrete(self.nS)
self.seed()
self.s = categorical_sample(self.isd, self.np_random)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.s = categorical_sample(self.isd, self.np_random)
self.lastaction = None
return int(self.s)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import gym
from gym import spaces
@@ -52,7 +54,6 @@ class CubeCrash(gym.Env):
use_random_colors = False # Makes env too hard
def __init__(self):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(
@@ -62,23 +63,20 @@ class CubeCrash(gym.Env):
self.reset()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def random_color(self):
return np.array(
[
self.np_random.randint(low=0, high=255),
self.np_random.randint(low=0, high=255),
self.np_random.randint(low=0, high=255),
self.np_random.integers(low=0, high=255),
self.np_random.integers(low=0, high=255),
self.np_random.integers(low=0, high=255),
]
).astype("uint8")
def reset(self):
self.cube_x = self.np_random.randint(low=3, high=FIELD_W - 3)
self.cube_y = self.np_random.randint(low=3, high=FIELD_H // 6)
self.hole_x = self.np_random.randint(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.cube_x = self.np_random.integers(low=3, high=FIELD_W - 3)
self.cube_y = self.np_random.integers(low=3, high=FIELD_H // 6)
self.hole_x = self.np_random.integers(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
self.bg_color = self.random_color() if self.use_random_colors else color_black
self.potential = None
self.step_n = 0
@@ -95,6 +93,7 @@ class CubeCrash(gym.Env):
):
continue
break
return self.step(0)[0]
def step(self, action):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import gym
from gym import spaces
@@ -60,7 +62,6 @@ class MemorizeDigits(gym.Env):
use_random_colors = False
def __init__(self):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
@@ -74,22 +75,19 @@ class MemorizeDigits(gym.Env):
]
self.reset()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def random_color(self):
return np.array(
[
self.np_random.randint(low=0, high=255),
self.np_random.randint(low=0, high=255),
self.np_random.randint(low=0, high=255),
self.np_random.integers(low=0, high=255),
self.np_random.integers(low=0, high=255),
self.np_random.integers(low=0, high=255),
]
).astype("uint8")
def reset(self):
self.digit_x = self.np_random.randint(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
self.digit_y = self.np_random.randint(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.digit_x = self.np_random.integers(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
self.digit_y = self.np_random.integers(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
self.color_bg = self.random_color() if self.use_random_colors else color_black
self.step_n = 0
while 1:
@@ -111,8 +109,8 @@ class MemorizeDigits(gym.Env):
else:
if self.digit == action:
reward = +1
done = self.step_n > 20 and 0 == self.np_random.randint(low=0, high=5)
self.digit = self.np_random.randint(low=0, high=10)
done = self.step_n > 20 and 0 == self.np_random.integers(low=0, high=5)
self.digit = self.np_random.integers(low=0, high=10)
obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8)
obs[:, :, :] = self.color_bg
digit_img = np.zeros((6, 6, 3), dtype=np.uint8)

View File

@@ -35,7 +35,7 @@ class MultiBinary(Space):
super().__init__(input_n, np.int8, seed)
def sample(self):
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x):
if isinstance(x, list) or isinstance(x, tuple):

View File

@@ -35,9 +35,7 @@ class MultiDiscrete(Space):
super().__init__(self.nvec.shape, dtype, seed)
def sample(self):
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
self.dtype
)
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x):
if isinstance(x, list):
@@ -61,7 +59,7 @@ class MultiDiscrete(Space):
subspace = Discrete(nvec)
else:
subspace = MultiDiscrete(nvec, self.dtype)
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
return subspace
def __len__(self):

View File

@@ -1,43 +1,118 @@
import hashlib
import numpy as np
import os
import random as _random
import struct
import sys
from typing import Optional
import numpy as np
from numpy.random import Generator
from gym import error
from gym.logger import deprecation
def np_random(seed=None):
def np_random(seed: Optional[int] = None):
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
seed = create_seed(seed)
rng = np.random.RandomState()
rng.seed(_int_list_from_bigint(hash_seed(seed)))
seed_seq = np.random.SeedSequence(seed)
seed = seed_seq.entropy
rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
return rng, seed
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator):
def rand(self, *size):
deprecation(
"Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `Generator.random(size)` instead."
)
return self.random(size)
random_sample = rand
def randn(self, *size):
deprecation(
"Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.standard_normal(size)` instead."
)
return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int):
deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.integers(low, [high, size, dtype])` instead."
)
return self.integers(low=low, high=high, size=size, dtype=dtype)
random_integers = randint
def get_state(self):
deprecation(
"Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state` instead."
)
return self.bit_generator.state
def set_state(self, state):
deprecation(
"Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state = state` instead."
)
self.bit_generator.state = state
def seed(self, seed=None):
deprecation(
"Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
)
self.bit_generator.state = type(self.bit_generator)(seed).state
rand.__doc__ = np.random.rand.__doc__
randn.__doc__ = np.random.randn.__doc__
randint.__doc__ = np.random.randint.__doc__
get_state.__doc__ = np.random.get_state.__doc__
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
RNG = RandomNumberGenerator
# Legacy functions
def hash_seed(seed=None, max_bytes=8):
"""Any given evaluation is likely to have many PRNG's active at
once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having
linear correlations between seeds of multiple PRNG's can correlate
the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.)
Args:
seed (Optional[int]): None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed.
"""
deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
)
if seed is None:
seed = create_seed(max_bytes=max_bytes)
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
@@ -48,11 +123,13 @@ def create_seed(a=None, max_bytes=8):
"""Create a strong random seed. Otherwise, Python 2 would seed using
the system time, which might be non-robust especially in the
presence of concurrency.
Args:
a (Optional[int, str]): None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed.
"""
deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
)
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
if a is None:
a = _bigint_from_bytes(os.urandom(max_bytes))
@@ -70,6 +147,9 @@ def create_seed(a=None, max_bytes=8):
# TODO: don't hardcode sizeof_int here
def _bigint_from_bytes(bytes):
deprecation(
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
)
sizeof_int = 4
padding = sizeof_int - len(bytes) % sizeof_int
bytes += b"\0" * padding
@@ -82,6 +162,9 @@ def _bigint_from_bytes(bytes):
def _int_list_from_bigint(bigint):
deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
)
# Special case 0
if bigint < 0:
raise error.Error(f"Seed must be non-negative, not {bigint}")

View File

@@ -1,3 +1,5 @@
from typing import Optional, Union, List
import numpy as np
import multiprocessing as mp
import time
@@ -6,6 +8,7 @@ from enum import Enum
from copy import deepcopy
from gym import logger
from gym.logger import warn
from gym.vector.vector_env import VectorEnv
from gym.error import (
AlreadyPendingCallError,
@@ -187,13 +190,14 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
self._check_observation_spaces()
def seed(self, seeds=None):
def seed(self, seed=None):
super().seed(seed=seed)
self._assert_is_running()
if seeds is None:
seeds = [None for _ in range(self.num_envs)]
if isinstance(seeds, int):
seeds = [seeds + i for i in range(self.num_envs)]
assert len(seeds) == self.num_envs
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
@@ -201,12 +205,12 @@ class AsyncVectorEnv(VectorEnv):
self._state.value,
)
for pipe, seed in zip(self.parent_pipes, seeds):
for pipe, seed in zip(self.parent_pipes, seed):
pipe.send(("seed", seed))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def reset_async(self):
def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
"""Send the calls to :obj:`reset` to each sub-environment.
Raises
@@ -221,24 +225,31 @@ class AsyncVectorEnv(VectorEnv):
between.
"""
self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
self._state.value,
)
for pipe in self.parent_pipes:
pipe.send(("reset", None))
for pipe, single_seed in zip(self.parent_pipes, seed):
pipe.send(("reset", single_seed))
self._state = AsyncState.WAITING_RESET
def reset_wait(self, timeout=None):
"""Wait for the calls to :obj:`reset` in each sub-environment to finish.
def reset_wait(self, timeout=None, seed: Optional[int] = None):
"""
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to :meth:`reset_wait` times out.
If ``None``, the call to :meth:`reset_wait` never times out.
Number of seconds before the call to `reset_wait` times out. If
`None`, the call to `reset_wait` never times out.
seed: ignored
Returns
-------
@@ -486,7 +497,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
while True:
command, data = pipe.recv()
if command == "reset":
observation = env.reset()
observation = env.reset(data)
pipe.send((observation, True))
elif command == "step":
observation, reward, done, info = env.step(data)
@@ -524,7 +535,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
while True:
command, data = pipe.recv()
if command == "reset":
observation = env.reset()
observation = env.reset(data)
write_to_shared_memory(
index, observation, shared_memory, observation_space
)

View File

@@ -1,7 +1,10 @@
from typing import List, Union, Optional
import numpy as np
from copy import deepcopy
from gym import logger
from gym.logger import warn
from gym.vector.vector_env import VectorEnv
from gym.vector.utils import concatenate, create_empty_array
@@ -72,21 +75,28 @@ class SyncVectorEnv(VectorEnv):
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
def seed(self, seeds=None):
if seeds is None:
seeds = [None for _ in range(self.num_envs)]
if isinstance(seeds, int):
seeds = [seeds + i for i in range(self.num_envs)]
assert len(seeds) == self.num_envs
def seed(self, seed=None):
super().seed(seed=seed)
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
for env, seed in zip(self.envs, seeds):
env.seed(seed)
for env, single_seed in zip(self.envs, seed):
env.seed(single_seed)
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
def reset_wait(self):
self._dones[:] = False
observations = []
for env in self.envs:
observation = env.reset()
for env, single_seed in zip(self.envs, seed):
observation = env.reset(seed=single_seed)
observations.append(observation)
self.observations = concatenate(
observations, self.observations, self.single_observation_space

View File

@@ -1,4 +1,7 @@
from typing import Optional, Union, List
import gym
from gym.logger import warn, deprecation
from gym.spaces import Tuple
from gym.vector.utils.spaces import batch_space
@@ -43,13 +46,13 @@ class VectorEnv(gym.Env):
self.single_observation_space = observation_space
self.single_action_space = action_space
def reset_async(self):
def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
pass
def reset_wait(self, **kwargs):
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
raise NotImplementedError()
def reset(self):
def reset(self, seed: Optional[Union[int, List[int]]] = None):
r"""Reset all sub-environments and return a batch of initial observations.
Returns
@@ -57,8 +60,8 @@ class VectorEnv(gym.Env):
element of :attr:`observation_space`
A batch of observations from the vectorized environment.
"""
self.reset_async()
return self.reset_wait()
self.reset_async(seed=seed)
return self.reset_wait(seed=seed)
def step_async(self, actions):
pass
@@ -120,19 +123,22 @@ class VectorEnv(gym.Env):
self.close_extras(**kwargs)
self.closed = True
def seed(self, seeds=None):
def seed(self, seed=None):
"""Set the random seed in all sub-environments.
Parameters
----------
seeds : list of int, or int, optional
Random seed for each sub-environment. If ``seeds`` is a list of
seed : list of int, or int, optional
Random seed for each sub-environment. If ``seed`` is a list of
length ``num_envs``, then the items of the list are chosen as random
seeds. If ``seeds`` is an int, then each sub-environment uses the random
seed ``seeds + n``, where ``n`` is the index of the sub-environment
seeds. If ``seed`` is an int, then each sub-environment uses the random
seed ``seed + n``, where ``n`` is the index of the sub-environment
(between ``0`` and ``num_envs - 1``).
"""
pass
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead in VectorEnvs."
)
def __del__(self):
if not getattr(self, "closed", True):
@@ -164,11 +170,11 @@ class VectorEnvWrapper(VectorEnv):
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset_async(self):
return self.env.reset_async()
def reset_async(self, **kwargs):
return self.env.reset_async(**kwargs)
def reset_wait(self):
return self.env.reset_wait()
def reset_wait(self, **kwargs):
return self.env.reset_wait(**kwargs)
def step_async(self, actions):
return self.env.step_async(actions)
@@ -182,8 +188,8 @@ class VectorEnvWrapper(VectorEnv):
def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs)
def seed(self, seeds=None):
return self.env.seed(seeds)
def seed(self, seed=None):
return self.env.seed(seed)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import gym
from gym.spaces import Box
@@ -127,11 +129,11 @@ class AtariPreprocessing(gym.Wrapper):
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), R, done, info
def reset(self, **kwargs):
def reset(self, seed: Optional[int] = None, **kwargs):
# NoopReset
self.env.reset(**kwargs)
self.env.reset(seed=seed, **kwargs)
noops = (
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)

View File

@@ -1,4 +1,6 @@
from collections import deque
from typing import Optional
import numpy as np
from gym.spaces import Box
from gym import ObservationWrapper
@@ -116,7 +118,7 @@ class FrameStack(ObservationWrapper):
self.frames.append(observation)
return self.observation(), reward, done, info
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
observation = self.env.reset(seed=seed, **kwargs)
[self.frames.append(observation) for _ in range(self.num_stack)]
return self.observation()

View File

@@ -1,5 +1,6 @@
import json
import os
from typing import Optional
import numpy as np
@@ -50,9 +51,9 @@ class Monitor(Wrapper):
return observation, reward, done, info
def reset(self, **kwargs):
def reset(self, seed: Optional[int] = None, **kwargs):
self._before_reset()
observation = self.env.reset(**kwargs)
observation = self.env.reset(seed=seed, **kwargs)
self._after_reset(observation)
return observation

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import gym
@@ -61,8 +63,8 @@ class NormalizeObservation(gym.core.Wrapper):
obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos
def reset(self):
obs = self.env.reset()
def reset(self, seed: Optional[int] = None):
obs = self.env.reset(seed=seed)
if self.is_vector_env:
obs = self.normalize(obs)
else:

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym
@@ -11,6 +13,6 @@ class OrderEnforcing(gym.Wrapper):
observation, reward, done, info = self.env.step(action)
return observation, reward, done, info
def reset(self, **kwargs):
def reset(self, seed: Optional[int] = None, **kwargs):
self._has_reset = True
return self.env.reset(**kwargs)
return self.env.reset(seed=seed, **kwargs)

View File

@@ -1,5 +1,7 @@
import time
from collections import deque
from typing import Optional
import numpy as np
import gym
@@ -16,8 +18,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs):
observations = super().reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
observations = super().reset(seed=seed, **kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

View File

@@ -1,6 +1,6 @@
import os
import gym
from typing import Callable
from typing import Callable, Optional
from gym import logger
from gym.wrappers.monitoring import video_recorder
@@ -52,8 +52,8 @@ class RecordVideo(gym.Wrapper):
self.is_vector_env = getattr(env, "is_vector_env", False)
self.episode_id = 0
def reset(self, **kwargs):
observations = super().reset(**kwargs)
def reset(self, seed: Optional[int] = None, **kwargs):
observations = super().reset(seed=seed, **kwargs)
if not self.recording and self._video_enabled():
self.start_video_recorder()
return observations

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
from gym.spaces import Box
from gym import ObservationWrapper
@@ -27,6 +29,6 @@ class TimeAwareObservation(ObservationWrapper):
self.t += 1
return super().step(action)
def reset(self, **kwargs):
def reset(self, seed: Optional[int] = None, **kwargs):
self.t = 0
return super().reset(**kwargs)
return super().reset(seed=seed, **kwargs)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym
@@ -22,6 +24,6 @@ class TimeLimit(gym.Wrapper):
done = True
return observation, reward, done, info
def reset(self, **kwargs):
def reset(self, seed: Optional[int] = None, **kwargs):
self._elapsed_steps = 0
return self.env.reset(**kwargs)
return self.env.reset(seed=seed, **kwargs)

View File

@@ -15,6 +15,6 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mjpro150/bin
COPY . /usr/local/gym/
WORKDIR /usr/local/gym/
RUN pip install .[nomujoco,accept-rom-license] && pip install -r test_requirements.txt
RUN pip install .[nomujoco] && pip install -r test_requirements.txt
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]

View File

@@ -21,7 +21,7 @@ extras = {
}
# Meta dependency groups.
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license"])
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license", "atari"])
nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
extras["nomujoco"] = list(

View File

@@ -10,16 +10,14 @@ def test_env(spec):
# threads. However, we probably already can't do multithreading
# due to some environments.
env1 = spec.make()
env1.seed(0)
initial_observation1 = env1.reset()
initial_observation1 = env1.reset(seed=0)
env1.action_space.seed(0)
action_samples1 = [env1.action_space.sample() for i in range(4)]
step_responses1 = [env1.step(action) for action in action_samples1]
env1.close()
env2 = spec.make()
env2.seed(0)
initial_observation2 = env2.reset()
initial_observation2 = env2.reset(seed=0)
env2.action_space.seed(0)
action_samples2 = [env2.action_space.sample() for i in range(4)]
step_responses2 = [env2.step(action) for action in action_samples2]

View File

@@ -10,11 +10,8 @@ def verify_environments_match(
old_environment = envs.make(old_environment_id)
new_environment = envs.make(new_environment_id)
old_environment.seed(seed)
new_environment.seed(seed)
old_reset_observation = old_environment.reset()
new_reset_observation = new_environment.reset()
old_reset_observation = old_environment.reset(seed=seed)
new_reset_observation = new_environment.reset(seed=seed)
np.testing.assert_allclose(old_reset_observation, new_reset_observation)

View File

@@ -370,7 +370,7 @@ def test_seed_subspace_incorrelated(space):
space.seed(0)
states = [
convert_sample_hashable(subspace.np_random.get_state())
convert_sample_hashable(subspace.np_random.bit_generator.state)
for subspace in subspaces
]

View File

@@ -1,3 +1,5 @@
from typing import Optional
import pytest
import numpy as np
@@ -16,7 +18,8 @@ class UnittestEnv(core.Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
return self.observation_space.sample() # Dummy observation
def step(self, action):
@@ -31,7 +34,8 @@ class UnknownSpacesEnv(core.Env):
on external resources), it is not encouraged.
"""
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym
import numpy as np
import pytest
@@ -16,7 +18,8 @@ class ActionDictTestEnv(gym.Env):
done = True
return observation, reward, done
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
return np.array([1.0, 1.5, 0.5])
def render(self, mode="human"):

View File

@@ -17,17 +17,14 @@ def test_vector_env_equal(shared_memory):
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
sync_env = SyncVectorEnv(env_fns)
async_env.seed(0)
sync_env.seed(0)
assert async_env.num_envs == sync_env.num_envs
assert async_env.observation_space == sync_env.observation_space
assert async_env.single_observation_space == sync_env.single_observation_space
assert async_env.action_space == sync_env.action_space
assert async_env.single_action_space == sync_env.single_action_space
async_observations = async_env.reset()
sync_observations = sync_env.reset()
async_observations = async_env.reset(seed=0)
sync_observations = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations)
for _ in range(num_steps):

View File

@@ -8,7 +8,7 @@ class DummyWrapper(VectorEnvWrapper):
self.env = env
self.counter = 0
def reset_async(self):
def reset_async(self, **kwargs):
super().reset_async()
self.counter += 1

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import gym
import time
@@ -55,7 +57,8 @@ class UnittestSlowEnv(gym.Env):
)
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
if self.slow_reset > 0:
time.sleep(self.slow_reset)
return self.observation_space.sample()
@@ -86,7 +89,8 @@ class CustomSpaceEnv(gym.Env):
self.observation_space = CustomSpace()
self.action_space = CustomSpace()
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
return "reset"
def step(self, action):
@@ -98,7 +102,7 @@ class CustomSpaceEnv(gym.Env):
def make_env(env_name, seed):
def _make():
env = gym.make(env_name)
env.seed(seed)
env.reset(seed=seed)
return env
return _make
@@ -107,7 +111,7 @@ def make_env(env_name, seed):
def make_slow_env(slow_reset, seed):
def _make():
env = UnittestSlowEnv(slow_reset=slow_reset)
env.seed(seed)
env.reset(seed=seed)
return env
return _make
@@ -116,7 +120,7 @@ def make_slow_env(slow_reset, seed):
def make_custom_space_env(seed):
def _make():
env = CustomSpaceEnv()
env.seed(seed)
env.reset(seed=seed)
return env
return _make

View File

@@ -1,6 +1,7 @@
"""Tests for the flatten observation wrapper."""
from collections import OrderedDict
from typing import Optional
import numpy as np
import pytest
@@ -14,7 +15,8 @@ class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
self.observation_space = observation_space
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.observation = self.observation_space.sample()
return self.observation

View File

@@ -1,5 +1,5 @@
"""Tests for the filter observation wrapper."""
from typing import Optional
import pytest
import numpy as np
@@ -21,7 +21,8 @@ class FakeEnvironment(gym.Env):
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation

View File

@@ -29,14 +29,10 @@ def test_atari_preprocessing_grayscale(env_fn):
noop_max=0,
grayscale_newaxis=True,
)
env1.seed(0)
env2.seed(0)
env3.seed(0)
env4.seed(0)
obs1 = env1.reset()
obs2 = env2.reset()
obs3 = env3.reset()
obs4 = env4.reset()
obs1 = env1.reset(seed=0)
obs2 = env2.reset(seed=0)
obs3 = env3.reset(seed=0)
obs4 = env4.reset(seed=0)
assert env1.observation_space.shape == (210, 160, 3)
assert env2.observation_space.shape == (84, 84)
assert env3.observation_space.shape == (84, 84, 3)

View File

@@ -11,11 +11,9 @@ def test_clip_action():
wrapped_env = ClipAction(make_env())
seed = 0
env.seed(seed)
wrapped_env.seed(seed)
env.reset()
wrapped_env.reset()
env.reset(seed=seed)
wrapped_env.reset(seed=seed)
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
for action in actions:

View File

@@ -1,3 +1,5 @@
from typing import Optional
import pytest
import numpy as np
@@ -22,7 +24,8 @@ class FakeEnvironment(gym.Env):
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation

View File

@@ -28,17 +28,15 @@ except ImportError:
)
def test_frame_stack(env_id, num_stack, lz4_compress):
env = gym.make(env_id)
env.seed(0)
shape = env.observation_space.shape
env = FrameStack(env, num_stack, lz4_compress)
assert env.observation_space.shape == (num_stack,) + shape
assert env.observation_space.dtype == env.env.observation_space.dtype
dup = gym.make(env_id)
dup.seed(0)
obs = env.reset()
dup_obs = dup.reset()
obs = env.reset(seed=0)
dup_obs = dup.reset(seed=0)
assert np.allclose(obs[-1], dup_obs)
for _ in range(num_stack ** 2):

View File

@@ -21,11 +21,9 @@ def test_gray_scale_observation(env_id, keep_dim):
assert rgb_env.observation_space.shape[-1] == 3
seed = 0
gray_env.seed(seed)
wrapped_env.seed(seed)
gray_obs = gray_env.reset()
wrapped_obs = wrapped_env.reset()
gray_obs = gray_env.reset(seed=seed)
wrapped_obs = wrapped_env.reset(seed=seed)
if keep_dim:
assert wrapped_env.observation_space.shape[-1] == 1

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym
import numpy as np
from numpy.testing import assert_almost_equal
@@ -21,7 +23,8 @@ class DummyRewardEnv(gym.Env):
self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.t = self.return_reward_idx
return np.array([self.t])

View File

@@ -1,5 +1,5 @@
"""Tests for the pixel observation wrapper."""
from typing import Optional
import pytest
import numpy as np
@@ -19,7 +19,8 @@ class FakeEnvironment(gym.Env):
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self):
def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation

View File

@@ -16,11 +16,9 @@ def test_rescale_action():
wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1)
seed = 0
env.seed(seed)
wrapped_env.seed(seed)
obs = env.reset()
wrapped_obs = wrapped_env.reset()
obs = env.reset(seed=seed)
wrapped_obs = wrapped_env.reset(seed=seed)
assert np.allclose(obs, wrapped_obs)
obs, reward, _, _ = env.step([1.5])

View File

@@ -14,11 +14,8 @@ def test_transform_observation(env_id):
gym.make(env_id), lambda obs: affine_transform(obs)
)
env.seed(0)
wrapped_env.seed(0)
obs = env.reset()
wrapped_obs = wrapped_env.reset()
obs = env.reset(seed=0)
wrapped_obs = wrapped_env.reset(seed=0)
assert np.allclose(wrapped_obs, affine_transform(obs))
action = env.action_space.sample()

View File

@@ -15,10 +15,8 @@ def test_transform_reward(env_id):
wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r)
action = env.action_space.sample()
env.seed(0)
env.reset()
wrapped_env.seed(0)
wrapped_env.reset()
env.reset(seed=0)
wrapped_env.reset(seed=0)
_, reward, _, _ = env.step(action)
_, wrapped_reward, _, _ = wrapped_env.step(action)
@@ -33,10 +31,8 @@ def test_transform_reward(env_id):
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r))
action = env.action_space.sample()
env.seed(0)
env.reset()
wrapped_env.seed(0)
wrapped_env.reset()
env.reset(seed=0)
wrapped_env.reset(seed=0)
_, reward, _, _ = env.step(action)
_, wrapped_reward, _, _ = wrapped_env.step(action)
@@ -49,10 +45,8 @@ def test_transform_reward(env_id):
env = gym.make(env_id)
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r))
env.seed(0)
env.reset()
wrapped_env.seed(0)
wrapped_env.reset()
env.reset(seed=0)
wrapped_env.reset(seed=0)
for _ in range(1000):
action = env.action_space.sample()