Seeding update (#2422)

* Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass

* Updated a bunch of RNG calls from the RandomState API to Generator API

* black; didn't expect that, did ya?

* Undo a typo

* blaaack

* More typo fixes

* Fixed setting/getting state in multidiscrete spaces

* Fix typo, fix a test to work with the new sampling

* Correctly (?) pass the randomly generated seed if np_random is called with None as seed

* Convert the Discrete sample to a python int (as opposed to np.int64)

* Remove some redundant imports

* First version of the compatibility layer for old-style RNG. Mainly to trigger tests.

* Removed redundant f-strings

* Style fixes, removing unused imports

* Try to make tests pass by removing atari from the dockerfile

* Try to make tests pass by removing atari from the setup

* Try to make tests pass by removing atari from the setup

* Try to make tests pass by removing atari from the setup

* First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings.

* black; didn't expect that, didya?

* Rename the reset parameter in VecEnvs back to `seed`

* Updated tests to use the new seeding method

* Removed a bunch of old `seed` calls.

Fixed a bug in AsyncVectorEnv

* Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset)

* Add explicit seed to wrappers reset

* Remove an accidental return

* Re-add some legacy functions with a warning.

* Use deprecation instead of regular warnings for the newly deprecated methods/functions
This commit is contained in:
Ariel Kwiatkowski
2021-12-08 22:14:15 +01:00
committed by GitHub
parent b84b69c872
commit c364506710
59 changed files with 386 additions and 294 deletions

View File

@@ -50,6 +50,7 @@
* `gym-foo/gym_foo/envs/foo_env.py` should look something like: * `gym-foo/gym_foo/envs/foo_env.py` should look something like:
```python ```python
from typing import Optional
import gym import gym
from gym import error, spaces, utils from gym import error, spaces, utils
from gym.utils import seeding from gym.utils import seeding
@@ -61,7 +62,8 @@
... ...
def step(self, action): def step(self, action):
... ...
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
... ...
def render(self, mode='human'): def render(self, mode='human'):
... ...

View File

@@ -1,8 +1,10 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Optional
import gym import gym
from gym import error from gym import error
from gym.utils import closer from gym.utils import closer, seeding
from gym.logger import deprecation
class Env: class Env:
@@ -38,6 +40,9 @@ class Env:
action_space = None action_space = None
observation_space = None observation_space = None
# Created
np_random = None
@abstractmethod @abstractmethod
def step(self, action): def step(self, action):
"""Run one timestep of the environment's dynamics. When end of """Run one timestep of the environment's dynamics. When end of
@@ -58,7 +63,7 @@ class Env:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def reset(self): def reset(self, seed: Optional[int] = None):
"""Resets the environment to an initial state and returns an initial """Resets the environment to an initial state and returns an initial
observation. observation.
@@ -71,7 +76,9 @@ class Env:
Returns: Returns:
observation (object): the initial observation. observation (object): the initial observation.
""" """
raise NotImplementedError # Initialize the RNG if it's the first reset, or if the seed is manually passed
if seed is not None or self.np_random is None:
self.np_random, seed = seeding.np_random(seed)
@abstractmethod @abstractmethod
def render(self, mode="human"): def render(self, mode="human"):
@@ -136,7 +143,12 @@ class Env:
'seed'. Often, the main seed equals the provided 'seed', but 'seed'. Often, the main seed equals the provided 'seed', but
this won't be true if seed=None, for example. this won't be true if seed=None, for example.
""" """
return deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead."
)
self.np_random, seed = seeding.np_random(seed)
return [seed]
@property @property
def unwrapped(self): def unwrapped(self):
@@ -173,7 +185,8 @@ class GoalEnv(Env):
actual observations of the environment as per usual. actual observations of the environment as per usual.
""" """
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
# Enforce that each GoalEnv uses a Goal-compatible observation space. # Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict): if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error( raise error.Error(
@@ -286,8 +299,8 @@ class Wrapper(Env):
def step(self, action): def step(self, action):
return self.env.step(action) return self.env.step(action)
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(**kwargs) return self.env.reset(seed=seed, **kwargs)
def render(self, mode="human", **kwargs): def render(self, mode="human", **kwargs):
return self.env.render(mode, **kwargs) return self.env.render(mode, **kwargs)
@@ -313,8 +326,8 @@ class Wrapper(Env):
class ObservationWrapper(Wrapper): class ObservationWrapper(Wrapper):
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
observation = self.env.reset(**kwargs) observation = self.env.reset(seed=seed, **kwargs)
return self.observation(observation) return self.observation(observation)
def step(self, action): def step(self, action):
@@ -327,8 +340,8 @@ class ObservationWrapper(Wrapper):
class RewardWrapper(Wrapper): class RewardWrapper(Wrapper):
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(**kwargs) return self.env.reset(seed=seed, **kwargs)
def step(self, action): def step(self, action):
observation, reward, done, info = self.env.step(action) observation, reward, done, info = self.env.step(action)
@@ -340,8 +353,8 @@ class RewardWrapper(Wrapper):
class ActionWrapper(Wrapper): class ActionWrapper(Wrapper):
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
return self.env.reset(**kwargs) return self.env.reset(seed=seed, **kwargs)
def step(self, action): def step(self, action):
return self.env.step(self.action(action)) return self.env.step(self.action(action))

View File

@@ -1,5 +1,6 @@
import sys import sys
import math import math
from typing import Optional
import numpy as np import numpy as np
import Box2D import Box2D
@@ -122,7 +123,6 @@ class BipedalWalker(gym.Env, EzPickle):
def __init__(self): def __init__(self):
EzPickle.__init__(self) EzPickle.__init__(self)
self.seed()
self.viewer = None self.viewer = None
self.world = Box2D.b2World() self.world = Box2D.b2World()
@@ -149,10 +149,6 @@ class BipedalWalker(gym.Env, EzPickle):
) )
self.observation_space = spaces.Box(-high, high) self.observation_space = spaces.Box(-high, high)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _destroy(self): def _destroy(self):
if not self.terrain: if not self.terrain:
return return
@@ -188,7 +184,7 @@ class BipedalWalker(gym.Env, EzPickle):
y += velocity y += velocity
elif state == PIT and oneshot: elif state == PIT and oneshot:
counter = self.np_random.randint(3, 5) counter = self.np_random.integers(3, 5)
poly = [ poly = [
(x, y), (x, y),
(x + TERRAIN_STEP, y), (x + TERRAIN_STEP, y),
@@ -215,7 +211,7 @@ class BipedalWalker(gym.Env, EzPickle):
y -= 4 * TERRAIN_STEP y -= 4 * TERRAIN_STEP
elif state == STUMP and oneshot: elif state == STUMP and oneshot:
counter = self.np_random.randint(1, 3) counter = self.np_random.integers(1, 3)
poly = [ poly = [
(x, y), (x, y),
(x + counter * TERRAIN_STEP, y), (x + counter * TERRAIN_STEP, y),
@@ -228,9 +224,9 @@ class BipedalWalker(gym.Env, EzPickle):
self.terrain.append(t) self.terrain.append(t)
elif state == STAIRS and oneshot: elif state == STAIRS and oneshot:
stair_height = +1 if self.np_random.rand() > 0.5 else -1 stair_height = +1 if self.np_random.random() > 0.5 else -1
stair_width = self.np_random.randint(4, 5) stair_width = self.np_random.integers(4, 5)
stair_steps = self.np_random.randint(3, 5) stair_steps = self.np_random.integers(3, 5)
original_y = y original_y = y
for s in range(stair_steps): for s in range(stair_steps):
poly = [ poly = [
@@ -266,9 +262,9 @@ class BipedalWalker(gym.Env, EzPickle):
self.terrain_y.append(y) self.terrain_y.append(y)
counter -= 1 counter -= 1
if counter == 0: if counter == 0:
counter = self.np_random.randint(TERRAIN_GRASS / 2, TERRAIN_GRASS) counter = self.np_random.integers(TERRAIN_GRASS / 2, TERRAIN_GRASS)
if state == GRASS and hardcore: if state == GRASS and hardcore:
state = self.np_random.randint(1, _STATES_) state = self.np_random.integers(1, _STATES_)
oneshot = True oneshot = True
else: else:
state = GRASS state = GRASS
@@ -312,7 +308,8 @@ class BipedalWalker(gym.Env, EzPickle):
x2 = max(p[0] for p in poly) x2 = max(p[0] for p in poly)
self.cloud_poly.append((poly, x1, x2)) self.cloud_poly.append((poly, x1, x2))
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy() self._destroy()
self.world.contactListener_bug_workaround = ContactDetector(self) self.world.contactListener_bug_workaround = ContactDetector(self)
self.world.contactListener = self.world.contactListener_bug_workaround self.world.contactListener = self.world.contactListener_bug_workaround

View File

@@ -32,6 +32,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
""" """
import sys import sys
import math import math
from typing import Optional
import numpy as np import numpy as np
import Box2D import Box2D
@@ -121,7 +123,6 @@ class CarRacing(gym.Env, EzPickle):
def __init__(self, verbose=1): def __init__(self, verbose=1):
EzPickle.__init__(self) EzPickle.__init__(self)
self.seed()
self.contactListener_keepref = FrictionDetector(self) self.contactListener_keepref = FrictionDetector(self)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref) self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.viewer = None self.viewer = None
@@ -145,10 +146,6 @@ class CarRacing(gym.Env, EzPickle):
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8 low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
) )
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _destroy(self): def _destroy(self):
if not self.road: if not self.road:
return return
@@ -343,7 +340,8 @@ class CarRacing(gym.Env, EzPickle):
self.track = track self.track = track
return True return True
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy() self._destroy()
self.reward = 0.0 self.reward = 0.0
self.prev_reward = 0.0 self.prev_reward = 0.0

View File

@@ -28,6 +28,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
import math import math
import sys import sys
from typing import Optional
import numpy as np import numpy as np
import Box2D import Box2D
@@ -93,7 +95,6 @@ class LunarLander(gym.Env, EzPickle):
def __init__(self): def __init__(self):
EzPickle.__init__(self) EzPickle.__init__(self)
self.seed()
self.viewer = None self.viewer = None
self.world = Box2D.b2World() self.world = Box2D.b2World()
@@ -117,10 +118,6 @@ class LunarLander(gym.Env, EzPickle):
# Nop, fire left engine, main engine, right engine # Nop, fire left engine, main engine, right engine
self.action_space = spaces.Discrete(4) self.action_space = spaces.Discrete(4)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def _destroy(self): def _destroy(self):
if not self.moon: if not self.moon:
return return
@@ -133,7 +130,8 @@ class LunarLander(gym.Env, EzPickle):
self.world.DestroyBody(self.legs[0]) self.world.DestroyBody(self.legs[0])
self.world.DestroyBody(self.legs[1]) self.world.DestroyBody(self.legs[1])
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self._destroy() self._destroy()
self.world.contactListener_keepref = ContactDetector(self) self.world.contactListener_keepref = ContactDetector(self)
self.world.contactListener = self.world.contactListener_keepref self.world.contactListener = self.world.contactListener_keepref
@@ -504,10 +502,9 @@ def heuristic(env, s):
def demo_heuristic_lander(env, seed=None, render=False): def demo_heuristic_lander(env, seed=None, render=False):
env.seed(seed)
total_reward = 0 total_reward = 0
steps = 0 steps = 0
s = env.reset() s = env.reset(seed=seed)
while True: while True:
a = heuristic(env, s) a = heuristic(env, s)
s, r, done, info = env.step(a) s, r, done, info = env.step(a)

View File

@@ -1,4 +1,6 @@
"""classic Acrobot task""" """classic Acrobot task"""
from typing import Optional
import numpy as np import numpy as np
from numpy import sin, cos, pi from numpy import sin, cos, pi
@@ -94,13 +96,9 @@ class AcrobotEnv(core.Env):
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
self.state = None self.state = None
self.seed()
def seed(self, seed=None): def reset(self, seed: Optional[int] = None):
self.np_random, seed = seeding.np_random(seed) super().reset(seed=seed)
return [seed]
def reset(self):
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype( self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
np.float32 np.float32
) )

View File

@@ -5,6 +5,8 @@ permalink: https://perma.cc/C9ZM-652R
""" """
import math import math
from typing import Optional
import gym import gym
from gym import spaces, logger from gym import spaces, logger
from gym.utils import seeding from gym.utils import seeding
@@ -90,16 +92,11 @@ class CartPoleEnv(gym.Env):
self.action_space = spaces.Discrete(2) self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(-high, high, dtype=np.float32) self.observation_space = spaces.Box(-high, high, dtype=np.float32)
self.seed()
self.viewer = None self.viewer = None
self.state = None self.state = None
self.steps_beyond_done = None self.steps_beyond_done = None
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action): def step(self, action):
err_msg = f"{action!r} ({type(action)}) invalid" err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg assert self.action_space.contains(action), err_msg
@@ -158,7 +155,8 @@ class CartPoleEnv(gym.Env):
return np.array(self.state, dtype=np.float32), reward, done, {} return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None self.steps_beyond_done = None
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)

View File

@@ -14,6 +14,7 @@ permalink: https://perma.cc/6Z2N-PFWC
""" """
import math import math
from typing import Optional
import numpy as np import numpy as np
@@ -83,12 +84,6 @@ class Continuous_MountainCarEnv(gym.Env):
low=self.low_state, high=self.high_state, dtype=np.float32 low=self.low_state, high=self.high_state, dtype=np.float32
) )
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action): def step(self, action):
position = self.state[0] position = self.state[0]
@@ -119,7 +114,8 @@ class Continuous_MountainCarEnv(gym.Env):
self.state = np.array([position, velocity], dtype=np.float32) self.state = np.array([position, velocity], dtype=np.float32)
return self.state, reward, done, {} return self.state, reward, done, {}
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)

View File

@@ -3,6 +3,7 @@ http://incompleteideas.net/MountainCar/MountainCar1.cp
permalink: https://perma.cc/6Z2N-PFWC permalink: https://perma.cc/6Z2N-PFWC
""" """
import math import math
from typing import Optional
import numpy as np import numpy as np
@@ -72,12 +73,6 @@ class MountainCarEnv(gym.Env):
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32) self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action): def step(self, action):
assert self.action_space.contains( assert self.action_space.contains(
action action
@@ -97,7 +92,8 @@ class MountainCarEnv(gym.Env):
self.state = (position, velocity) self.state = (position, velocity)
return np.array(self.state, dtype=np.float32), reward, done, {} return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym import gym
from gym import spaces from gym import spaces
from gym.utils import seeding from gym.utils import seeding
@@ -24,12 +26,6 @@ class PendulumEnv(gym.Env):
) )
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32) self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, u): def step(self, u):
th, thdot = self.state # th := theta th, thdot = self.state # th := theta
@@ -49,7 +45,8 @@ class PendulumEnv(gym.Env):
self.state = np.array([newth, newthdot]) self.state = np.array([newth, newthdot])
return self._get_obs(), -costs, False, {} return self._get_obs(), -costs, False, {}
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
high = np.array([np.pi, 1]) high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high) self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None self.last_u = None

View File

@@ -48,7 +48,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-0.1, high=0.1 size=self.model.nq, low=-0.1, high=0.1
) )
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -131,8 +131,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq low=noise_low, high=noise_high, size=self.model.nq
) )
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( qvel = (
self.model.nv self.init_qvel
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
) )
self.set_state(qpos, qvel) self.set_state(qpos, qvel)

View File

@@ -31,7 +31,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nq low=-0.1, high=0.1, size=self.model.nq
) )
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -74,8 +74,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq low=noise_low, high=noise_high, size=self.model.nq
) )
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( qvel = (
self.model.nv self.init_qvel
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
) )
self.set_state(qpos, qvel) self.set_state(qpos, qvel)

View File

@@ -35,7 +35,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.set_state( self.set_state(
self.init_qpos self.init_qpos
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1, self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1,
) )
return self._get_obs() return self._get_obs()

View File

@@ -1,6 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Optional
from gym import error, spaces from gym import error, spaces
from gym.utils import seeding from gym.utils import seeding
@@ -73,8 +73,6 @@ class MujocoEnv(gym.Env):
self._set_observation_space(observation) self._set_observation_space(observation)
self.seed()
def _set_action_space(self): def _set_action_space(self):
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32) bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
low, high = bounds.T low, high = bounds.T
@@ -85,10 +83,6 @@ class MujocoEnv(gym.Env):
self.observation_space = convert_observation_to_space(observation) self.observation_space = convert_observation_to_space(observation)
return self.observation_space return self.observation_space
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
# methods to override: # methods to override:
# ---------------------------- # ----------------------------
@@ -109,7 +103,8 @@ class MujocoEnv(gym.Env):
# ----------------------------- # -----------------------------
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.sim.reset() self.sim.reset()
ob = self.reset_model() ob = self.reset_model()
return ob return ob

View File

@@ -183,7 +183,7 @@ class ManipulateEnv(hand_env.HandEnv):
axis = np.array([0.0, 0.0, 1.0]) axis = np.array([0.0, 0.0, 1.0])
z_quat = quat_from_angle_and_axis(angle, axis) z_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[ parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats)) self.np_random.integers(len(self.parallel_quats))
] ]
offset_quat = rotations.quat_mul(z_quat, parallel_quat) offset_quat = rotations.quat_mul(z_quat, parallel_quat)
initial_quat = rotations.quat_mul(initial_quat, offset_quat) initial_quat = rotations.quat_mul(initial_quat, offset_quat)
@@ -254,7 +254,7 @@ class ManipulateEnv(hand_env.HandEnv):
axis = np.array([0.0, 0.0, 1.0]) axis = np.array([0.0, 0.0, 1.0])
target_quat = quat_from_angle_and_axis(angle, axis) target_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[ parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats)) self.np_random.integers(len(self.parallel_quats))
] ]
target_quat = rotations.quat_mul(target_quat, parallel_quat) target_quat = rotations.quat_mul(target_quat, parallel_quat)
elif self.target_rotation == "xyz": elif self.target_rotation == "xyz":

View File

@@ -1,5 +1,7 @@
import os import os
import copy import copy
from typing import Optional
import numpy as np import numpy as np
import gym import gym
@@ -37,7 +39,6 @@ class RobotEnv(gym.GoalEnv):
"video.frames_per_second": int(np.round(1.0 / self.dt)), "video.frames_per_second": int(np.round(1.0 / self.dt)),
} }
self.seed()
self._env_setup(initial_qpos=initial_qpos) self._env_setup(initial_qpos=initial_qpos)
self.initial_state = copy.deepcopy(self.sim.get_state()) self.initial_state = copy.deepcopy(self.sim.get_state())
@@ -65,10 +66,6 @@ class RobotEnv(gym.GoalEnv):
# Env methods # Env methods
# ---------------------------- # ----------------------------
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action): def step(self, action):
if np.array(action).shape != self.action_space.shape: if np.array(action).shape != self.action_space.shape:
raise ValueError("Action dimension mismatch") raise ValueError("Action dimension mismatch")
@@ -86,13 +83,13 @@ class RobotEnv(gym.GoalEnv):
reward = self.compute_reward(obs["achieved_goal"], self.goal, info) reward = self.compute_reward(obs["achieved_goal"], self.goal, info)
return obs, reward, done, info return obs, reward, done, info
def reset(self): def reset(self, seed: Optional[int] = None):
# Attempt to reset the simulator. Since we randomize initial conditions, it # Attempt to reset the simulator. Since we randomize initial conditions, it
# is possible to get into a state with numerical issues (e.g. due to penetration or # is possible to get into a state with numerical issues (e.g. due to penetration or
# Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand). # Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand).
# In this case, we just keep randomizing until we eventually achieve a valid initial # In this case, we just keep randomizing until we eventually achieve a valid initial
# configuration. # configuration.
super().reset() super().reset(seed=seed)
did_reset_sim = False did_reset_sim = False
while not did_reset_sim: while not did_reset_sim:
did_reset_sim = self._reset_sim() did_reset_sim = self._reset_sim()

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym import gym
from gym import spaces from gym import spaces
from gym.utils import seeding from gym.utils import seeding
@@ -78,7 +80,6 @@ class BlackjackEnv(gym.Env):
self.observation_space = spaces.Tuple( self.observation_space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)) (spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
) )
self.seed()
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules # Flag to payout 1.5 on a "natural" blackjack win, like casino rules
# Ref: http://www.bicyclecards.com/how-to-play/blackjack/ # Ref: http://www.bicyclecards.com/how-to-play/blackjack/
@@ -87,10 +88,6 @@ class BlackjackEnv(gym.Env):
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural # Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
self.sab = sab self.sab = sab
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action): def step(self, action):
assert self.action_space.contains(action) assert self.action_space.contains(action)
if action: # hit: add a card to players hand and return if action: # hit: add a card to players hand and return
@@ -122,7 +119,8 @@ class BlackjackEnv(gym.Env):
def _get_obs(self): def _get_obs(self):
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player)) return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.dealer = draw_hand(self.np_random) self.dealer = draw_hand(self.np_random)
self.player = draw_hand(self.np_random) self.player = draw_hand(self.np_random)
return self._get_obs() return self._get_obs()

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import Env, spaces from gym import Env, spaces
@@ -11,7 +13,7 @@ def categorical_sample(prob_n, np_random):
""" """
prob_n = np.asarray(prob_n) prob_n = np.asarray(prob_n)
csprob_n = np.cumsum(prob_n) csprob_n = np.cumsum(prob_n)
return (csprob_n > np_random.rand()).argmax() return (csprob_n > np_random.random()).argmax()
class DiscreteEnv(Env): class DiscreteEnv(Env):
@@ -40,14 +42,8 @@ class DiscreteEnv(Env):
self.action_space = spaces.Discrete(self.nA) self.action_space = spaces.Discrete(self.nA)
self.observation_space = spaces.Discrete(self.nS) self.observation_space = spaces.Discrete(self.nS)
self.seed() def reset(self, seed: Optional[int] = None):
self.s = categorical_sample(self.isd, self.np_random) super().reset(seed=seed)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self.s = categorical_sample(self.isd, self.np_random) self.s = categorical_sample(self.isd, self.np_random)
self.lastaction = None self.lastaction = None
return int(self.s) return int(self.s)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import gym import gym
from gym import spaces from gym import spaces
@@ -52,7 +54,6 @@ class CubeCrash(gym.Env):
use_random_colors = False # Makes env too hard use_random_colors = False # Makes env too hard
def __init__(self): def __init__(self):
self.seed()
self.viewer = None self.viewer = None
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
@@ -62,23 +63,20 @@ class CubeCrash(gym.Env):
self.reset() self.reset()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def random_color(self): def random_color(self):
return np.array( return np.array(
[ [
self.np_random.randint(low=0, high=255), self.np_random.integers(low=0, high=255),
self.np_random.randint(low=0, high=255), self.np_random.integers(low=0, high=255),
self.np_random.randint(low=0, high=255), self.np_random.integers(low=0, high=255),
] ]
).astype("uint8") ).astype("uint8")
def reset(self): def reset(self, seed: Optional[int] = None):
self.cube_x = self.np_random.randint(low=3, high=FIELD_W - 3) super().reset(seed=seed)
self.cube_y = self.np_random.randint(low=3, high=FIELD_H // 6) self.cube_x = self.np_random.integers(low=3, high=FIELD_W - 3)
self.hole_x = self.np_random.randint(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH) self.cube_y = self.np_random.integers(low=3, high=FIELD_H // 6)
self.hole_x = self.np_random.integers(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
self.bg_color = self.random_color() if self.use_random_colors else color_black self.bg_color = self.random_color() if self.use_random_colors else color_black
self.potential = None self.potential = None
self.step_n = 0 self.step_n = 0
@@ -95,6 +93,7 @@ class CubeCrash(gym.Env):
): ):
continue continue
break break
return self.step(0)[0] return self.step(0)[0]
def step(self, action): def step(self, action):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import gym import gym
from gym import spaces from gym import spaces
@@ -60,7 +62,6 @@ class MemorizeDigits(gym.Env):
use_random_colors = False use_random_colors = False
def __init__(self): def __init__(self):
self.seed()
self.viewer = None self.viewer = None
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8 0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
@@ -74,22 +75,19 @@ class MemorizeDigits(gym.Env):
] ]
self.reset() self.reset()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def random_color(self): def random_color(self):
return np.array( return np.array(
[ [
self.np_random.randint(low=0, high=255), self.np_random.integers(low=0, high=255),
self.np_random.randint(low=0, high=255), self.np_random.integers(low=0, high=255),
self.np_random.randint(low=0, high=255), self.np_random.integers(low=0, high=255),
] ]
).astype("uint8") ).astype("uint8")
def reset(self): def reset(self, seed: Optional[int] = None):
self.digit_x = self.np_random.randint(low=FIELD_W // 5, high=FIELD_W // 5 * 4) super().reset(seed=seed)
self.digit_y = self.np_random.randint(low=FIELD_H // 5, high=FIELD_H // 5 * 4) self.digit_x = self.np_random.integers(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
self.digit_y = self.np_random.integers(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
self.color_bg = self.random_color() if self.use_random_colors else color_black self.color_bg = self.random_color() if self.use_random_colors else color_black
self.step_n = 0 self.step_n = 0
while 1: while 1:
@@ -111,8 +109,8 @@ class MemorizeDigits(gym.Env):
else: else:
if self.digit == action: if self.digit == action:
reward = +1 reward = +1
done = self.step_n > 20 and 0 == self.np_random.randint(low=0, high=5) done = self.step_n > 20 and 0 == self.np_random.integers(low=0, high=5)
self.digit = self.np_random.randint(low=0, high=10) self.digit = self.np_random.integers(low=0, high=10)
obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8) obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8)
obs[:, :, :] = self.color_bg obs[:, :, :] = self.color_bg
digit_img = np.zeros((6, 6, 3), dtype=np.uint8) digit_img = np.zeros((6, 6, 3), dtype=np.uint8)

View File

@@ -35,7 +35,7 @@ class MultiBinary(Space):
super().__init__(input_n, np.int8, seed) super().__init__(input_n, np.int8, seed)
def sample(self): def sample(self):
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype) return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x): def contains(self, x):
if isinstance(x, list) or isinstance(x, tuple): if isinstance(x, list) or isinstance(x, tuple):

View File

@@ -35,9 +35,7 @@ class MultiDiscrete(Space):
super().__init__(self.nvec.shape, dtype, seed) super().__init__(self.nvec.shape, dtype, seed)
def sample(self): def sample(self):
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype( return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
self.dtype
)
def contains(self, x): def contains(self, x):
if isinstance(x, list): if isinstance(x, list):
@@ -61,7 +59,7 @@ class MultiDiscrete(Space):
subspace = Discrete(nvec) subspace = Discrete(nvec)
else: else:
subspace = MultiDiscrete(nvec, self.dtype) subspace = MultiDiscrete(nvec, self.dtype)
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
return subspace return subspace
def __len__(self): def __len__(self):

View File

@@ -1,43 +1,118 @@
import hashlib import hashlib
import numpy as np
import os import os
import random as _random
import struct import struct
import sys from typing import Optional
import numpy as np
from numpy.random import Generator
from gym import error from gym import error
from gym.logger import deprecation
def np_random(seed=None): def np_random(seed: Optional[int] = None):
if seed is not None and not (isinstance(seed, int) and 0 <= seed): if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
seed = create_seed(seed) seed_seq = np.random.SeedSequence(seed)
seed = seed_seq.entropy
rng = np.random.RandomState() rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
rng.seed(_int_list_from_bigint(hash_seed(seed)))
return rng, seed return rng, seed
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
# RandomNumberGenerator = np.random.Generator
class RandomNumberGenerator(np.random.Generator):
def rand(self, *size):
deprecation(
"Function `rng.rand(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `Generator.random(size)` instead."
)
return self.random(size)
random_sample = rand
def randn(self, *size):
deprecation(
"Function `rng.randn(*size)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.standard_normal(size)` instead."
)
return self.standard_normal(size)
def randint(self, low, high=None, size=None, dtype=int):
deprecation(
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.integers(low, [high, size, dtype])` instead."
)
return self.integers(low=low, high=high, size=size, dtype=dtype)
random_integers = randint
def get_state(self):
deprecation(
"Function `rng.get_state()` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state` instead."
)
return self.bit_generator.state
def set_state(self, state):
deprecation(
"Function `rng.set_state(state)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng.bit_generator.state = state` instead."
)
self.bit_generator.state = state
def seed(self, seed=None):
deprecation(
"Function `rng.seed(seed)` is marked as deprecated "
"and will be removed in the future. "
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
)
self.bit_generator.state = type(self.bit_generator)(seed).state
rand.__doc__ = np.random.rand.__doc__
randn.__doc__ = np.random.randn.__doc__
randint.__doc__ = np.random.randint.__doc__
get_state.__doc__ = np.random.get_state.__doc__
set_state.__doc__ = np.random.set_state.__doc__
seed.__doc__ = np.random.seed.__doc__
RNG = RandomNumberGenerator
# Legacy functions
def hash_seed(seed=None, max_bytes=8): def hash_seed(seed=None, max_bytes=8):
"""Any given evaluation is likely to have many PRNG's active at """Any given evaluation is likely to have many PRNG's active at
once. (Most commonly, because the environment is running in once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having multiple processes.) There's literature indicating that having
linear correlations between seeds of multiple PRNG's can correlate linear correlations between seeds of multiple PRNG's can correlate
the outputs: the outputs:
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/ http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928 http://dl.acm.org/citation.cfm?id=1276928
Thus, for sanity we hash the seeds before using them. (This scheme Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.) rid of simple correlations.)
Args: Args:
seed (Optional[int]): None seeds from an operating system specific randomness source. seed (Optional[int]): None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed. max_bytes: Maximum number of bytes to use in the hashed seed.
""" """
deprecation(
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
)
if seed is None: if seed is None:
seed = create_seed(max_bytes=max_bytes) seed = create_seed(max_bytes=max_bytes)
hash = hashlib.sha512(str(seed).encode("utf8")).digest() hash = hashlib.sha512(str(seed).encode("utf8")).digest()
@@ -48,11 +123,13 @@ def create_seed(a=None, max_bytes=8):
"""Create a strong random seed. Otherwise, Python 2 would seed using """Create a strong random seed. Otherwise, Python 2 would seed using
the system time, which might be non-robust especially in the the system time, which might be non-robust especially in the
presence of concurrency. presence of concurrency.
Args: Args:
a (Optional[int, str]): None seeds from an operating system specific randomness source. a (Optional[int, str]): None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed. max_bytes: Maximum number of bytes to use in the seed.
""" """
deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
)
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py # Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
if a is None: if a is None:
a = _bigint_from_bytes(os.urandom(max_bytes)) a = _bigint_from_bytes(os.urandom(max_bytes))
@@ -70,6 +147,9 @@ def create_seed(a=None, max_bytes=8):
# TODO: don't hardcode sizeof_int here # TODO: don't hardcode sizeof_int here
def _bigint_from_bytes(bytes): def _bigint_from_bytes(bytes):
deprecation(
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
)
sizeof_int = 4 sizeof_int = 4
padding = sizeof_int - len(bytes) % sizeof_int padding = sizeof_int - len(bytes) % sizeof_int
bytes += b"\0" * padding bytes += b"\0" * padding
@@ -82,6 +162,9 @@ def _bigint_from_bytes(bytes):
def _int_list_from_bigint(bigint): def _int_list_from_bigint(bigint):
deprecation(
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
)
# Special case 0 # Special case 0
if bigint < 0: if bigint < 0:
raise error.Error(f"Seed must be non-negative, not {bigint}") raise error.Error(f"Seed must be non-negative, not {bigint}")

View File

@@ -1,3 +1,5 @@
from typing import Optional, Union, List
import numpy as np import numpy as np
import multiprocessing as mp import multiprocessing as mp
import time import time
@@ -6,6 +8,7 @@ from enum import Enum
from copy import deepcopy from copy import deepcopy
from gym import logger from gym import logger
from gym.logger import warn
from gym.vector.vector_env import VectorEnv from gym.vector.vector_env import VectorEnv
from gym.error import ( from gym.error import (
AlreadyPendingCallError, AlreadyPendingCallError,
@@ -187,13 +190,14 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
self._check_observation_spaces() self._check_observation_spaces()
def seed(self, seeds=None): def seed(self, seed=None):
super().seed(seed=seed)
self._assert_is_running() self._assert_is_running()
if seeds is None: if seed is None:
seeds = [None for _ in range(self.num_envs)] seed = [None for _ in range(self.num_envs)]
if isinstance(seeds, int): if isinstance(seed, int):
seeds = [seeds + i for i in range(self.num_envs)] seed = [seed + i for i in range(self.num_envs)]
assert len(seeds) == self.num_envs assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
@@ -201,12 +205,12 @@ class AsyncVectorEnv(VectorEnv):
self._state.value, self._state.value,
) )
for pipe, seed in zip(self.parent_pipes, seeds): for pipe, seed in zip(self.parent_pipes, seed):
pipe.send(("seed", seed)) pipe.send(("seed", seed))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes) self._raise_if_errors(successes)
def reset_async(self): def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
"""Send the calls to :obj:`reset` to each sub-environment. """Send the calls to :obj:`reset` to each sub-environment.
Raises Raises
@@ -221,24 +225,31 @@ class AsyncVectorEnv(VectorEnv):
between. between.
""" """
self._assert_is_running() self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete", f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
self._state.value, self._state.value,
) )
for pipe in self.parent_pipes: for pipe, single_seed in zip(self.parent_pipes, seed):
pipe.send(("reset", None)) pipe.send(("reset", single_seed))
self._state = AsyncState.WAITING_RESET self._state = AsyncState.WAITING_RESET
def reset_wait(self, timeout=None): def reset_wait(self, timeout=None, seed: Optional[int] = None):
"""Wait for the calls to :obj:`reset` in each sub-environment to finish. """
Parameters Parameters
---------- ----------
timeout : int or float, optional timeout : int or float, optional
Number of seconds before the call to :meth:`reset_wait` times out. Number of seconds before the call to `reset_wait` times out. If
If ``None``, the call to :meth:`reset_wait` never times out. `None`, the call to `reset_wait` never times out.
seed: ignored
Returns Returns
------- -------
@@ -486,7 +497,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
observation = env.reset() observation = env.reset(data)
pipe.send((observation, True)) pipe.send((observation, True))
elif command == "step": elif command == "step":
observation, reward, done, info = env.step(data) observation, reward, done, info = env.step(data)
@@ -524,7 +535,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
observation = env.reset() observation = env.reset(data)
write_to_shared_memory( write_to_shared_memory(
index, observation, shared_memory, observation_space index, observation, shared_memory, observation_space
) )

View File

@@ -1,7 +1,10 @@
from typing import List, Union, Optional
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from gym import logger from gym import logger
from gym.logger import warn
from gym.vector.vector_env import VectorEnv from gym.vector.vector_env import VectorEnv
from gym.vector.utils import concatenate, create_empty_array from gym.vector.utils import concatenate, create_empty_array
@@ -72,21 +75,28 @@ class SyncVectorEnv(VectorEnv):
self._dones = np.zeros((self.num_envs,), dtype=np.bool_) self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None self._actions = None
def seed(self, seeds=None): def seed(self, seed=None):
if seeds is None: super().seed(seed=seed)
seeds = [None for _ in range(self.num_envs)] if seed is None:
if isinstance(seeds, int): seed = [None for _ in range(self.num_envs)]
seeds = [seeds + i for i in range(self.num_envs)] if isinstance(seed, int):
assert len(seeds) == self.num_envs seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
for env, seed in zip(self.envs, seeds): for env, single_seed in zip(self.envs, seed):
env.seed(seed) env.seed(single_seed)
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
def reset_wait(self):
self._dones[:] = False self._dones[:] = False
observations = [] observations = []
for env in self.envs: for env, single_seed in zip(self.envs, seed):
observation = env.reset() observation = env.reset(seed=single_seed)
observations.append(observation) observations.append(observation)
self.observations = concatenate( self.observations = concatenate(
observations, self.observations, self.single_observation_space observations, self.observations, self.single_observation_space

View File

@@ -1,4 +1,7 @@
from typing import Optional, Union, List
import gym import gym
from gym.logger import warn, deprecation
from gym.spaces import Tuple from gym.spaces import Tuple
from gym.vector.utils.spaces import batch_space from gym.vector.utils.spaces import batch_space
@@ -43,13 +46,13 @@ class VectorEnv(gym.Env):
self.single_observation_space = observation_space self.single_observation_space = observation_space
self.single_action_space = action_space self.single_action_space = action_space
def reset_async(self): def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
pass pass
def reset_wait(self, **kwargs): def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def reset(self): def reset(self, seed: Optional[Union[int, List[int]]] = None):
r"""Reset all sub-environments and return a batch of initial observations. r"""Reset all sub-environments and return a batch of initial observations.
Returns Returns
@@ -57,8 +60,8 @@ class VectorEnv(gym.Env):
element of :attr:`observation_space` element of :attr:`observation_space`
A batch of observations from the vectorized environment. A batch of observations from the vectorized environment.
""" """
self.reset_async() self.reset_async(seed=seed)
return self.reset_wait() return self.reset_wait(seed=seed)
def step_async(self, actions): def step_async(self, actions):
pass pass
@@ -120,19 +123,22 @@ class VectorEnv(gym.Env):
self.close_extras(**kwargs) self.close_extras(**kwargs)
self.closed = True self.closed = True
def seed(self, seeds=None): def seed(self, seed=None):
"""Set the random seed in all sub-environments. """Set the random seed in all sub-environments.
Parameters Parameters
---------- ----------
seeds : list of int, or int, optional seed : list of int, or int, optional
Random seed for each sub-environment. If ``seeds`` is a list of Random seed for each sub-environment. If ``seed`` is a list of
length ``num_envs``, then the items of the list are chosen as random length ``num_envs``, then the items of the list are chosen as random
seeds. If ``seeds`` is an int, then each sub-environment uses the random seeds. If ``seed`` is an int, then each sub-environment uses the random
seed ``seeds + n``, where ``n`` is the index of the sub-environment seed ``seed + n``, where ``n`` is the index of the sub-environment
(between ``0`` and ``num_envs - 1``). (between ``0`` and ``num_envs - 1``).
""" """
pass deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead in VectorEnvs."
)
def __del__(self): def __del__(self):
if not getattr(self, "closed", True): if not getattr(self, "closed", True):
@@ -164,11 +170,11 @@ class VectorEnvWrapper(VectorEnv):
# explicitly forward the methods defined in VectorEnv # explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class) # to self.env (instead of the base class)
def reset_async(self): def reset_async(self, **kwargs):
return self.env.reset_async() return self.env.reset_async(**kwargs)
def reset_wait(self): def reset_wait(self, **kwargs):
return self.env.reset_wait() return self.env.reset_wait(**kwargs)
def step_async(self, actions): def step_async(self, actions):
return self.env.step_async(actions) return self.env.step_async(actions)
@@ -182,8 +188,8 @@ class VectorEnvWrapper(VectorEnv):
def close_extras(self, **kwargs): def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs) return self.env.close_extras(**kwargs)
def seed(self, seeds=None): def seed(self, seed=None):
return self.env.seed(seeds) return self.env.seed(seed)
# implicitly forward all other methods and attributes to self.env # implicitly forward all other methods and attributes to self.env
def __getattr__(self, name): def __getattr__(self, name):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import gym import gym
from gym.spaces import Box from gym.spaces import Box
@@ -127,11 +129,11 @@ class AtariPreprocessing(gym.Wrapper):
self.ale.getScreenRGB(self.obs_buffer[0]) self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), R, done, info return self._get_obs(), R, done, info
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
# NoopReset # NoopReset
self.env.reset(**kwargs) self.env.reset(seed=seed, **kwargs)
noops = ( noops = (
self.env.unwrapped.np_random.randint(1, self.noop_max + 1) self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0 if self.noop_max > 0
else 0 else 0
) )

View File

@@ -1,4 +1,6 @@
from collections import deque from collections import deque
from typing import Optional
import numpy as np import numpy as np
from gym.spaces import Box from gym.spaces import Box
from gym import ObservationWrapper from gym import ObservationWrapper
@@ -116,7 +118,7 @@ class FrameStack(ObservationWrapper):
self.frames.append(observation) self.frames.append(observation)
return self.observation(), reward, done, info return self.observation(), reward, done, info
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
observation = self.env.reset(**kwargs) observation = self.env.reset(seed=seed, **kwargs)
[self.frames.append(observation) for _ in range(self.num_stack)] [self.frames.append(observation) for _ in range(self.num_stack)]
return self.observation() return self.observation()

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
from typing import Optional
import numpy as np import numpy as np
@@ -50,9 +51,9 @@ class Monitor(Wrapper):
return observation, reward, done, info return observation, reward, done, info
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
self._before_reset() self._before_reset()
observation = self.env.reset(**kwargs) observation = self.env.reset(seed=seed, **kwargs)
self._after_reset(observation) self._after_reset(observation)
return observation return observation

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import gym import gym
@@ -61,8 +63,8 @@ class NormalizeObservation(gym.core.Wrapper):
obs = self.normalize(np.array([obs]))[0] obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos return obs, rews, dones, infos
def reset(self): def reset(self, seed: Optional[int] = None):
obs = self.env.reset() obs = self.env.reset(seed=seed)
if self.is_vector_env: if self.is_vector_env:
obs = self.normalize(obs) obs = self.normalize(obs)
else: else:

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym import gym
@@ -11,6 +13,6 @@ class OrderEnforcing(gym.Wrapper):
observation, reward, done, info = self.env.step(action) observation, reward, done, info = self.env.step(action)
return observation, reward, done, info return observation, reward, done, info
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
self._has_reset = True self._has_reset = True
return self.env.reset(**kwargs) return self.env.reset(seed=seed, **kwargs)

View File

@@ -1,5 +1,7 @@
import time import time
from collections import deque from collections import deque
from typing import Optional
import numpy as np import numpy as np
import gym import gym
@@ -16,8 +18,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.length_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
observations = super().reset(**kwargs) observations = super().reset(seed=seed, **kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations return observations

View File

@@ -1,6 +1,6 @@
import os import os
import gym import gym
from typing import Callable from typing import Callable, Optional
from gym import logger from gym import logger
from gym.wrappers.monitoring import video_recorder from gym.wrappers.monitoring import video_recorder
@@ -52,8 +52,8 @@ class RecordVideo(gym.Wrapper):
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)
self.episode_id = 0 self.episode_id = 0
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
observations = super().reset(**kwargs) observations = super().reset(seed=seed, **kwargs)
if not self.recording and self._video_enabled(): if not self.recording and self._video_enabled():
self.start_video_recorder() self.start_video_recorder()
return observations return observations

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym.spaces import Box from gym.spaces import Box
from gym import ObservationWrapper from gym import ObservationWrapper
@@ -27,6 +29,6 @@ class TimeAwareObservation(ObservationWrapper):
self.t += 1 self.t += 1
return super().step(action) return super().step(action)
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
self.t = 0 self.t = 0
return super().reset(**kwargs) return super().reset(seed=seed, **kwargs)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym import gym
@@ -22,6 +24,6 @@ class TimeLimit(gym.Wrapper):
done = True done = True
return observation, reward, done, info return observation, reward, done, info
def reset(self, **kwargs): def reset(self, seed: Optional[int] = None, **kwargs):
self._elapsed_steps = 0 self._elapsed_steps = 0
return self.env.reset(**kwargs) return self.env.reset(seed=seed, **kwargs)

View File

@@ -15,6 +15,6 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mjpro150/bin
COPY . /usr/local/gym/ COPY . /usr/local/gym/
WORKDIR /usr/local/gym/ WORKDIR /usr/local/gym/
RUN pip install .[nomujoco,accept-rom-license] && pip install -r test_requirements.txt RUN pip install .[nomujoco] && pip install -r test_requirements.txt
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"] ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]

View File

@@ -21,7 +21,7 @@ extras = {
} }
# Meta dependency groups. # Meta dependency groups.
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license"]) nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license", "atari"])
nomujoco_groups = set(extras.keys()) - nomujoco_blacklist nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
extras["nomujoco"] = list( extras["nomujoco"] = list(

View File

@@ -10,16 +10,14 @@ def test_env(spec):
# threads. However, we probably already can't do multithreading # threads. However, we probably already can't do multithreading
# due to some environments. # due to some environments.
env1 = spec.make() env1 = spec.make()
env1.seed(0) initial_observation1 = env1.reset(seed=0)
initial_observation1 = env1.reset()
env1.action_space.seed(0) env1.action_space.seed(0)
action_samples1 = [env1.action_space.sample() for i in range(4)] action_samples1 = [env1.action_space.sample() for i in range(4)]
step_responses1 = [env1.step(action) for action in action_samples1] step_responses1 = [env1.step(action) for action in action_samples1]
env1.close() env1.close()
env2 = spec.make() env2 = spec.make()
env2.seed(0) initial_observation2 = env2.reset(seed=0)
initial_observation2 = env2.reset()
env2.action_space.seed(0) env2.action_space.seed(0)
action_samples2 = [env2.action_space.sample() for i in range(4)] action_samples2 = [env2.action_space.sample() for i in range(4)]
step_responses2 = [env2.step(action) for action in action_samples2] step_responses2 = [env2.step(action) for action in action_samples2]

View File

@@ -10,11 +10,8 @@ def verify_environments_match(
old_environment = envs.make(old_environment_id) old_environment = envs.make(old_environment_id)
new_environment = envs.make(new_environment_id) new_environment = envs.make(new_environment_id)
old_environment.seed(seed) old_reset_observation = old_environment.reset(seed=seed)
new_environment.seed(seed) new_reset_observation = new_environment.reset(seed=seed)
old_reset_observation = old_environment.reset()
new_reset_observation = new_environment.reset()
np.testing.assert_allclose(old_reset_observation, new_reset_observation) np.testing.assert_allclose(old_reset_observation, new_reset_observation)

View File

@@ -370,7 +370,7 @@ def test_seed_subspace_incorrelated(space):
space.seed(0) space.seed(0)
states = [ states = [
convert_sample_hashable(subspace.np_random.get_state()) convert_sample_hashable(subspace.np_random.bit_generator.state)
for subspace in subspaces for subspace in subspaces
] ]

View File

@@ -1,3 +1,5 @@
from typing import Optional
import pytest import pytest
import numpy as np import numpy as np
@@ -16,7 +18,8 @@ class UnittestEnv(core.Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3) action_space = spaces.Discrete(3)
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
return self.observation_space.sample() # Dummy observation return self.observation_space.sample() # Dummy observation
def step(self, action): def step(self, action):
@@ -31,7 +34,8 @@ class UnknownSpacesEnv(core.Env):
on external resources), it is not encouraged. on external resources), it is not encouraged.
""" """
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
) )

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym import gym
import numpy as np import numpy as np
import pytest import pytest
@@ -16,7 +18,8 @@ class ActionDictTestEnv(gym.Env):
done = True done = True
return observation, reward, done return observation, reward, done
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
return np.array([1.0, 1.5, 0.5]) return np.array([1.0, 1.5, 0.5])
def render(self, mode="human"): def render(self, mode="human"):

View File

@@ -17,17 +17,14 @@ def test_vector_env_equal(shared_memory):
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
sync_env = SyncVectorEnv(env_fns) sync_env = SyncVectorEnv(env_fns)
async_env.seed(0)
sync_env.seed(0)
assert async_env.num_envs == sync_env.num_envs assert async_env.num_envs == sync_env.num_envs
assert async_env.observation_space == sync_env.observation_space assert async_env.observation_space == sync_env.observation_space
assert async_env.single_observation_space == sync_env.single_observation_space assert async_env.single_observation_space == sync_env.single_observation_space
assert async_env.action_space == sync_env.action_space assert async_env.action_space == sync_env.action_space
assert async_env.single_action_space == sync_env.single_action_space assert async_env.single_action_space == sync_env.single_action_space
async_observations = async_env.reset() async_observations = async_env.reset(seed=0)
sync_observations = sync_env.reset() sync_observations = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations) assert np.all(async_observations == sync_observations)
for _ in range(num_steps): for _ in range(num_steps):

View File

@@ -8,7 +8,7 @@ class DummyWrapper(VectorEnvWrapper):
self.env = env self.env = env
self.counter = 0 self.counter = 0
def reset_async(self): def reset_async(self, **kwargs):
super().reset_async() super().reset_async()
self.counter += 1 self.counter += 1

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import gym import gym
import time import time
@@ -55,7 +57,8 @@ class UnittestSlowEnv(gym.Env):
) )
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32) self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
if self.slow_reset > 0: if self.slow_reset > 0:
time.sleep(self.slow_reset) time.sleep(self.slow_reset)
return self.observation_space.sample() return self.observation_space.sample()
@@ -86,7 +89,8 @@ class CustomSpaceEnv(gym.Env):
self.observation_space = CustomSpace() self.observation_space = CustomSpace()
self.action_space = CustomSpace() self.action_space = CustomSpace()
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
return "reset" return "reset"
def step(self, action): def step(self, action):
@@ -98,7 +102,7 @@ class CustomSpaceEnv(gym.Env):
def make_env(env_name, seed): def make_env(env_name, seed):
def _make(): def _make():
env = gym.make(env_name) env = gym.make(env_name)
env.seed(seed) env.reset(seed=seed)
return env return env
return _make return _make
@@ -107,7 +111,7 @@ def make_env(env_name, seed):
def make_slow_env(slow_reset, seed): def make_slow_env(slow_reset, seed):
def _make(): def _make():
env = UnittestSlowEnv(slow_reset=slow_reset) env = UnittestSlowEnv(slow_reset=slow_reset)
env.seed(seed) env.reset(seed=seed)
return env return env
return _make return _make
@@ -116,7 +120,7 @@ def make_slow_env(slow_reset, seed):
def make_custom_space_env(seed): def make_custom_space_env(seed):
def _make(): def _make():
env = CustomSpaceEnv() env = CustomSpaceEnv()
env.seed(seed) env.reset(seed=seed)
return env return env
return _make return _make

View File

@@ -1,6 +1,7 @@
"""Tests for the flatten observation wrapper.""" """Tests for the flatten observation wrapper."""
from collections import OrderedDict from collections import OrderedDict
from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
@@ -14,7 +15,8 @@ class FakeEnvironment(gym.Env):
def __init__(self, observation_space): def __init__(self, observation_space):
self.observation_space = observation_space self.observation_space = observation_space
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.observation = self.observation_space.sample() self.observation = self.observation_space.sample()
return self.observation return self.observation

View File

@@ -1,5 +1,5 @@
"""Tests for the filter observation wrapper.""" """Tests for the filter observation wrapper."""
from typing import Optional
import pytest import pytest
import numpy as np import numpy as np
@@ -21,7 +21,8 @@ class FakeEnvironment(gym.Env):
image_shape = (height, width, 3) image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
observation = self.observation_space.sample() observation = self.observation_space.sample()
return observation return observation

View File

@@ -29,14 +29,10 @@ def test_atari_preprocessing_grayscale(env_fn):
noop_max=0, noop_max=0,
grayscale_newaxis=True, grayscale_newaxis=True,
) )
env1.seed(0) obs1 = env1.reset(seed=0)
env2.seed(0) obs2 = env2.reset(seed=0)
env3.seed(0) obs3 = env3.reset(seed=0)
env4.seed(0) obs4 = env4.reset(seed=0)
obs1 = env1.reset()
obs2 = env2.reset()
obs3 = env3.reset()
obs4 = env4.reset()
assert env1.observation_space.shape == (210, 160, 3) assert env1.observation_space.shape == (210, 160, 3)
assert env2.observation_space.shape == (84, 84) assert env2.observation_space.shape == (84, 84)
assert env3.observation_space.shape == (84, 84, 3) assert env3.observation_space.shape == (84, 84, 3)

View File

@@ -11,11 +11,9 @@ def test_clip_action():
wrapped_env = ClipAction(make_env()) wrapped_env = ClipAction(make_env())
seed = 0 seed = 0
env.seed(seed)
wrapped_env.seed(seed)
env.reset() env.reset(seed=seed)
wrapped_env.reset() wrapped_env.reset(seed=seed)
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
for action in actions: for action in actions:

View File

@@ -1,3 +1,5 @@
from typing import Optional
import pytest import pytest
import numpy as np import numpy as np
@@ -22,7 +24,8 @@ class FakeEnvironment(gym.Env):
image_shape = (height, width, 3) image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
observation = self.observation_space.sample() observation = self.observation_space.sample()
return observation return observation

View File

@@ -28,17 +28,15 @@ except ImportError:
) )
def test_frame_stack(env_id, num_stack, lz4_compress): def test_frame_stack(env_id, num_stack, lz4_compress):
env = gym.make(env_id) env = gym.make(env_id)
env.seed(0)
shape = env.observation_space.shape shape = env.observation_space.shape
env = FrameStack(env, num_stack, lz4_compress) env = FrameStack(env, num_stack, lz4_compress)
assert env.observation_space.shape == (num_stack,) + shape assert env.observation_space.shape == (num_stack,) + shape
assert env.observation_space.dtype == env.env.observation_space.dtype assert env.observation_space.dtype == env.env.observation_space.dtype
dup = gym.make(env_id) dup = gym.make(env_id)
dup.seed(0)
obs = env.reset() obs = env.reset(seed=0)
dup_obs = dup.reset() dup_obs = dup.reset(seed=0)
assert np.allclose(obs[-1], dup_obs) assert np.allclose(obs[-1], dup_obs)
for _ in range(num_stack ** 2): for _ in range(num_stack ** 2):

View File

@@ -21,11 +21,9 @@ def test_gray_scale_observation(env_id, keep_dim):
assert rgb_env.observation_space.shape[-1] == 3 assert rgb_env.observation_space.shape[-1] == 3
seed = 0 seed = 0
gray_env.seed(seed)
wrapped_env.seed(seed)
gray_obs = gray_env.reset() gray_obs = gray_env.reset(seed=seed)
wrapped_obs = wrapped_env.reset() wrapped_obs = wrapped_env.reset(seed=seed)
if keep_dim: if keep_dim:
assert wrapped_env.observation_space.shape[-1] == 1 assert wrapped_env.observation_space.shape[-1] == 1

View File

@@ -1,3 +1,5 @@
from typing import Optional
import gym import gym
import numpy as np import numpy as np
from numpy.testing import assert_almost_equal from numpy.testing import assert_almost_equal
@@ -21,7 +23,8 @@ class DummyRewardEnv(gym.Env):
self.t += 1 self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
self.t = self.return_reward_idx self.t = self.return_reward_idx
return np.array([self.t]) return np.array([self.t])

View File

@@ -1,5 +1,5 @@
"""Tests for the pixel observation wrapper.""" """Tests for the pixel observation wrapper."""
from typing import Optional
import pytest import pytest
import numpy as np import numpy as np
@@ -19,7 +19,8 @@ class FakeEnvironment(gym.Env):
image_shape = (height, width, 3) image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
def reset(self): def reset(self, seed: Optional[int] = None):
super().reset(seed=seed)
observation = self.observation_space.sample() observation = self.observation_space.sample()
return observation return observation

View File

@@ -16,11 +16,9 @@ def test_rescale_action():
wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1) wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1)
seed = 0 seed = 0
env.seed(seed)
wrapped_env.seed(seed)
obs = env.reset() obs = env.reset(seed=seed)
wrapped_obs = wrapped_env.reset() wrapped_obs = wrapped_env.reset(seed=seed)
assert np.allclose(obs, wrapped_obs) assert np.allclose(obs, wrapped_obs)
obs, reward, _, _ = env.step([1.5]) obs, reward, _, _ = env.step([1.5])

View File

@@ -14,11 +14,8 @@ def test_transform_observation(env_id):
gym.make(env_id), lambda obs: affine_transform(obs) gym.make(env_id), lambda obs: affine_transform(obs)
) )
env.seed(0) obs = env.reset(seed=0)
wrapped_env.seed(0) wrapped_obs = wrapped_env.reset(seed=0)
obs = env.reset()
wrapped_obs = wrapped_env.reset()
assert np.allclose(wrapped_obs, affine_transform(obs)) assert np.allclose(wrapped_obs, affine_transform(obs))
action = env.action_space.sample() action = env.action_space.sample()

View File

@@ -15,10 +15,8 @@ def test_transform_reward(env_id):
wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r) wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r)
action = env.action_space.sample() action = env.action_space.sample()
env.seed(0) env.reset(seed=0)
env.reset() wrapped_env.reset(seed=0)
wrapped_env.seed(0)
wrapped_env.reset()
_, reward, _, _ = env.step(action) _, reward, _, _ = env.step(action)
_, wrapped_reward, _, _ = wrapped_env.step(action) _, wrapped_reward, _, _ = wrapped_env.step(action)
@@ -33,10 +31,8 @@ def test_transform_reward(env_id):
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r)) wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r))
action = env.action_space.sample() action = env.action_space.sample()
env.seed(0) env.reset(seed=0)
env.reset() wrapped_env.reset(seed=0)
wrapped_env.seed(0)
wrapped_env.reset()
_, reward, _, _ = env.step(action) _, reward, _, _ = env.step(action)
_, wrapped_reward, _, _ = wrapped_env.step(action) _, wrapped_reward, _, _ = wrapped_env.step(action)
@@ -49,10 +45,8 @@ def test_transform_reward(env_id):
env = gym.make(env_id) env = gym.make(env_id)
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r)) wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r))
env.seed(0) env.reset(seed=0)
env.reset() wrapped_env.reset(seed=0)
wrapped_env.seed(0)
wrapped_env.reset()
for _ in range(1000): for _ in range(1000):
action = env.action_space.sample() action = env.action_space.sample()