mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Seeding update (#2422)
* Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
This commit is contained in:
committed by
GitHub
parent
b84b69c872
commit
c364506710
@@ -50,6 +50,7 @@
|
||||
|
||||
* `gym-foo/gym_foo/envs/foo_env.py` should look something like:
|
||||
```python
|
||||
from typing import Optional
|
||||
import gym
|
||||
from gym import error, spaces, utils
|
||||
from gym.utils import seeding
|
||||
@@ -61,7 +62,8 @@
|
||||
...
|
||||
def step(self, action):
|
||||
...
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
...
|
||||
def render(self, mode='human'):
|
||||
...
|
||||
|
39
gym/core.py
39
gym/core.py
@@ -1,8 +1,10 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
from gym import error
|
||||
from gym.utils import closer
|
||||
from gym.utils import closer, seeding
|
||||
from gym.logger import deprecation
|
||||
|
||||
|
||||
class Env:
|
||||
@@ -38,6 +40,9 @@ class Env:
|
||||
action_space = None
|
||||
observation_space = None
|
||||
|
||||
# Created
|
||||
np_random = None
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action):
|
||||
"""Run one timestep of the environment's dynamics. When end of
|
||||
@@ -58,7 +63,7 @@ class Env:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
"""Resets the environment to an initial state and returns an initial
|
||||
observation.
|
||||
|
||||
@@ -71,7 +76,9 @@ class Env:
|
||||
Returns:
|
||||
observation (object): the initial observation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# Initialize the RNG if it's the first reset, or if the seed is manually passed
|
||||
if seed is not None or self.np_random is None:
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
|
||||
@abstractmethod
|
||||
def render(self, mode="human"):
|
||||
@@ -136,7 +143,12 @@ class Env:
|
||||
'seed'. Often, the main seed equals the provided 'seed', but
|
||||
this won't be true if seed=None, for example.
|
||||
"""
|
||||
return
|
||||
deprecation(
|
||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||
"Please use `env.reset(seed=seed) instead."
|
||||
)
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
@@ -173,7 +185,8 @@ class GoalEnv(Env):
|
||||
actual observations of the environment as per usual.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
# Enforce that each GoalEnv uses a Goal-compatible observation space.
|
||||
if not isinstance(self.observation_space, gym.spaces.Dict):
|
||||
raise error.Error(
|
||||
@@ -286,8 +299,8 @@ class Wrapper(Env):
|
||||
def step(self, action):
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
return self.env.reset(seed=seed, **kwargs)
|
||||
|
||||
def render(self, mode="human", **kwargs):
|
||||
return self.env.render(mode, **kwargs)
|
||||
@@ -313,8 +326,8 @@ class Wrapper(Env):
|
||||
|
||||
|
||||
class ObservationWrapper(Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
observation = self.env.reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
observation = self.env.reset(seed=seed, **kwargs)
|
||||
return self.observation(observation)
|
||||
|
||||
def step(self, action):
|
||||
@@ -327,8 +340,8 @@ class ObservationWrapper(Wrapper):
|
||||
|
||||
|
||||
class RewardWrapper(Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
return self.env.reset(seed=seed, **kwargs)
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
@@ -340,8 +353,8 @@ class RewardWrapper(Wrapper):
|
||||
|
||||
|
||||
class ActionWrapper(Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
return self.env.reset(seed=seed, **kwargs)
|
||||
|
||||
def step(self, action):
|
||||
return self.env.step(self.action(action))
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import Box2D
|
||||
@@ -122,7 +123,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
def __init__(self):
|
||||
EzPickle.__init__(self)
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
|
||||
self.world = Box2D.b2World()
|
||||
@@ -149,10 +149,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
)
|
||||
self.observation_space = spaces.Box(-high, high)
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def _destroy(self):
|
||||
if not self.terrain:
|
||||
return
|
||||
@@ -188,7 +184,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
y += velocity
|
||||
|
||||
elif state == PIT and oneshot:
|
||||
counter = self.np_random.randint(3, 5)
|
||||
counter = self.np_random.integers(3, 5)
|
||||
poly = [
|
||||
(x, y),
|
||||
(x + TERRAIN_STEP, y),
|
||||
@@ -215,7 +211,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
y -= 4 * TERRAIN_STEP
|
||||
|
||||
elif state == STUMP and oneshot:
|
||||
counter = self.np_random.randint(1, 3)
|
||||
counter = self.np_random.integers(1, 3)
|
||||
poly = [
|
||||
(x, y),
|
||||
(x + counter * TERRAIN_STEP, y),
|
||||
@@ -228,9 +224,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.terrain.append(t)
|
||||
|
||||
elif state == STAIRS and oneshot:
|
||||
stair_height = +1 if self.np_random.rand() > 0.5 else -1
|
||||
stair_width = self.np_random.randint(4, 5)
|
||||
stair_steps = self.np_random.randint(3, 5)
|
||||
stair_height = +1 if self.np_random.random() > 0.5 else -1
|
||||
stair_width = self.np_random.integers(4, 5)
|
||||
stair_steps = self.np_random.integers(3, 5)
|
||||
original_y = y
|
||||
for s in range(stair_steps):
|
||||
poly = [
|
||||
@@ -266,9 +262,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.terrain_y.append(y)
|
||||
counter -= 1
|
||||
if counter == 0:
|
||||
counter = self.np_random.randint(TERRAIN_GRASS / 2, TERRAIN_GRASS)
|
||||
counter = self.np_random.integers(TERRAIN_GRASS / 2, TERRAIN_GRASS)
|
||||
if state == GRASS and hardcore:
|
||||
state = self.np_random.randint(1, _STATES_)
|
||||
state = self.np_random.integers(1, _STATES_)
|
||||
oneshot = True
|
||||
else:
|
||||
state = GRASS
|
||||
@@ -312,7 +308,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
x2 = max(p[0] for p in poly)
|
||||
self.cloud_poly.append((poly, x1, x2))
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.world.contactListener_bug_workaround = ContactDetector(self)
|
||||
self.world.contactListener = self.world.contactListener_bug_workaround
|
||||
|
@@ -32,6 +32,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
|
||||
"""
|
||||
import sys
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import Box2D
|
||||
@@ -121,7 +123,6 @@ class CarRacing(gym.Env, EzPickle):
|
||||
|
||||
def __init__(self, verbose=1):
|
||||
EzPickle.__init__(self)
|
||||
self.seed()
|
||||
self.contactListener_keepref = FrictionDetector(self)
|
||||
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
||||
self.viewer = None
|
||||
@@ -145,10 +146,6 @@ class CarRacing(gym.Env, EzPickle):
|
||||
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
|
||||
)
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def _destroy(self):
|
||||
if not self.road:
|
||||
return
|
||||
@@ -343,7 +340,8 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self.track = track
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.reward = 0.0
|
||||
self.prev_reward = 0.0
|
||||
|
@@ -28,6 +28,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
|
||||
|
||||
import math
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import Box2D
|
||||
@@ -93,7 +95,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
|
||||
def __init__(self):
|
||||
EzPickle.__init__(self)
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
|
||||
self.world = Box2D.b2World()
|
||||
@@ -117,10 +118,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
# Nop, fire left engine, main engine, right engine
|
||||
self.action_space = spaces.Discrete(4)
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def _destroy(self):
|
||||
if not self.moon:
|
||||
return
|
||||
@@ -133,7 +130,8 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.world.DestroyBody(self.legs[0])
|
||||
self.world.DestroyBody(self.legs[1])
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.world.contactListener_keepref = ContactDetector(self)
|
||||
self.world.contactListener = self.world.contactListener_keepref
|
||||
@@ -504,10 +502,9 @@ def heuristic(env, s):
|
||||
|
||||
|
||||
def demo_heuristic_lander(env, seed=None, render=False):
|
||||
env.seed(seed)
|
||||
total_reward = 0
|
||||
steps = 0
|
||||
s = env.reset()
|
||||
s = env.reset(seed=seed)
|
||||
while True:
|
||||
a = heuristic(env, s)
|
||||
s, r, done, info = env.step(a)
|
||||
|
@@ -1,4 +1,6 @@
|
||||
"""classic Acrobot task"""
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy import sin, cos, pi
|
||||
|
||||
@@ -94,13 +96,9 @@ class AcrobotEnv(core.Env):
|
||||
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
self.state = None
|
||||
self.seed()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
|
||||
np.float32
|
||||
)
|
||||
|
@@ -5,6 +5,8 @@ permalink: https://perma.cc/C9ZM-652R
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
from gym import spaces, logger
|
||||
from gym.utils import seeding
|
||||
@@ -90,16 +92,11 @@ class CartPoleEnv(gym.Env):
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
self.state = None
|
||||
|
||||
self.steps_beyond_done = None
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||
assert self.action_space.contains(action), err_msg
|
||||
@@ -158,7 +155,8 @@ class CartPoleEnv(gym.Env):
|
||||
|
||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
||||
self.steps_beyond_done = None
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
|
@@ -14,6 +14,7 @@ permalink: https://perma.cc/6Z2N-PFWC
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -83,12 +84,6 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
low=self.low_state, high=self.high_state, dtype=np.float32
|
||||
)
|
||||
|
||||
self.seed()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
|
||||
position = self.state[0]
|
||||
@@ -119,7 +114,8 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
self.state = np.array([position, velocity], dtype=np.float32)
|
||||
return self.state, reward, done, {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
|
||||
|
@@ -3,6 +3,7 @@ http://incompleteideas.net/MountainCar/MountainCar1.cp
|
||||
permalink: https://perma.cc/6Z2N-PFWC
|
||||
"""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -72,12 +73,6 @@ class MountainCarEnv(gym.Env):
|
||||
self.action_space = spaces.Discrete(3)
|
||||
self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)
|
||||
|
||||
self.seed()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
assert self.action_space.contains(
|
||||
action
|
||||
@@ -97,7 +92,8 @@ class MountainCarEnv(gym.Env):
|
||||
self.state = (position, velocity)
|
||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
@@ -24,12 +26,6 @@ class PendulumEnv(gym.Env):
|
||||
)
|
||||
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
|
||||
|
||||
self.seed()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, u):
|
||||
th, thdot = self.state # th := theta
|
||||
|
||||
@@ -49,7 +45,8 @@ class PendulumEnv(gym.Env):
|
||||
self.state = np.array([newth, newthdot])
|
||||
return self._get_obs(), -costs, False, {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
high = np.array([np.pi, 1])
|
||||
self.state = self.np_random.uniform(low=-high, high=high)
|
||||
self.last_u = None
|
||||
|
@@ -48,7 +48,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
size=self.model.nq, low=-0.1, high=0.1
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -131,8 +131,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
||||
self.model.nv
|
||||
qvel = (
|
||||
self.init_qvel
|
||||
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
|
||||
)
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
|
@@ -31,7 +31,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=-0.1, high=0.1, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -74,8 +74,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
||||
self.model.nv
|
||||
qvel = (
|
||||
self.init_qvel
|
||||
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
|
||||
)
|
||||
|
||||
self.set_state(qpos, qvel)
|
||||
|
@@ -35,7 +35,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
self.set_state(
|
||||
self.init_qpos
|
||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
|
||||
self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1,
|
||||
)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from gym import error, spaces
|
||||
from gym.utils import seeding
|
||||
@@ -73,8 +73,6 @@ class MujocoEnv(gym.Env):
|
||||
|
||||
self._set_observation_space(observation)
|
||||
|
||||
self.seed()
|
||||
|
||||
def _set_action_space(self):
|
||||
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
|
||||
low, high = bounds.T
|
||||
@@ -85,10 +83,6 @@ class MujocoEnv(gym.Env):
|
||||
self.observation_space = convert_observation_to_space(observation)
|
||||
return self.observation_space
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
# methods to override:
|
||||
# ----------------------------
|
||||
|
||||
@@ -109,7 +103,8 @@ class MujocoEnv(gym.Env):
|
||||
|
||||
# -----------------------------
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.sim.reset()
|
||||
ob = self.reset_model()
|
||||
return ob
|
||||
|
@@ -183,7 +183,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
z_quat = quat_from_angle_and_axis(angle, axis)
|
||||
parallel_quat = self.parallel_quats[
|
||||
self.np_random.randint(len(self.parallel_quats))
|
||||
self.np_random.integers(len(self.parallel_quats))
|
||||
]
|
||||
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
|
||||
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
|
||||
@@ -254,7 +254,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
target_quat = quat_from_angle_and_axis(angle, axis)
|
||||
parallel_quat = self.parallel_quats[
|
||||
self.np_random.randint(len(self.parallel_quats))
|
||||
self.np_random.integers(len(self.parallel_quats))
|
||||
]
|
||||
target_quat = rotations.quat_mul(target_quat, parallel_quat)
|
||||
elif self.target_rotation == "xyz":
|
||||
|
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
@@ -37,7 +39,6 @@ class RobotEnv(gym.GoalEnv):
|
||||
"video.frames_per_second": int(np.round(1.0 / self.dt)),
|
||||
}
|
||||
|
||||
self.seed()
|
||||
self._env_setup(initial_qpos=initial_qpos)
|
||||
self.initial_state = copy.deepcopy(self.sim.get_state())
|
||||
|
||||
@@ -65,10 +66,6 @@ class RobotEnv(gym.GoalEnv):
|
||||
# Env methods
|
||||
# ----------------------------
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
if np.array(action).shape != self.action_space.shape:
|
||||
raise ValueError("Action dimension mismatch")
|
||||
@@ -86,13 +83,13 @@ class RobotEnv(gym.GoalEnv):
|
||||
reward = self.compute_reward(obs["achieved_goal"], self.goal, info)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
# Attempt to reset the simulator. Since we randomize initial conditions, it
|
||||
# is possible to get into a state with numerical issues (e.g. due to penetration or
|
||||
# Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand).
|
||||
# In this case, we just keep randomizing until we eventually achieve a valid initial
|
||||
# configuration.
|
||||
super().reset()
|
||||
super().reset(seed=seed)
|
||||
did_reset_sim = False
|
||||
while not did_reset_sim:
|
||||
did_reset_sim = self._reset_sim()
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
@@ -78,7 +80,6 @@ class BlackjackEnv(gym.Env):
|
||||
self.observation_space = spaces.Tuple(
|
||||
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
||||
)
|
||||
self.seed()
|
||||
|
||||
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
|
||||
# Ref: http://www.bicyclecards.com/how-to-play/blackjack/
|
||||
@@ -87,10 +88,6 @@ class BlackjackEnv(gym.Env):
|
||||
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
|
||||
self.sab = sab
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
assert self.action_space.contains(action)
|
||||
if action: # hit: add a card to players hand and return
|
||||
@@ -122,7 +119,8 @@ class BlackjackEnv(gym.Env):
|
||||
def _get_obs(self):
|
||||
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.dealer = draw_hand(self.np_random)
|
||||
self.player = draw_hand(self.np_random)
|
||||
return self._get_obs()
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym import Env, spaces
|
||||
@@ -11,7 +13,7 @@ def categorical_sample(prob_n, np_random):
|
||||
"""
|
||||
prob_n = np.asarray(prob_n)
|
||||
csprob_n = np.cumsum(prob_n)
|
||||
return (csprob_n > np_random.rand()).argmax()
|
||||
return (csprob_n > np_random.random()).argmax()
|
||||
|
||||
|
||||
class DiscreteEnv(Env):
|
||||
@@ -40,14 +42,8 @@ class DiscreteEnv(Env):
|
||||
self.action_space = spaces.Discrete(self.nA)
|
||||
self.observation_space = spaces.Discrete(self.nS)
|
||||
|
||||
self.seed()
|
||||
self.s = categorical_sample(self.isd, self.np_random)
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.s = categorical_sample(self.isd, self.np_random)
|
||||
self.lastaction = None
|
||||
return int(self.s)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym import spaces
|
||||
@@ -52,7 +54,6 @@ class CubeCrash(gym.Env):
|
||||
use_random_colors = False # Makes env too hard
|
||||
|
||||
def __init__(self):
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
@@ -62,23 +63,20 @@ class CubeCrash(gym.Env):
|
||||
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def random_color(self):
|
||||
return np.array(
|
||||
[
|
||||
self.np_random.randint(low=0, high=255),
|
||||
self.np_random.randint(low=0, high=255),
|
||||
self.np_random.randint(low=0, high=255),
|
||||
self.np_random.integers(low=0, high=255),
|
||||
self.np_random.integers(low=0, high=255),
|
||||
self.np_random.integers(low=0, high=255),
|
||||
]
|
||||
).astype("uint8")
|
||||
|
||||
def reset(self):
|
||||
self.cube_x = self.np_random.randint(low=3, high=FIELD_W - 3)
|
||||
self.cube_y = self.np_random.randint(low=3, high=FIELD_H // 6)
|
||||
self.hole_x = self.np_random.randint(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.cube_x = self.np_random.integers(low=3, high=FIELD_W - 3)
|
||||
self.cube_y = self.np_random.integers(low=3, high=FIELD_H // 6)
|
||||
self.hole_x = self.np_random.integers(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
|
||||
self.bg_color = self.random_color() if self.use_random_colors else color_black
|
||||
self.potential = None
|
||||
self.step_n = 0
|
||||
@@ -95,6 +93,7 @@ class CubeCrash(gym.Env):
|
||||
):
|
||||
continue
|
||||
break
|
||||
|
||||
return self.step(0)[0]
|
||||
|
||||
def step(self, action):
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym import spaces
|
||||
@@ -60,7 +62,6 @@ class MemorizeDigits(gym.Env):
|
||||
use_random_colors = False
|
||||
|
||||
def __init__(self):
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
self.observation_space = spaces.Box(
|
||||
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
||||
@@ -74,22 +75,19 @@ class MemorizeDigits(gym.Env):
|
||||
]
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def random_color(self):
|
||||
return np.array(
|
||||
[
|
||||
self.np_random.randint(low=0, high=255),
|
||||
self.np_random.randint(low=0, high=255),
|
||||
self.np_random.randint(low=0, high=255),
|
||||
self.np_random.integers(low=0, high=255),
|
||||
self.np_random.integers(low=0, high=255),
|
||||
self.np_random.integers(low=0, high=255),
|
||||
]
|
||||
).astype("uint8")
|
||||
|
||||
def reset(self):
|
||||
self.digit_x = self.np_random.randint(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
|
||||
self.digit_y = self.np_random.randint(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.digit_x = self.np_random.integers(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
|
||||
self.digit_y = self.np_random.integers(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
|
||||
self.color_bg = self.random_color() if self.use_random_colors else color_black
|
||||
self.step_n = 0
|
||||
while 1:
|
||||
@@ -111,8 +109,8 @@ class MemorizeDigits(gym.Env):
|
||||
else:
|
||||
if self.digit == action:
|
||||
reward = +1
|
||||
done = self.step_n > 20 and 0 == self.np_random.randint(low=0, high=5)
|
||||
self.digit = self.np_random.randint(low=0, high=10)
|
||||
done = self.step_n > 20 and 0 == self.np_random.integers(low=0, high=5)
|
||||
self.digit = self.np_random.integers(low=0, high=10)
|
||||
obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
||||
obs[:, :, :] = self.color_bg
|
||||
digit_img = np.zeros((6, 6, 3), dtype=np.uint8)
|
||||
|
@@ -35,7 +35,7 @@ class MultiBinary(Space):
|
||||
super().__init__(input_n, np.int8, seed)
|
||||
|
||||
def sample(self):
|
||||
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
if isinstance(x, list) or isinstance(x, tuple):
|
||||
|
@@ -35,9 +35,7 @@ class MultiDiscrete(Space):
|
||||
super().__init__(self.nvec.shape, dtype, seed)
|
||||
|
||||
def sample(self):
|
||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
||||
self.dtype
|
||||
)
|
||||
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
if isinstance(x, list):
|
||||
@@ -61,7 +59,7 @@ class MultiDiscrete(Space):
|
||||
subspace = Discrete(nvec)
|
||||
else:
|
||||
subspace = MultiDiscrete(nvec, self.dtype)
|
||||
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
|
||||
subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
|
||||
return subspace
|
||||
|
||||
def __len__(self):
|
||||
|
@@ -1,43 +1,118 @@
|
||||
import hashlib
|
||||
import numpy as np
|
||||
import os
|
||||
import random as _random
|
||||
import struct
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import Generator
|
||||
|
||||
from gym import error
|
||||
from gym.logger import deprecation
|
||||
|
||||
|
||||
def np_random(seed=None):
|
||||
def np_random(seed: Optional[int] = None):
|
||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
||||
|
||||
seed = create_seed(seed)
|
||||
|
||||
rng = np.random.RandomState()
|
||||
rng.seed(_int_list_from_bigint(hash_seed(seed)))
|
||||
seed_seq = np.random.SeedSequence(seed)
|
||||
seed = seed_seq.entropy
|
||||
rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
|
||||
return rng, seed
|
||||
|
||||
|
||||
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
|
||||
# RandomNumberGenerator = np.random.Generator
|
||||
class RandomNumberGenerator(np.random.Generator):
|
||||
def rand(self, *size):
|
||||
deprecation(
|
||||
"Function `rng.rand(*size)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
"Please use `Generator.random(size)` instead."
|
||||
)
|
||||
|
||||
return self.random(size)
|
||||
|
||||
random_sample = rand
|
||||
|
||||
def randn(self, *size):
|
||||
deprecation(
|
||||
"Function `rng.randn(*size)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
"Please use `rng.standard_normal(size)` instead."
|
||||
)
|
||||
|
||||
return self.standard_normal(size)
|
||||
|
||||
def randint(self, low, high=None, size=None, dtype=int):
|
||||
deprecation(
|
||||
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
"Please use `rng.integers(low, [high, size, dtype])` instead."
|
||||
)
|
||||
|
||||
return self.integers(low=low, high=high, size=size, dtype=dtype)
|
||||
|
||||
random_integers = randint
|
||||
|
||||
def get_state(self):
|
||||
deprecation(
|
||||
"Function `rng.get_state()` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
"Please use `rng.bit_generator.state` instead."
|
||||
)
|
||||
|
||||
return self.bit_generator.state
|
||||
|
||||
def set_state(self, state):
|
||||
deprecation(
|
||||
"Function `rng.set_state(state)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
"Please use `rng.bit_generator.state = state` instead."
|
||||
)
|
||||
|
||||
self.bit_generator.state = state
|
||||
|
||||
def seed(self, seed=None):
|
||||
deprecation(
|
||||
"Function `rng.seed(seed)` is marked as deprecated "
|
||||
"and will be removed in the future. "
|
||||
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
|
||||
)
|
||||
|
||||
self.bit_generator.state = type(self.bit_generator)(seed).state
|
||||
|
||||
rand.__doc__ = np.random.rand.__doc__
|
||||
randn.__doc__ = np.random.randn.__doc__
|
||||
randint.__doc__ = np.random.randint.__doc__
|
||||
get_state.__doc__ = np.random.get_state.__doc__
|
||||
set_state.__doc__ = np.random.set_state.__doc__
|
||||
seed.__doc__ = np.random.seed.__doc__
|
||||
|
||||
|
||||
RNG = RandomNumberGenerator
|
||||
|
||||
# Legacy functions
|
||||
|
||||
|
||||
def hash_seed(seed=None, max_bytes=8):
|
||||
"""Any given evaluation is likely to have many PRNG's active at
|
||||
once. (Most commonly, because the environment is running in
|
||||
multiple processes.) There's literature indicating that having
|
||||
linear correlations between seeds of multiple PRNG's can correlate
|
||||
the outputs:
|
||||
|
||||
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
||||
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
||||
http://dl.acm.org/citation.cfm?id=1276928
|
||||
|
||||
Thus, for sanity we hash the seeds before using them. (This scheme
|
||||
is likely not crypto-strength, but it should be good enough to get
|
||||
rid of simple correlations.)
|
||||
|
||||
Args:
|
||||
seed (Optional[int]): None seeds from an operating system specific randomness source.
|
||||
max_bytes: Maximum number of bytes to use in the hashed seed.
|
||||
"""
|
||||
deprecation(
|
||||
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||
)
|
||||
if seed is None:
|
||||
seed = create_seed(max_bytes=max_bytes)
|
||||
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
|
||||
@@ -48,11 +123,13 @@ def create_seed(a=None, max_bytes=8):
|
||||
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
||||
the system time, which might be non-robust especially in the
|
||||
presence of concurrency.
|
||||
|
||||
Args:
|
||||
a (Optional[int, str]): None seeds from an operating system specific randomness source.
|
||||
max_bytes: Maximum number of bytes to use in the seed.
|
||||
"""
|
||||
deprecation(
|
||||
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||
)
|
||||
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
|
||||
if a is None:
|
||||
a = _bigint_from_bytes(os.urandom(max_bytes))
|
||||
@@ -70,6 +147,9 @@ def create_seed(a=None, max_bytes=8):
|
||||
|
||||
# TODO: don't hardcode sizeof_int here
|
||||
def _bigint_from_bytes(bytes):
|
||||
deprecation(
|
||||
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
|
||||
)
|
||||
sizeof_int = 4
|
||||
padding = sizeof_int - len(bytes) % sizeof_int
|
||||
bytes += b"\0" * padding
|
||||
@@ -82,6 +162,9 @@ def _bigint_from_bytes(bytes):
|
||||
|
||||
|
||||
def _int_list_from_bigint(bigint):
|
||||
deprecation(
|
||||
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
||||
)
|
||||
# Special case 0
|
||||
if bigint < 0:
|
||||
raise error.Error(f"Seed must be non-negative, not {bigint}")
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
@@ -6,6 +8,7 @@ from enum import Enum
|
||||
from copy import deepcopy
|
||||
|
||||
from gym import logger
|
||||
from gym.logger import warn
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
from gym.error import (
|
||||
AlreadyPendingCallError,
|
||||
@@ -187,13 +190,14 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_observation_spaces()
|
||||
|
||||
def seed(self, seeds=None):
|
||||
def seed(self, seed=None):
|
||||
super().seed(seed=seed)
|
||||
self._assert_is_running()
|
||||
if seeds is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seeds, int):
|
||||
seeds = [seeds + i for i in range(self.num_envs)]
|
||||
assert len(seeds) == self.num_envs
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
@@ -201,12 +205,12 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
for pipe, seed in zip(self.parent_pipes, seeds):
|
||||
for pipe, seed in zip(self.parent_pipes, seed):
|
||||
pipe.send(("seed", seed))
|
||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
|
||||
def reset_async(self):
|
||||
def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
|
||||
"""Send the calls to :obj:`reset` to each sub-environment.
|
||||
|
||||
Raises
|
||||
@@ -221,24 +225,31 @@ class AsyncVectorEnv(VectorEnv):
|
||||
between.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("reset", None))
|
||||
for pipe, single_seed in zip(self.parent_pipes, seed):
|
||||
pipe.send(("reset", single_seed))
|
||||
self._state = AsyncState.WAITING_RESET
|
||||
|
||||
def reset_wait(self, timeout=None):
|
||||
"""Wait for the calls to :obj:`reset` in each sub-environment to finish.
|
||||
|
||||
def reset_wait(self, timeout=None, seed: Optional[int] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to :meth:`reset_wait` times out.
|
||||
If ``None``, the call to :meth:`reset_wait` never times out.
|
||||
Number of seconds before the call to `reset_wait` times out. If
|
||||
`None`, the call to `reset_wait` never times out.
|
||||
seed: ignored
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -486,7 +497,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
observation = env.reset()
|
||||
observation = env.reset(data)
|
||||
pipe.send((observation, True))
|
||||
elif command == "step":
|
||||
observation, reward, done, info = env.step(data)
|
||||
@@ -524,7 +535,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
observation = env.reset()
|
||||
observation = env.reset(data)
|
||||
write_to_shared_memory(
|
||||
index, observation, shared_memory, observation_space
|
||||
)
|
||||
|
@@ -1,7 +1,10 @@
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
from gym import logger
|
||||
from gym.logger import warn
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
from gym.vector.utils import concatenate, create_empty_array
|
||||
|
||||
@@ -72,21 +75,28 @@ class SyncVectorEnv(VectorEnv):
|
||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._actions = None
|
||||
|
||||
def seed(self, seeds=None):
|
||||
if seeds is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seeds, int):
|
||||
seeds = [seeds + i for i in range(self.num_envs)]
|
||||
assert len(seeds) == self.num_envs
|
||||
def seed(self, seed=None):
|
||||
super().seed(seed=seed)
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
for env, seed in zip(self.envs, seeds):
|
||||
env.seed(seed)
|
||||
for env, single_seed in zip(self.envs, seed):
|
||||
env.seed(single_seed)
|
||||
|
||||
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
def reset_wait(self):
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
for env in self.envs:
|
||||
observation = env.reset()
|
||||
for env, single_seed in zip(self.envs, seed):
|
||||
observation = env.reset(seed=single_seed)
|
||||
observations.append(observation)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
|
@@ -1,4 +1,7 @@
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import gym
|
||||
from gym.logger import warn, deprecation
|
||||
from gym.spaces import Tuple
|
||||
from gym.vector.utils.spaces import batch_space
|
||||
|
||||
@@ -43,13 +46,13 @@ class VectorEnv(gym.Env):
|
||||
self.single_observation_space = observation_space
|
||||
self.single_action_space = action_space
|
||||
|
||||
def reset_async(self):
|
||||
def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
|
||||
pass
|
||||
|
||||
def reset_wait(self, **kwargs):
|
||||
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[Union[int, List[int]]] = None):
|
||||
r"""Reset all sub-environments and return a batch of initial observations.
|
||||
|
||||
Returns
|
||||
@@ -57,8 +60,8 @@ class VectorEnv(gym.Env):
|
||||
element of :attr:`observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
"""
|
||||
self.reset_async()
|
||||
return self.reset_wait()
|
||||
self.reset_async(seed=seed)
|
||||
return self.reset_wait(seed=seed)
|
||||
|
||||
def step_async(self, actions):
|
||||
pass
|
||||
@@ -120,19 +123,22 @@ class VectorEnv(gym.Env):
|
||||
self.close_extras(**kwargs)
|
||||
self.closed = True
|
||||
|
||||
def seed(self, seeds=None):
|
||||
def seed(self, seed=None):
|
||||
"""Set the random seed in all sub-environments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seeds : list of int, or int, optional
|
||||
Random seed for each sub-environment. If ``seeds`` is a list of
|
||||
seed : list of int, or int, optional
|
||||
Random seed for each sub-environment. If ``seed`` is a list of
|
||||
length ``num_envs``, then the items of the list are chosen as random
|
||||
seeds. If ``seeds`` is an int, then each sub-environment uses the random
|
||||
seed ``seeds + n``, where ``n`` is the index of the sub-environment
|
||||
seeds. If ``seed`` is an int, then each sub-environment uses the random
|
||||
seed ``seed + n``, where ``n`` is the index of the sub-environment
|
||||
(between ``0`` and ``num_envs - 1``).
|
||||
"""
|
||||
pass
|
||||
deprecation(
|
||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||
"Please use `env.reset(seed=seed) instead in VectorEnvs."
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if not getattr(self, "closed", True):
|
||||
@@ -164,11 +170,11 @@ class VectorEnvWrapper(VectorEnv):
|
||||
|
||||
# explicitly forward the methods defined in VectorEnv
|
||||
# to self.env (instead of the base class)
|
||||
def reset_async(self):
|
||||
return self.env.reset_async()
|
||||
def reset_async(self, **kwargs):
|
||||
return self.env.reset_async(**kwargs)
|
||||
|
||||
def reset_wait(self):
|
||||
return self.env.reset_wait()
|
||||
def reset_wait(self, **kwargs):
|
||||
return self.env.reset_wait(**kwargs)
|
||||
|
||||
def step_async(self, actions):
|
||||
return self.env.step_async(actions)
|
||||
@@ -182,8 +188,8 @@ class VectorEnvWrapper(VectorEnv):
|
||||
def close_extras(self, **kwargs):
|
||||
return self.env.close_extras(**kwargs)
|
||||
|
||||
def seed(self, seeds=None):
|
||||
return self.env.seed(seeds)
|
||||
def seed(self, seed=None):
|
||||
return self.env.seed(seed)
|
||||
|
||||
# implicitly forward all other methods and attributes to self.env
|
||||
def __getattr__(self, name):
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
@@ -127,11 +129,11 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
return self._get_obs(), R, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
# NoopReset
|
||||
self.env.reset(**kwargs)
|
||||
self.env.reset(seed=seed, **kwargs)
|
||||
noops = (
|
||||
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
if self.noop_max > 0
|
||||
else 0
|
||||
)
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
from gym import ObservationWrapper
|
||||
@@ -116,7 +118,7 @@ class FrameStack(ObservationWrapper):
|
||||
self.frames.append(observation)
|
||||
return self.observation(), reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observation = self.env.reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
observation = self.env.reset(seed=seed, **kwargs)
|
||||
[self.frames.append(observation) for _ in range(self.num_stack)]
|
||||
return self.observation()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -50,9 +51,9 @@ class Monitor(Wrapper):
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
self._before_reset()
|
||||
observation = self.env.reset(**kwargs)
|
||||
observation = self.env.reset(seed=seed, **kwargs)
|
||||
self._after_reset(observation)
|
||||
|
||||
return observation
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
@@ -61,8 +63,8 @@ class NormalizeObservation(gym.core.Wrapper):
|
||||
obs = self.normalize(np.array([obs]))[0]
|
||||
return obs, rews, dones, infos
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
obs = self.env.reset(seed=seed)
|
||||
if self.is_vector_env:
|
||||
obs = self.normalize(obs)
|
||||
else:
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
|
||||
|
||||
@@ -11,6 +13,6 @@ class OrderEnforcing(gym.Wrapper):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
self._has_reset = True
|
||||
return self.env.reset(**kwargs)
|
||||
return self.env.reset(seed=seed, **kwargs)
|
||||
|
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
@@ -16,8 +18,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
self.length_queue = deque(maxlen=deque_size)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observations = super().reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
observations = super().reset(seed=seed, **kwargs)
|
||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
||||
return observations
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import gym
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
from gym import logger
|
||||
from gym.wrappers.monitoring import video_recorder
|
||||
@@ -52,8 +52,8 @@ class RecordVideo(gym.Wrapper):
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
self.episode_id = 0
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observations = super().reset(**kwargs)
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
observations = super().reset(seed=seed, **kwargs)
|
||||
if not self.recording and self._video_enabled():
|
||||
self.start_video_recorder()
|
||||
return observations
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
from gym import ObservationWrapper
|
||||
@@ -27,6 +29,6 @@ class TimeAwareObservation(ObservationWrapper):
|
||||
self.t += 1
|
||||
return super().step(action)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
self.t = 0
|
||||
return super().reset(**kwargs)
|
||||
return super().reset(seed=seed, **kwargs)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
|
||||
|
||||
@@ -22,6 +24,6 @@ class TimeLimit(gym.Wrapper):
|
||||
done = True
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset(**kwargs)
|
||||
return self.env.reset(seed=seed, **kwargs)
|
||||
|
@@ -15,6 +15,6 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mjpro150/bin
|
||||
COPY . /usr/local/gym/
|
||||
WORKDIR /usr/local/gym/
|
||||
|
||||
RUN pip install .[nomujoco,accept-rom-license] && pip install -r test_requirements.txt
|
||||
RUN pip install .[nomujoco] && pip install -r test_requirements.txt
|
||||
|
||||
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
||||
|
2
setup.py
2
setup.py
@@ -21,7 +21,7 @@ extras = {
|
||||
}
|
||||
|
||||
# Meta dependency groups.
|
||||
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license"])
|
||||
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license", "atari"])
|
||||
nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
|
||||
|
||||
extras["nomujoco"] = list(
|
||||
|
@@ -10,16 +10,14 @@ def test_env(spec):
|
||||
# threads. However, we probably already can't do multithreading
|
||||
# due to some environments.
|
||||
env1 = spec.make()
|
||||
env1.seed(0)
|
||||
initial_observation1 = env1.reset()
|
||||
initial_observation1 = env1.reset(seed=0)
|
||||
env1.action_space.seed(0)
|
||||
action_samples1 = [env1.action_space.sample() for i in range(4)]
|
||||
step_responses1 = [env1.step(action) for action in action_samples1]
|
||||
env1.close()
|
||||
|
||||
env2 = spec.make()
|
||||
env2.seed(0)
|
||||
initial_observation2 = env2.reset()
|
||||
initial_observation2 = env2.reset(seed=0)
|
||||
env2.action_space.seed(0)
|
||||
action_samples2 = [env2.action_space.sample() for i in range(4)]
|
||||
step_responses2 = [env2.step(action) for action in action_samples2]
|
||||
|
@@ -10,11 +10,8 @@ def verify_environments_match(
|
||||
old_environment = envs.make(old_environment_id)
|
||||
new_environment = envs.make(new_environment_id)
|
||||
|
||||
old_environment.seed(seed)
|
||||
new_environment.seed(seed)
|
||||
|
||||
old_reset_observation = old_environment.reset()
|
||||
new_reset_observation = new_environment.reset()
|
||||
old_reset_observation = old_environment.reset(seed=seed)
|
||||
new_reset_observation = new_environment.reset(seed=seed)
|
||||
|
||||
np.testing.assert_allclose(old_reset_observation, new_reset_observation)
|
||||
|
||||
|
@@ -370,7 +370,7 @@ def test_seed_subspace_incorrelated(space):
|
||||
|
||||
space.seed(0)
|
||||
states = [
|
||||
convert_sample_hashable(subspace.np_random.get_state())
|
||||
convert_sample_hashable(subspace.np_random.bit_generator.state)
|
||||
for subspace in subspaces
|
||||
]
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
@@ -16,7 +18,8 @@ class UnittestEnv(core.Env):
|
||||
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
||||
action_space = spaces.Discrete(3)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
return self.observation_space.sample() # Dummy observation
|
||||
|
||||
def step(self, action):
|
||||
@@ -31,7 +34,8 @@ class UnknownSpacesEnv(core.Env):
|
||||
on external resources), it is not encouraged.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||
)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -16,7 +18,8 @@ class ActionDictTestEnv(gym.Env):
|
||||
done = True
|
||||
return observation, reward, done
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
return np.array([1.0, 1.5, 0.5])
|
||||
|
||||
def render(self, mode="human"):
|
||||
|
@@ -17,17 +17,14 @@ def test_vector_env_equal(shared_memory):
|
||||
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
sync_env = SyncVectorEnv(env_fns)
|
||||
|
||||
async_env.seed(0)
|
||||
sync_env.seed(0)
|
||||
|
||||
assert async_env.num_envs == sync_env.num_envs
|
||||
assert async_env.observation_space == sync_env.observation_space
|
||||
assert async_env.single_observation_space == sync_env.single_observation_space
|
||||
assert async_env.action_space == sync_env.action_space
|
||||
assert async_env.single_action_space == sync_env.single_action_space
|
||||
|
||||
async_observations = async_env.reset()
|
||||
sync_observations = sync_env.reset()
|
||||
async_observations = async_env.reset(seed=0)
|
||||
sync_observations = sync_env.reset(seed=0)
|
||||
assert np.all(async_observations == sync_observations)
|
||||
|
||||
for _ in range(num_steps):
|
||||
|
@@ -8,7 +8,7 @@ class DummyWrapper(VectorEnvWrapper):
|
||||
self.env = env
|
||||
self.counter = 0
|
||||
|
||||
def reset_async(self):
|
||||
def reset_async(self, **kwargs):
|
||||
super().reset_async()
|
||||
self.counter += 1
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import time
|
||||
@@ -55,7 +57,8 @@ class UnittestSlowEnv(gym.Env):
|
||||
)
|
||||
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
if self.slow_reset > 0:
|
||||
time.sleep(self.slow_reset)
|
||||
return self.observation_space.sample()
|
||||
@@ -86,7 +89,8 @@ class CustomSpaceEnv(gym.Env):
|
||||
self.observation_space = CustomSpace()
|
||||
self.action_space = CustomSpace()
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
return "reset"
|
||||
|
||||
def step(self, action):
|
||||
@@ -98,7 +102,7 @@ class CustomSpaceEnv(gym.Env):
|
||||
def make_env(env_name, seed):
|
||||
def _make():
|
||||
env = gym.make(env_name)
|
||||
env.seed(seed)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
|
||||
return _make
|
||||
@@ -107,7 +111,7 @@ def make_env(env_name, seed):
|
||||
def make_slow_env(slow_reset, seed):
|
||||
def _make():
|
||||
env = UnittestSlowEnv(slow_reset=slow_reset)
|
||||
env.seed(seed)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
|
||||
return _make
|
||||
@@ -116,7 +120,7 @@ def make_slow_env(slow_reset, seed):
|
||||
def make_custom_space_env(seed):
|
||||
def _make():
|
||||
env = CustomSpaceEnv()
|
||||
env.seed(seed)
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
|
||||
return _make
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Tests for the flatten observation wrapper."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -14,7 +15,8 @@ class FakeEnvironment(gym.Env):
|
||||
def __init__(self, observation_space):
|
||||
self.observation_space = observation_space
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.observation = self.observation_space.sample()
|
||||
return self.observation
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
"""Tests for the filter observation wrapper."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
@@ -21,7 +21,8 @@ class FakeEnvironment(gym.Env):
|
||||
image_shape = (height, width, 3)
|
||||
return np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation
|
||||
|
||||
|
@@ -29,14 +29,10 @@ def test_atari_preprocessing_grayscale(env_fn):
|
||||
noop_max=0,
|
||||
grayscale_newaxis=True,
|
||||
)
|
||||
env1.seed(0)
|
||||
env2.seed(0)
|
||||
env3.seed(0)
|
||||
env4.seed(0)
|
||||
obs1 = env1.reset()
|
||||
obs2 = env2.reset()
|
||||
obs3 = env3.reset()
|
||||
obs4 = env4.reset()
|
||||
obs1 = env1.reset(seed=0)
|
||||
obs2 = env2.reset(seed=0)
|
||||
obs3 = env3.reset(seed=0)
|
||||
obs4 = env4.reset(seed=0)
|
||||
assert env1.observation_space.shape == (210, 160, 3)
|
||||
assert env2.observation_space.shape == (84, 84)
|
||||
assert env3.observation_space.shape == (84, 84, 3)
|
||||
|
@@ -11,11 +11,9 @@ def test_clip_action():
|
||||
wrapped_env = ClipAction(make_env())
|
||||
|
||||
seed = 0
|
||||
env.seed(seed)
|
||||
wrapped_env.seed(seed)
|
||||
|
||||
env.reset()
|
||||
wrapped_env.reset()
|
||||
env.reset(seed=seed)
|
||||
wrapped_env.reset(seed=seed)
|
||||
|
||||
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
||||
for action in actions:
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
@@ -22,7 +24,8 @@ class FakeEnvironment(gym.Env):
|
||||
image_shape = (height, width, 3)
|
||||
return np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation
|
||||
|
||||
|
@@ -28,17 +28,15 @@ except ImportError:
|
||||
)
|
||||
def test_frame_stack(env_id, num_stack, lz4_compress):
|
||||
env = gym.make(env_id)
|
||||
env.seed(0)
|
||||
shape = env.observation_space.shape
|
||||
env = FrameStack(env, num_stack, lz4_compress)
|
||||
assert env.observation_space.shape == (num_stack,) + shape
|
||||
assert env.observation_space.dtype == env.env.observation_space.dtype
|
||||
|
||||
dup = gym.make(env_id)
|
||||
dup.seed(0)
|
||||
|
||||
obs = env.reset()
|
||||
dup_obs = dup.reset()
|
||||
obs = env.reset(seed=0)
|
||||
dup_obs = dup.reset(seed=0)
|
||||
assert np.allclose(obs[-1], dup_obs)
|
||||
|
||||
for _ in range(num_stack ** 2):
|
||||
|
@@ -21,11 +21,9 @@ def test_gray_scale_observation(env_id, keep_dim):
|
||||
assert rgb_env.observation_space.shape[-1] == 3
|
||||
|
||||
seed = 0
|
||||
gray_env.seed(seed)
|
||||
wrapped_env.seed(seed)
|
||||
|
||||
gray_obs = gray_env.reset()
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
gray_obs = gray_env.reset(seed=seed)
|
||||
wrapped_obs = wrapped_env.reset(seed=seed)
|
||||
|
||||
if keep_dim:
|
||||
assert wrapped_env.observation_space.shape[-1] == 1
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from numpy.testing import assert_almost_equal
|
||||
@@ -21,7 +23,8 @@ class DummyRewardEnv(gym.Env):
|
||||
self.t += 1
|
||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
self.t = self.return_reward_idx
|
||||
return np.array([self.t])
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
"""Tests for the pixel observation wrapper."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
@@ -19,7 +19,8 @@ class FakeEnvironment(gym.Env):
|
||||
image_shape = (height, width, 3)
|
||||
return np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed: Optional[int] = None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation
|
||||
|
||||
|
@@ -16,11 +16,9 @@ def test_rescale_action():
|
||||
wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1)
|
||||
|
||||
seed = 0
|
||||
env.seed(seed)
|
||||
wrapped_env.seed(seed)
|
||||
|
||||
obs = env.reset()
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
obs = env.reset(seed=seed)
|
||||
wrapped_obs = wrapped_env.reset(seed=seed)
|
||||
assert np.allclose(obs, wrapped_obs)
|
||||
|
||||
obs, reward, _, _ = env.step([1.5])
|
||||
|
@@ -14,11 +14,8 @@ def test_transform_observation(env_id):
|
||||
gym.make(env_id), lambda obs: affine_transform(obs)
|
||||
)
|
||||
|
||||
env.seed(0)
|
||||
wrapped_env.seed(0)
|
||||
|
||||
obs = env.reset()
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
obs = env.reset(seed=0)
|
||||
wrapped_obs = wrapped_env.reset(seed=0)
|
||||
assert np.allclose(wrapped_obs, affine_transform(obs))
|
||||
|
||||
action = env.action_space.sample()
|
||||
|
@@ -15,10 +15,8 @@ def test_transform_reward(env_id):
|
||||
wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r)
|
||||
action = env.action_space.sample()
|
||||
|
||||
env.seed(0)
|
||||
env.reset()
|
||||
wrapped_env.seed(0)
|
||||
wrapped_env.reset()
|
||||
env.reset(seed=0)
|
||||
wrapped_env.reset(seed=0)
|
||||
|
||||
_, reward, _, _ = env.step(action)
|
||||
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
||||
@@ -33,10 +31,8 @@ def test_transform_reward(env_id):
|
||||
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r))
|
||||
action = env.action_space.sample()
|
||||
|
||||
env.seed(0)
|
||||
env.reset()
|
||||
wrapped_env.seed(0)
|
||||
wrapped_env.reset()
|
||||
env.reset(seed=0)
|
||||
wrapped_env.reset(seed=0)
|
||||
|
||||
_, reward, _, _ = env.step(action)
|
||||
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
||||
@@ -49,10 +45,8 @@ def test_transform_reward(env_id):
|
||||
env = gym.make(env_id)
|
||||
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r))
|
||||
|
||||
env.seed(0)
|
||||
env.reset()
|
||||
wrapped_env.seed(0)
|
||||
wrapped_env.reset()
|
||||
env.reset(seed=0)
|
||||
wrapped_env.reset(seed=0)
|
||||
|
||||
for _ in range(1000):
|
||||
action = env.action_space.sample()
|
||||
|
Reference in New Issue
Block a user