mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Seeding update (#2422)
* Ditch most of the seeding.py and replace np_random with the numpy default_rng. Let's see if tests pass * Updated a bunch of RNG calls from the RandomState API to Generator API * black; didn't expect that, did ya? * Undo a typo * blaaack * More typo fixes * Fixed setting/getting state in multidiscrete spaces * Fix typo, fix a test to work with the new sampling * Correctly (?) pass the randomly generated seed if np_random is called with None as seed * Convert the Discrete sample to a python int (as opposed to np.int64) * Remove some redundant imports * First version of the compatibility layer for old-style RNG. Mainly to trigger tests. * Removed redundant f-strings * Style fixes, removing unused imports * Try to make tests pass by removing atari from the dockerfile * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * Try to make tests pass by removing atari from the setup * First attempt at deprecating `env.seed` and supporting `env.reset(seed=seed)` instead. Tests should hopefully pass but throw up a million warnings. * black; didn't expect that, didya? * Rename the reset parameter in VecEnvs back to `seed` * Updated tests to use the new seeding method * Removed a bunch of old `seed` calls. Fixed a bug in AsyncVectorEnv * Stop Discrete envs from doing part of the setup (and using the randomness) in init (as opposed to reset) * Add explicit seed to wrappers reset * Remove an accidental return * Re-add some legacy functions with a warning. * Use deprecation instead of regular warnings for the newly deprecated methods/functions
This commit is contained in:
committed by
GitHub
parent
b84b69c872
commit
c364506710
@@ -50,6 +50,7 @@
|
|||||||
|
|
||||||
* `gym-foo/gym_foo/envs/foo_env.py` should look something like:
|
* `gym-foo/gym_foo/envs/foo_env.py` should look something like:
|
||||||
```python
|
```python
|
||||||
|
from typing import Optional
|
||||||
import gym
|
import gym
|
||||||
from gym import error, spaces, utils
|
from gym import error, spaces, utils
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
@@ -61,7 +62,8 @@
|
|||||||
...
|
...
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
...
|
...
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
...
|
...
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
...
|
...
|
||||||
|
39
gym/core.py
39
gym/core.py
@@ -1,8 +1,10 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import error
|
from gym import error
|
||||||
from gym.utils import closer
|
from gym.utils import closer, seeding
|
||||||
|
from gym.logger import deprecation
|
||||||
|
|
||||||
|
|
||||||
class Env:
|
class Env:
|
||||||
@@ -38,6 +40,9 @@ class Env:
|
|||||||
action_space = None
|
action_space = None
|
||||||
observation_space = None
|
observation_space = None
|
||||||
|
|
||||||
|
# Created
|
||||||
|
np_random = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Run one timestep of the environment's dynamics. When end of
|
"""Run one timestep of the environment's dynamics. When end of
|
||||||
@@ -58,7 +63,7 @@ class Env:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
"""Resets the environment to an initial state and returns an initial
|
"""Resets the environment to an initial state and returns an initial
|
||||||
observation.
|
observation.
|
||||||
|
|
||||||
@@ -71,7 +76,9 @@ class Env:
|
|||||||
Returns:
|
Returns:
|
||||||
observation (object): the initial observation.
|
observation (object): the initial observation.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
# Initialize the RNG if it's the first reset, or if the seed is manually passed
|
||||||
|
if seed is not None or self.np_random is None:
|
||||||
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
@@ -136,7 +143,12 @@ class Env:
|
|||||||
'seed'. Often, the main seed equals the provided 'seed', but
|
'seed'. Often, the main seed equals the provided 'seed', but
|
||||||
this won't be true if seed=None, for example.
|
this won't be true if seed=None, for example.
|
||||||
"""
|
"""
|
||||||
return
|
deprecation(
|
||||||
|
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||||
|
"Please use `env.reset(seed=seed) instead."
|
||||||
|
)
|
||||||
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
|
return [seed]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
@@ -173,7 +185,8 @@ class GoalEnv(Env):
|
|||||||
actual observations of the environment as per usual.
|
actual observations of the environment as per usual.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
# Enforce that each GoalEnv uses a Goal-compatible observation space.
|
# Enforce that each GoalEnv uses a Goal-compatible observation space.
|
||||||
if not isinstance(self.observation_space, gym.spaces.Dict):
|
if not isinstance(self.observation_space, gym.spaces.Dict):
|
||||||
raise error.Error(
|
raise error.Error(
|
||||||
@@ -286,8 +299,8 @@ class Wrapper(Env):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(seed=seed, **kwargs)
|
||||||
|
|
||||||
def render(self, mode="human", **kwargs):
|
def render(self, mode="human", **kwargs):
|
||||||
return self.env.render(mode, **kwargs)
|
return self.env.render(mode, **kwargs)
|
||||||
@@ -313,8 +326,8 @@ class Wrapper(Env):
|
|||||||
|
|
||||||
|
|
||||||
class ObservationWrapper(Wrapper):
|
class ObservationWrapper(Wrapper):
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
observation = self.env.reset(**kwargs)
|
observation = self.env.reset(seed=seed, **kwargs)
|
||||||
return self.observation(observation)
|
return self.observation(observation)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@@ -327,8 +340,8 @@ class ObservationWrapper(Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
class RewardWrapper(Wrapper):
|
class RewardWrapper(Wrapper):
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(seed=seed, **kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
@@ -340,8 +353,8 @@ class RewardWrapper(Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
class ActionWrapper(Wrapper):
|
class ActionWrapper(Wrapper):
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(seed=seed, **kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
return self.env.step(self.action(action))
|
return self.env.step(self.action(action))
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import Box2D
|
import Box2D
|
||||||
@@ -122,7 +123,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
EzPickle.__init__(self)
|
EzPickle.__init__(self)
|
||||||
self.seed()
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
self.world = Box2D.b2World()
|
self.world = Box2D.b2World()
|
||||||
@@ -149,10 +149,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
)
|
)
|
||||||
self.observation_space = spaces.Box(-high, high)
|
self.observation_space = spaces.Box(-high, high)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def _destroy(self):
|
def _destroy(self):
|
||||||
if not self.terrain:
|
if not self.terrain:
|
||||||
return
|
return
|
||||||
@@ -188,7 +184,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
y += velocity
|
y += velocity
|
||||||
|
|
||||||
elif state == PIT and oneshot:
|
elif state == PIT and oneshot:
|
||||||
counter = self.np_random.randint(3, 5)
|
counter = self.np_random.integers(3, 5)
|
||||||
poly = [
|
poly = [
|
||||||
(x, y),
|
(x, y),
|
||||||
(x + TERRAIN_STEP, y),
|
(x + TERRAIN_STEP, y),
|
||||||
@@ -215,7 +211,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
y -= 4 * TERRAIN_STEP
|
y -= 4 * TERRAIN_STEP
|
||||||
|
|
||||||
elif state == STUMP and oneshot:
|
elif state == STUMP and oneshot:
|
||||||
counter = self.np_random.randint(1, 3)
|
counter = self.np_random.integers(1, 3)
|
||||||
poly = [
|
poly = [
|
||||||
(x, y),
|
(x, y),
|
||||||
(x + counter * TERRAIN_STEP, y),
|
(x + counter * TERRAIN_STEP, y),
|
||||||
@@ -228,9 +224,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.terrain.append(t)
|
self.terrain.append(t)
|
||||||
|
|
||||||
elif state == STAIRS and oneshot:
|
elif state == STAIRS and oneshot:
|
||||||
stair_height = +1 if self.np_random.rand() > 0.5 else -1
|
stair_height = +1 if self.np_random.random() > 0.5 else -1
|
||||||
stair_width = self.np_random.randint(4, 5)
|
stair_width = self.np_random.integers(4, 5)
|
||||||
stair_steps = self.np_random.randint(3, 5)
|
stair_steps = self.np_random.integers(3, 5)
|
||||||
original_y = y
|
original_y = y
|
||||||
for s in range(stair_steps):
|
for s in range(stair_steps):
|
||||||
poly = [
|
poly = [
|
||||||
@@ -266,9 +262,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.terrain_y.append(y)
|
self.terrain_y.append(y)
|
||||||
counter -= 1
|
counter -= 1
|
||||||
if counter == 0:
|
if counter == 0:
|
||||||
counter = self.np_random.randint(TERRAIN_GRASS / 2, TERRAIN_GRASS)
|
counter = self.np_random.integers(TERRAIN_GRASS / 2, TERRAIN_GRASS)
|
||||||
if state == GRASS and hardcore:
|
if state == GRASS and hardcore:
|
||||||
state = self.np_random.randint(1, _STATES_)
|
state = self.np_random.integers(1, _STATES_)
|
||||||
oneshot = True
|
oneshot = True
|
||||||
else:
|
else:
|
||||||
state = GRASS
|
state = GRASS
|
||||||
@@ -312,7 +308,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
x2 = max(p[0] for p in poly)
|
x2 = max(p[0] for p in poly)
|
||||||
self.cloud_poly.append((poly, x1, x2))
|
self.cloud_poly.append((poly, x1, x2))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self._destroy()
|
self._destroy()
|
||||||
self.world.contactListener_bug_workaround = ContactDetector(self)
|
self.world.contactListener_bug_workaround = ContactDetector(self)
|
||||||
self.world.contactListener = self.world.contactListener_bug_workaround
|
self.world.contactListener = self.world.contactListener_bug_workaround
|
||||||
|
@@ -32,6 +32,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
|
|||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import Box2D
|
import Box2D
|
||||||
@@ -121,7 +123,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
|
|
||||||
def __init__(self, verbose=1):
|
def __init__(self, verbose=1):
|
||||||
EzPickle.__init__(self)
|
EzPickle.__init__(self)
|
||||||
self.seed()
|
|
||||||
self.contactListener_keepref = FrictionDetector(self)
|
self.contactListener_keepref = FrictionDetector(self)
|
||||||
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
@@ -145,10 +146,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
|
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
|
||||||
)
|
)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def _destroy(self):
|
def _destroy(self):
|
||||||
if not self.road:
|
if not self.road:
|
||||||
return
|
return
|
||||||
@@ -343,7 +340,8 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self.track = track
|
self.track = track
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self._destroy()
|
self._destroy()
|
||||||
self.reward = 0.0
|
self.reward = 0.0
|
||||||
self.prev_reward = 0.0
|
self.prev_reward = 0.0
|
||||||
|
@@ -28,6 +28,8 @@ Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym.
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import Box2D
|
import Box2D
|
||||||
@@ -93,7 +95,6 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
EzPickle.__init__(self)
|
EzPickle.__init__(self)
|
||||||
self.seed()
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
self.world = Box2D.b2World()
|
self.world = Box2D.b2World()
|
||||||
@@ -117,10 +118,6 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
# Nop, fire left engine, main engine, right engine
|
# Nop, fire left engine, main engine, right engine
|
||||||
self.action_space = spaces.Discrete(4)
|
self.action_space = spaces.Discrete(4)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def _destroy(self):
|
def _destroy(self):
|
||||||
if not self.moon:
|
if not self.moon:
|
||||||
return
|
return
|
||||||
@@ -133,7 +130,8 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.world.DestroyBody(self.legs[0])
|
self.world.DestroyBody(self.legs[0])
|
||||||
self.world.DestroyBody(self.legs[1])
|
self.world.DestroyBody(self.legs[1])
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self._destroy()
|
self._destroy()
|
||||||
self.world.contactListener_keepref = ContactDetector(self)
|
self.world.contactListener_keepref = ContactDetector(self)
|
||||||
self.world.contactListener = self.world.contactListener_keepref
|
self.world.contactListener = self.world.contactListener_keepref
|
||||||
@@ -504,10 +502,9 @@ def heuristic(env, s):
|
|||||||
|
|
||||||
|
|
||||||
def demo_heuristic_lander(env, seed=None, render=False):
|
def demo_heuristic_lander(env, seed=None, render=False):
|
||||||
env.seed(seed)
|
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
s = env.reset()
|
s = env.reset(seed=seed)
|
||||||
while True:
|
while True:
|
||||||
a = heuristic(env, s)
|
a = heuristic(env, s)
|
||||||
s, r, done, info = env.step(a)
|
s, r, done, info = env.step(a)
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
"""classic Acrobot task"""
|
"""classic Acrobot task"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import sin, cos, pi
|
from numpy import sin, cos, pi
|
||||||
|
|
||||||
@@ -94,13 +96,9 @@ class AcrobotEnv(core.Env):
|
|||||||
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
self.state = None
|
self.state = None
|
||||||
self.seed()
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def reset(self, seed: Optional[int] = None):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
super().reset(seed=seed)
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
|
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
|
||||||
np.float32
|
np.float32
|
||||||
)
|
)
|
||||||
|
@@ -5,6 +5,8 @@ permalink: https://perma.cc/C9ZM-652R
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces, logger
|
from gym import spaces, logger
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
@@ -90,16 +92,11 @@ class CartPoleEnv(gym.Env):
|
|||||||
self.action_space = spaces.Discrete(2)
|
self.action_space = spaces.Discrete(2)
|
||||||
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||||
|
|
||||||
self.seed()
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
self.steps_beyond_done = None
|
self.steps_beyond_done = None
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||||
assert self.action_space.contains(action), err_msg
|
assert self.action_space.contains(action), err_msg
|
||||||
@@ -158,7 +155,8 @@ class CartPoleEnv(gym.Env):
|
|||||||
|
|
||||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
||||||
self.steps_beyond_done = None
|
self.steps_beyond_done = None
|
||||||
return np.array(self.state, dtype=np.float32)
|
return np.array(self.state, dtype=np.float32)
|
||||||
|
@@ -14,6 +14,7 @@ permalink: https://perma.cc/6Z2N-PFWC
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -83,12 +84,6 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
low=self.low_state, high=self.high_state, dtype=np.float32
|
low=self.low_state, high=self.high_state, dtype=np.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
self.seed()
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
|
||||||
position = self.state[0]
|
position = self.state[0]
|
||||||
@@ -119,7 +114,8 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
self.state = np.array([position, velocity], dtype=np.float32)
|
self.state = np.array([position, velocity], dtype=np.float32)
|
||||||
return self.state, reward, done, {}
|
return self.state, reward, done, {}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||||
return np.array(self.state, dtype=np.float32)
|
return np.array(self.state, dtype=np.float32)
|
||||||
|
|
||||||
|
@@ -3,6 +3,7 @@ http://incompleteideas.net/MountainCar/MountainCar1.cp
|
|||||||
permalink: https://perma.cc/6Z2N-PFWC
|
permalink: https://perma.cc/6Z2N-PFWC
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -72,12 +73,6 @@ class MountainCarEnv(gym.Env):
|
|||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)
|
self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)
|
||||||
|
|
||||||
self.seed()
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert self.action_space.contains(
|
assert self.action_space.contains(
|
||||||
action
|
action
|
||||||
@@ -97,7 +92,8 @@ class MountainCarEnv(gym.Env):
|
|||||||
self.state = (position, velocity)
|
self.state = (position, velocity)
|
||||||
return np.array(self.state, dtype=np.float32), reward, done, {}
|
return np.array(self.state, dtype=np.float32), reward, done, {}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
|
||||||
return np.array(self.state, dtype=np.float32)
|
return np.array(self.state, dtype=np.float32)
|
||||||
|
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
@@ -24,12 +26,6 @@ class PendulumEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
|
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
|
||||||
|
|
||||||
self.seed()
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def step(self, u):
|
def step(self, u):
|
||||||
th, thdot = self.state # th := theta
|
th, thdot = self.state # th := theta
|
||||||
|
|
||||||
@@ -49,7 +45,8 @@ class PendulumEnv(gym.Env):
|
|||||||
self.state = np.array([newth, newthdot])
|
self.state = np.array([newth, newthdot])
|
||||||
return self._get_obs(), -costs, False, {}
|
return self._get_obs(), -costs, False, {}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
high = np.array([np.pi, 1])
|
high = np.array([np.pi, 1])
|
||||||
self.state = self.np_random.uniform(low=-high, high=high)
|
self.state = self.np_random.uniform(low=-high, high=high)
|
||||||
self.last_u = None
|
self.last_u = None
|
||||||
|
@@ -48,7 +48,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(
|
||||||
size=self.model.nq, low=-0.1, high=0.1
|
size=self.model.nq, low=-0.1, high=0.1
|
||||||
)
|
)
|
||||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -131,8 +131,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
low=noise_low, high=noise_high, size=self.model.nq
|
||||||
)
|
)
|
||||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
qvel = (
|
||||||
self.model.nv
|
self.init_qvel
|
||||||
|
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
|
||||||
)
|
)
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
|
@@ -31,7 +31,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(
|
||||||
low=-0.1, high=0.1, size=self.model.nq
|
low=-0.1, high=0.1, size=self.model.nq
|
||||||
)
|
)
|
||||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -74,8 +74,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
low=noise_low, high=noise_high, size=self.model.nq
|
||||||
)
|
)
|
||||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
qvel = (
|
||||||
self.model.nv
|
self.init_qvel
|
||||||
|
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
@@ -35,7 +35,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
self.set_state(
|
self.set_state(
|
||||||
self.init_qpos
|
self.init_qpos
|
||||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||||
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
|
self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1,
|
||||||
)
|
)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from gym import error, spaces
|
from gym import error, spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
@@ -73,8 +73,6 @@ class MujocoEnv(gym.Env):
|
|||||||
|
|
||||||
self._set_observation_space(observation)
|
self._set_observation_space(observation)
|
||||||
|
|
||||||
self.seed()
|
|
||||||
|
|
||||||
def _set_action_space(self):
|
def _set_action_space(self):
|
||||||
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
|
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
|
||||||
low, high = bounds.T
|
low, high = bounds.T
|
||||||
@@ -85,10 +83,6 @@ class MujocoEnv(gym.Env):
|
|||||||
self.observation_space = convert_observation_to_space(observation)
|
self.observation_space = convert_observation_to_space(observation)
|
||||||
return self.observation_space
|
return self.observation_space
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
# methods to override:
|
# methods to override:
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
|
||||||
@@ -109,7 +103,8 @@ class MujocoEnv(gym.Env):
|
|||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.sim.reset()
|
self.sim.reset()
|
||||||
ob = self.reset_model()
|
ob = self.reset_model()
|
||||||
return ob
|
return ob
|
||||||
|
@@ -183,7 +183,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
axis = np.array([0.0, 0.0, 1.0])
|
axis = np.array([0.0, 0.0, 1.0])
|
||||||
z_quat = quat_from_angle_and_axis(angle, axis)
|
z_quat = quat_from_angle_and_axis(angle, axis)
|
||||||
parallel_quat = self.parallel_quats[
|
parallel_quat = self.parallel_quats[
|
||||||
self.np_random.randint(len(self.parallel_quats))
|
self.np_random.integers(len(self.parallel_quats))
|
||||||
]
|
]
|
||||||
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
|
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
|
||||||
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
|
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
|
||||||
@@ -254,7 +254,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
axis = np.array([0.0, 0.0, 1.0])
|
axis = np.array([0.0, 0.0, 1.0])
|
||||||
target_quat = quat_from_angle_and_axis(angle, axis)
|
target_quat = quat_from_angle_and_axis(angle, axis)
|
||||||
parallel_quat = self.parallel_quats[
|
parallel_quat = self.parallel_quats[
|
||||||
self.np_random.randint(len(self.parallel_quats))
|
self.np_random.integers(len(self.parallel_quats))
|
||||||
]
|
]
|
||||||
target_quat = rotations.quat_mul(target_quat, parallel_quat)
|
target_quat = rotations.quat_mul(target_quat, parallel_quat)
|
||||||
elif self.target_rotation == "xyz":
|
elif self.target_rotation == "xyz":
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@@ -37,7 +39,6 @@ class RobotEnv(gym.GoalEnv):
|
|||||||
"video.frames_per_second": int(np.round(1.0 / self.dt)),
|
"video.frames_per_second": int(np.round(1.0 / self.dt)),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.seed()
|
|
||||||
self._env_setup(initial_qpos=initial_qpos)
|
self._env_setup(initial_qpos=initial_qpos)
|
||||||
self.initial_state = copy.deepcopy(self.sim.get_state())
|
self.initial_state = copy.deepcopy(self.sim.get_state())
|
||||||
|
|
||||||
@@ -65,10 +66,6 @@ class RobotEnv(gym.GoalEnv):
|
|||||||
# Env methods
|
# Env methods
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
if np.array(action).shape != self.action_space.shape:
|
if np.array(action).shape != self.action_space.shape:
|
||||||
raise ValueError("Action dimension mismatch")
|
raise ValueError("Action dimension mismatch")
|
||||||
@@ -86,13 +83,13 @@ class RobotEnv(gym.GoalEnv):
|
|||||||
reward = self.compute_reward(obs["achieved_goal"], self.goal, info)
|
reward = self.compute_reward(obs["achieved_goal"], self.goal, info)
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
# Attempt to reset the simulator. Since we randomize initial conditions, it
|
# Attempt to reset the simulator. Since we randomize initial conditions, it
|
||||||
# is possible to get into a state with numerical issues (e.g. due to penetration or
|
# is possible to get into a state with numerical issues (e.g. due to penetration or
|
||||||
# Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand).
|
# Gimbel lock) or we may not achieve an initial condition (e.g. an object is within the hand).
|
||||||
# In this case, we just keep randomizing until we eventually achieve a valid initial
|
# In this case, we just keep randomizing until we eventually achieve a valid initial
|
||||||
# configuration.
|
# configuration.
|
||||||
super().reset()
|
super().reset(seed=seed)
|
||||||
did_reset_sim = False
|
did_reset_sim = False
|
||||||
while not did_reset_sim:
|
while not did_reset_sim:
|
||||||
did_reset_sim = self._reset_sim()
|
did_reset_sim = self._reset_sim()
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
@@ -78,7 +80,6 @@ class BlackjackEnv(gym.Env):
|
|||||||
self.observation_space = spaces.Tuple(
|
self.observation_space = spaces.Tuple(
|
||||||
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
||||||
)
|
)
|
||||||
self.seed()
|
|
||||||
|
|
||||||
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
|
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
|
||||||
# Ref: http://www.bicyclecards.com/how-to-play/blackjack/
|
# Ref: http://www.bicyclecards.com/how-to-play/blackjack/
|
||||||
@@ -87,10 +88,6 @@ class BlackjackEnv(gym.Env):
|
|||||||
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
|
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
|
||||||
self.sab = sab
|
self.sab = sab
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert self.action_space.contains(action)
|
assert self.action_space.contains(action)
|
||||||
if action: # hit: add a card to players hand and return
|
if action: # hit: add a card to players hand and return
|
||||||
@@ -122,7 +119,8 @@ class BlackjackEnv(gym.Env):
|
|||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.dealer = draw_hand(self.np_random)
|
self.dealer = draw_hand(self.np_random)
|
||||||
self.player = draw_hand(self.np_random)
|
self.player = draw_hand(self.np_random)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym import Env, spaces
|
from gym import Env, spaces
|
||||||
@@ -11,7 +13,7 @@ def categorical_sample(prob_n, np_random):
|
|||||||
"""
|
"""
|
||||||
prob_n = np.asarray(prob_n)
|
prob_n = np.asarray(prob_n)
|
||||||
csprob_n = np.cumsum(prob_n)
|
csprob_n = np.cumsum(prob_n)
|
||||||
return (csprob_n > np_random.rand()).argmax()
|
return (csprob_n > np_random.random()).argmax()
|
||||||
|
|
||||||
|
|
||||||
class DiscreteEnv(Env):
|
class DiscreteEnv(Env):
|
||||||
@@ -40,14 +42,8 @@ class DiscreteEnv(Env):
|
|||||||
self.action_space = spaces.Discrete(self.nA)
|
self.action_space = spaces.Discrete(self.nA)
|
||||||
self.observation_space = spaces.Discrete(self.nS)
|
self.observation_space = spaces.Discrete(self.nS)
|
||||||
|
|
||||||
self.seed()
|
def reset(self, seed: Optional[int] = None):
|
||||||
self.s = categorical_sample(self.isd, self.np_random)
|
super().reset(seed=seed)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.s = categorical_sample(self.isd, self.np_random)
|
self.s = categorical_sample(self.isd, self.np_random)
|
||||||
self.lastaction = None
|
self.lastaction = None
|
||||||
return int(self.s)
|
return int(self.s)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
@@ -52,7 +54,6 @@ class CubeCrash(gym.Env):
|
|||||||
use_random_colors = False # Makes env too hard
|
use_random_colors = False # Makes env too hard
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed()
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
@@ -62,23 +63,20 @@ class CubeCrash(gym.Env):
|
|||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def random_color(self):
|
def random_color(self):
|
||||||
return np.array(
|
return np.array(
|
||||||
[
|
[
|
||||||
self.np_random.randint(low=0, high=255),
|
self.np_random.integers(low=0, high=255),
|
||||||
self.np_random.randint(low=0, high=255),
|
self.np_random.integers(low=0, high=255),
|
||||||
self.np_random.randint(low=0, high=255),
|
self.np_random.integers(low=0, high=255),
|
||||||
]
|
]
|
||||||
).astype("uint8")
|
).astype("uint8")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
self.cube_x = self.np_random.randint(low=3, high=FIELD_W - 3)
|
super().reset(seed=seed)
|
||||||
self.cube_y = self.np_random.randint(low=3, high=FIELD_H // 6)
|
self.cube_x = self.np_random.integers(low=3, high=FIELD_W - 3)
|
||||||
self.hole_x = self.np_random.randint(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
|
self.cube_y = self.np_random.integers(low=3, high=FIELD_H // 6)
|
||||||
|
self.hole_x = self.np_random.integers(low=HOLE_WIDTH, high=FIELD_W - HOLE_WIDTH)
|
||||||
self.bg_color = self.random_color() if self.use_random_colors else color_black
|
self.bg_color = self.random_color() if self.use_random_colors else color_black
|
||||||
self.potential = None
|
self.potential = None
|
||||||
self.step_n = 0
|
self.step_n = 0
|
||||||
@@ -95,6 +93,7 @@ class CubeCrash(gym.Env):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
return self.step(0)[0]
|
return self.step(0)[0]
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
@@ -60,7 +62,6 @@ class MemorizeDigits(gym.Env):
|
|||||||
use_random_colors = False
|
use_random_colors = False
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed()
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
||||||
@@ -74,22 +75,19 @@ class MemorizeDigits(gym.Env):
|
|||||||
]
|
]
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def random_color(self):
|
def random_color(self):
|
||||||
return np.array(
|
return np.array(
|
||||||
[
|
[
|
||||||
self.np_random.randint(low=0, high=255),
|
self.np_random.integers(low=0, high=255),
|
||||||
self.np_random.randint(low=0, high=255),
|
self.np_random.integers(low=0, high=255),
|
||||||
self.np_random.randint(low=0, high=255),
|
self.np_random.integers(low=0, high=255),
|
||||||
]
|
]
|
||||||
).astype("uint8")
|
).astype("uint8")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
self.digit_x = self.np_random.randint(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
|
super().reset(seed=seed)
|
||||||
self.digit_y = self.np_random.randint(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
|
self.digit_x = self.np_random.integers(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
|
||||||
|
self.digit_y = self.np_random.integers(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
|
||||||
self.color_bg = self.random_color() if self.use_random_colors else color_black
|
self.color_bg = self.random_color() if self.use_random_colors else color_black
|
||||||
self.step_n = 0
|
self.step_n = 0
|
||||||
while 1:
|
while 1:
|
||||||
@@ -111,8 +109,8 @@ class MemorizeDigits(gym.Env):
|
|||||||
else:
|
else:
|
||||||
if self.digit == action:
|
if self.digit == action:
|
||||||
reward = +1
|
reward = +1
|
||||||
done = self.step_n > 20 and 0 == self.np_random.randint(low=0, high=5)
|
done = self.step_n > 20 and 0 == self.np_random.integers(low=0, high=5)
|
||||||
self.digit = self.np_random.randint(low=0, high=10)
|
self.digit = self.np_random.integers(low=0, high=10)
|
||||||
obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
||||||
obs[:, :, :] = self.color_bg
|
obs[:, :, :] = self.color_bg
|
||||||
digit_img = np.zeros((6, 6, 3), dtype=np.uint8)
|
digit_img = np.zeros((6, 6, 3), dtype=np.uint8)
|
||||||
|
@@ -35,7 +35,7 @@ class MultiBinary(Space):
|
|||||||
super().__init__(input_n, np.int8, seed)
|
super().__init__(input_n, np.int8, seed)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
|
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||||
|
|
||||||
def contains(self, x):
|
def contains(self, x):
|
||||||
if isinstance(x, list) or isinstance(x, tuple):
|
if isinstance(x, list) or isinstance(x, tuple):
|
||||||
|
@@ -35,9 +35,7 @@ class MultiDiscrete(Space):
|
|||||||
super().__init__(self.nvec.shape, dtype, seed)
|
super().__init__(self.nvec.shape, dtype, seed)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||||
self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def contains(self, x):
|
def contains(self, x):
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
@@ -61,7 +59,7 @@ class MultiDiscrete(Space):
|
|||||||
subspace = Discrete(nvec)
|
subspace = Discrete(nvec)
|
||||||
else:
|
else:
|
||||||
subspace = MultiDiscrete(nvec, self.dtype)
|
subspace = MultiDiscrete(nvec, self.dtype)
|
||||||
subspace.np_random.set_state(self.np_random.get_state()) # for reproducibility
|
subspace.np_random.bit_generator.state = self.np_random.bit_generator.state
|
||||||
return subspace
|
return subspace
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@@ -1,43 +1,118 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import random as _random
|
|
||||||
import struct
|
import struct
|
||||||
import sys
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.random import Generator
|
||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
|
from gym.logger import deprecation
|
||||||
|
|
||||||
|
|
||||||
def np_random(seed=None):
|
def np_random(seed: Optional[int] = None):
|
||||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||||
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
||||||
|
|
||||||
seed = create_seed(seed)
|
seed_seq = np.random.SeedSequence(seed)
|
||||||
|
seed = seed_seq.entropy
|
||||||
rng = np.random.RandomState()
|
rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
|
||||||
rng.seed(_int_list_from_bigint(hash_seed(seed)))
|
|
||||||
return rng, seed
|
return rng, seed
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Remove this class and make it alias to `Generator` in a future Gym release
|
||||||
|
# RandomNumberGenerator = np.random.Generator
|
||||||
|
class RandomNumberGenerator(np.random.Generator):
|
||||||
|
def rand(self, *size):
|
||||||
|
deprecation(
|
||||||
|
"Function `rng.rand(*size)` is marked as deprecated "
|
||||||
|
"and will be removed in the future. "
|
||||||
|
"Please use `Generator.random(size)` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.random(size)
|
||||||
|
|
||||||
|
random_sample = rand
|
||||||
|
|
||||||
|
def randn(self, *size):
|
||||||
|
deprecation(
|
||||||
|
"Function `rng.randn(*size)` is marked as deprecated "
|
||||||
|
"and will be removed in the future. "
|
||||||
|
"Please use `rng.standard_normal(size)` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.standard_normal(size)
|
||||||
|
|
||||||
|
def randint(self, low, high=None, size=None, dtype=int):
|
||||||
|
deprecation(
|
||||||
|
"Function `rng.randint(low, [high, size, dtype])` is marked as deprecated "
|
||||||
|
"and will be removed in the future. "
|
||||||
|
"Please use `rng.integers(low, [high, size, dtype])` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.integers(low=low, high=high, size=size, dtype=dtype)
|
||||||
|
|
||||||
|
random_integers = randint
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
deprecation(
|
||||||
|
"Function `rng.get_state()` is marked as deprecated "
|
||||||
|
"and will be removed in the future. "
|
||||||
|
"Please use `rng.bit_generator.state` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.bit_generator.state
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
deprecation(
|
||||||
|
"Function `rng.set_state(state)` is marked as deprecated "
|
||||||
|
"and will be removed in the future. "
|
||||||
|
"Please use `rng.bit_generator.state = state` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bit_generator.state = state
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
deprecation(
|
||||||
|
"Function `rng.seed(seed)` is marked as deprecated "
|
||||||
|
"and will be removed in the future. "
|
||||||
|
"Please use `rng, seed = gym.utils.seeding.np_random(seed)` to create a separate generator instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bit_generator.state = type(self.bit_generator)(seed).state
|
||||||
|
|
||||||
|
rand.__doc__ = np.random.rand.__doc__
|
||||||
|
randn.__doc__ = np.random.randn.__doc__
|
||||||
|
randint.__doc__ = np.random.randint.__doc__
|
||||||
|
get_state.__doc__ = np.random.get_state.__doc__
|
||||||
|
set_state.__doc__ = np.random.set_state.__doc__
|
||||||
|
seed.__doc__ = np.random.seed.__doc__
|
||||||
|
|
||||||
|
|
||||||
|
RNG = RandomNumberGenerator
|
||||||
|
|
||||||
|
# Legacy functions
|
||||||
|
|
||||||
|
|
||||||
def hash_seed(seed=None, max_bytes=8):
|
def hash_seed(seed=None, max_bytes=8):
|
||||||
"""Any given evaluation is likely to have many PRNG's active at
|
"""Any given evaluation is likely to have many PRNG's active at
|
||||||
once. (Most commonly, because the environment is running in
|
once. (Most commonly, because the environment is running in
|
||||||
multiple processes.) There's literature indicating that having
|
multiple processes.) There's literature indicating that having
|
||||||
linear correlations between seeds of multiple PRNG's can correlate
|
linear correlations between seeds of multiple PRNG's can correlate
|
||||||
the outputs:
|
the outputs:
|
||||||
|
|
||||||
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
|
||||||
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
|
||||||
http://dl.acm.org/citation.cfm?id=1276928
|
http://dl.acm.org/citation.cfm?id=1276928
|
||||||
|
|
||||||
Thus, for sanity we hash the seeds before using them. (This scheme
|
Thus, for sanity we hash the seeds before using them. (This scheme
|
||||||
is likely not crypto-strength, but it should be good enough to get
|
is likely not crypto-strength, but it should be good enough to get
|
||||||
rid of simple correlations.)
|
rid of simple correlations.)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed (Optional[int]): None seeds from an operating system specific randomness source.
|
seed (Optional[int]): None seeds from an operating system specific randomness source.
|
||||||
max_bytes: Maximum number of bytes to use in the hashed seed.
|
max_bytes: Maximum number of bytes to use in the hashed seed.
|
||||||
"""
|
"""
|
||||||
|
deprecation(
|
||||||
|
"Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||||
|
)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = create_seed(max_bytes=max_bytes)
|
seed = create_seed(max_bytes=max_bytes)
|
||||||
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
|
hash = hashlib.sha512(str(seed).encode("utf8")).digest()
|
||||||
@@ -48,11 +123,13 @@ def create_seed(a=None, max_bytes=8):
|
|||||||
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
"""Create a strong random seed. Otherwise, Python 2 would seed using
|
||||||
the system time, which might be non-robust especially in the
|
the system time, which might be non-robust especially in the
|
||||||
presence of concurrency.
|
presence of concurrency.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (Optional[int, str]): None seeds from an operating system specific randomness source.
|
a (Optional[int, str]): None seeds from an operating system specific randomness source.
|
||||||
max_bytes: Maximum number of bytes to use in the seed.
|
max_bytes: Maximum number of bytes to use in the seed.
|
||||||
"""
|
"""
|
||||||
|
deprecation(
|
||||||
|
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||||
|
)
|
||||||
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
|
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
|
||||||
if a is None:
|
if a is None:
|
||||||
a = _bigint_from_bytes(os.urandom(max_bytes))
|
a = _bigint_from_bytes(os.urandom(max_bytes))
|
||||||
@@ -70,6 +147,9 @@ def create_seed(a=None, max_bytes=8):
|
|||||||
|
|
||||||
# TODO: don't hardcode sizeof_int here
|
# TODO: don't hardcode sizeof_int here
|
||||||
def _bigint_from_bytes(bytes):
|
def _bigint_from_bytes(bytes):
|
||||||
|
deprecation(
|
||||||
|
"Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. "
|
||||||
|
)
|
||||||
sizeof_int = 4
|
sizeof_int = 4
|
||||||
padding = sizeof_int - len(bytes) % sizeof_int
|
padding = sizeof_int - len(bytes) % sizeof_int
|
||||||
bytes += b"\0" * padding
|
bytes += b"\0" * padding
|
||||||
@@ -82,6 +162,9 @@ def _bigint_from_bytes(bytes):
|
|||||||
|
|
||||||
|
|
||||||
def _int_list_from_bigint(bigint):
|
def _int_list_from_bigint(bigint):
|
||||||
|
deprecation(
|
||||||
|
"Function `_int_list_from_bigint` is marked as deprecated and will be removed in the future. "
|
||||||
|
)
|
||||||
# Special case 0
|
# Special case 0
|
||||||
if bigint < 0:
|
if bigint < 0:
|
||||||
raise error.Error(f"Seed must be non-negative, not {bigint}")
|
raise error.Error(f"Seed must be non-negative, not {bigint}")
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import time
|
import time
|
||||||
@@ -6,6 +8,7 @@ from enum import Enum
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from gym import logger
|
from gym import logger
|
||||||
|
from gym.logger import warn
|
||||||
from gym.vector.vector_env import VectorEnv
|
from gym.vector.vector_env import VectorEnv
|
||||||
from gym.error import (
|
from gym.error import (
|
||||||
AlreadyPendingCallError,
|
AlreadyPendingCallError,
|
||||||
@@ -187,13 +190,14 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
self._check_observation_spaces()
|
self._check_observation_spaces()
|
||||||
|
|
||||||
def seed(self, seeds=None):
|
def seed(self, seed=None):
|
||||||
|
super().seed(seed=seed)
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if seeds is None:
|
if seed is None:
|
||||||
seeds = [None for _ in range(self.num_envs)]
|
seed = [None for _ in range(self.num_envs)]
|
||||||
if isinstance(seeds, int):
|
if isinstance(seed, int):
|
||||||
seeds = [seeds + i for i in range(self.num_envs)]
|
seed = [seed + i for i in range(self.num_envs)]
|
||||||
assert len(seeds) == self.num_envs
|
assert len(seed) == self.num_envs
|
||||||
|
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
raise AlreadyPendingCallError(
|
raise AlreadyPendingCallError(
|
||||||
@@ -201,12 +205,12 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._state.value,
|
self._state.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
for pipe, seed in zip(self.parent_pipes, seeds):
|
for pipe, seed in zip(self.parent_pipes, seed):
|
||||||
pipe.send(("seed", seed))
|
pipe.send(("seed", seed))
|
||||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||||
self._raise_if_errors(successes)
|
self._raise_if_errors(successes)
|
||||||
|
|
||||||
def reset_async(self):
|
def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
|
||||||
"""Send the calls to :obj:`reset` to each sub-environment.
|
"""Send the calls to :obj:`reset` to each sub-environment.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@@ -221,24 +225,31 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
between.
|
between.
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = [None for _ in range(self.num_envs)]
|
||||||
|
if isinstance(seed, int):
|
||||||
|
seed = [seed + i for i in range(self.num_envs)]
|
||||||
|
assert len(seed) == self.num_envs
|
||||||
|
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
raise AlreadyPendingCallError(
|
raise AlreadyPendingCallError(
|
||||||
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
|
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
|
||||||
self._state.value,
|
self._state.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
for pipe in self.parent_pipes:
|
for pipe, single_seed in zip(self.parent_pipes, seed):
|
||||||
pipe.send(("reset", None))
|
pipe.send(("reset", single_seed))
|
||||||
self._state = AsyncState.WAITING_RESET
|
self._state = AsyncState.WAITING_RESET
|
||||||
|
|
||||||
def reset_wait(self, timeout=None):
|
def reset_wait(self, timeout=None, seed: Optional[int] = None):
|
||||||
"""Wait for the calls to :obj:`reset` in each sub-environment to finish.
|
"""
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
timeout : int or float, optional
|
timeout : int or float, optional
|
||||||
Number of seconds before the call to :meth:`reset_wait` times out.
|
Number of seconds before the call to `reset_wait` times out. If
|
||||||
If ``None``, the call to :meth:`reset_wait` never times out.
|
`None`, the call to `reset_wait` never times out.
|
||||||
|
seed: ignored
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -486,7 +497,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|||||||
while True:
|
while True:
|
||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
observation = env.reset()
|
observation = env.reset(data)
|
||||||
pipe.send((observation, True))
|
pipe.send((observation, True))
|
||||||
elif command == "step":
|
elif command == "step":
|
||||||
observation, reward, done, info = env.step(data)
|
observation, reward, done, info = env.step(data)
|
||||||
@@ -524,7 +535,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
|||||||
while True:
|
while True:
|
||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
observation = env.reset()
|
observation = env.reset(data)
|
||||||
write_to_shared_memory(
|
write_to_shared_memory(
|
||||||
index, observation, shared_memory, observation_space
|
index, observation, shared_memory, observation_space
|
||||||
)
|
)
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
|
from typing import List, Union, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from gym import logger
|
from gym import logger
|
||||||
|
from gym.logger import warn
|
||||||
from gym.vector.vector_env import VectorEnv
|
from gym.vector.vector_env import VectorEnv
|
||||||
from gym.vector.utils import concatenate, create_empty_array
|
from gym.vector.utils import concatenate, create_empty_array
|
||||||
|
|
||||||
@@ -72,21 +75,28 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||||
self._actions = None
|
self._actions = None
|
||||||
|
|
||||||
def seed(self, seeds=None):
|
def seed(self, seed=None):
|
||||||
if seeds is None:
|
super().seed(seed=seed)
|
||||||
seeds = [None for _ in range(self.num_envs)]
|
if seed is None:
|
||||||
if isinstance(seeds, int):
|
seed = [None for _ in range(self.num_envs)]
|
||||||
seeds = [seeds + i for i in range(self.num_envs)]
|
if isinstance(seed, int):
|
||||||
assert len(seeds) == self.num_envs
|
seed = [seed + i for i in range(self.num_envs)]
|
||||||
|
assert len(seed) == self.num_envs
|
||||||
|
|
||||||
for env, seed in zip(self.envs, seeds):
|
for env, single_seed in zip(self.envs, seed):
|
||||||
env.seed(seed)
|
env.seed(single_seed)
|
||||||
|
|
||||||
|
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
|
||||||
|
if seed is None:
|
||||||
|
seed = [None for _ in range(self.num_envs)]
|
||||||
|
if isinstance(seed, int):
|
||||||
|
seed = [seed + i for i in range(self.num_envs)]
|
||||||
|
assert len(seed) == self.num_envs
|
||||||
|
|
||||||
def reset_wait(self):
|
|
||||||
self._dones[:] = False
|
self._dones[:] = False
|
||||||
observations = []
|
observations = []
|
||||||
for env in self.envs:
|
for env, single_seed in zip(self.envs, seed):
|
||||||
observation = env.reset()
|
observation = env.reset(seed=single_seed)
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
observations, self.observations, self.single_observation_space
|
observations, self.observations, self.single_observation_space
|
||||||
|
@@ -1,4 +1,7 @@
|
|||||||
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
from gym.logger import warn, deprecation
|
||||||
from gym.spaces import Tuple
|
from gym.spaces import Tuple
|
||||||
from gym.vector.utils.spaces import batch_space
|
from gym.vector.utils.spaces import batch_space
|
||||||
|
|
||||||
@@ -43,13 +46,13 @@ class VectorEnv(gym.Env):
|
|||||||
self.single_observation_space = observation_space
|
self.single_observation_space = observation_space
|
||||||
self.single_action_space = action_space
|
self.single_action_space = action_space
|
||||||
|
|
||||||
def reset_async(self):
|
def reset_async(self, seed: Optional[Union[int, List[int]]] = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def reset_wait(self, **kwargs):
|
def reset_wait(self, seed: Optional[Union[int, List[int]]] = None, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[Union[int, List[int]]] = None):
|
||||||
r"""Reset all sub-environments and return a batch of initial observations.
|
r"""Reset all sub-environments and return a batch of initial observations.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -57,8 +60,8 @@ class VectorEnv(gym.Env):
|
|||||||
element of :attr:`observation_space`
|
element of :attr:`observation_space`
|
||||||
A batch of observations from the vectorized environment.
|
A batch of observations from the vectorized environment.
|
||||||
"""
|
"""
|
||||||
self.reset_async()
|
self.reset_async(seed=seed)
|
||||||
return self.reset_wait()
|
return self.reset_wait(seed=seed)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
pass
|
pass
|
||||||
@@ -120,19 +123,22 @@ class VectorEnv(gym.Env):
|
|||||||
self.close_extras(**kwargs)
|
self.close_extras(**kwargs)
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
def seed(self, seeds=None):
|
def seed(self, seed=None):
|
||||||
"""Set the random seed in all sub-environments.
|
"""Set the random seed in all sub-environments.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
seeds : list of int, or int, optional
|
seed : list of int, or int, optional
|
||||||
Random seed for each sub-environment. If ``seeds`` is a list of
|
Random seed for each sub-environment. If ``seed`` is a list of
|
||||||
length ``num_envs``, then the items of the list are chosen as random
|
length ``num_envs``, then the items of the list are chosen as random
|
||||||
seeds. If ``seeds`` is an int, then each sub-environment uses the random
|
seeds. If ``seed`` is an int, then each sub-environment uses the random
|
||||||
seed ``seeds + n``, where ``n`` is the index of the sub-environment
|
seed ``seed + n``, where ``n`` is the index of the sub-environment
|
||||||
(between ``0`` and ``num_envs - 1``).
|
(between ``0`` and ``num_envs - 1``).
|
||||||
"""
|
"""
|
||||||
pass
|
deprecation(
|
||||||
|
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||||
|
"Please use `env.reset(seed=seed) instead in VectorEnvs."
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if not getattr(self, "closed", True):
|
if not getattr(self, "closed", True):
|
||||||
@@ -164,11 +170,11 @@ class VectorEnvWrapper(VectorEnv):
|
|||||||
|
|
||||||
# explicitly forward the methods defined in VectorEnv
|
# explicitly forward the methods defined in VectorEnv
|
||||||
# to self.env (instead of the base class)
|
# to self.env (instead of the base class)
|
||||||
def reset_async(self):
|
def reset_async(self, **kwargs):
|
||||||
return self.env.reset_async()
|
return self.env.reset_async(**kwargs)
|
||||||
|
|
||||||
def reset_wait(self):
|
def reset_wait(self, **kwargs):
|
||||||
return self.env.reset_wait()
|
return self.env.reset_wait(**kwargs)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
return self.env.step_async(actions)
|
return self.env.step_async(actions)
|
||||||
@@ -182,8 +188,8 @@ class VectorEnvWrapper(VectorEnv):
|
|||||||
def close_extras(self, **kwargs):
|
def close_extras(self, **kwargs):
|
||||||
return self.env.close_extras(**kwargs)
|
return self.env.close_extras(**kwargs)
|
||||||
|
|
||||||
def seed(self, seeds=None):
|
def seed(self, seed=None):
|
||||||
return self.env.seed(seeds)
|
return self.env.seed(seed)
|
||||||
|
|
||||||
# implicitly forward all other methods and attributes to self.env
|
# implicitly forward all other methods and attributes to self.env
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
@@ -127,11 +129,11 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
return self._get_obs(), R, done, info
|
return self._get_obs(), R, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
# NoopReset
|
# NoopReset
|
||||||
self.env.reset(**kwargs)
|
self.env.reset(seed=seed, **kwargs)
|
||||||
noops = (
|
noops = (
|
||||||
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
|
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||||
if self.noop_max > 0
|
if self.noop_max > 0
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym import ObservationWrapper
|
from gym import ObservationWrapper
|
||||||
@@ -116,7 +118,7 @@ class FrameStack(ObservationWrapper):
|
|||||||
self.frames.append(observation)
|
self.frames.append(observation)
|
||||||
return self.observation(), reward, done, info
|
return self.observation(), reward, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
observation = self.env.reset(**kwargs)
|
observation = self.env.reset(seed=seed, **kwargs)
|
||||||
[self.frames.append(observation) for _ in range(self.num_stack)]
|
[self.frames.append(observation) for _ in range(self.num_stack)]
|
||||||
return self.observation()
|
return self.observation()
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -50,9 +51,9 @@ class Monitor(Wrapper):
|
|||||||
|
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
self._before_reset()
|
self._before_reset()
|
||||||
observation = self.env.reset(**kwargs)
|
observation = self.env.reset(seed=seed, **kwargs)
|
||||||
self._after_reset(observation)
|
self._after_reset(observation)
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
@@ -61,8 +63,8 @@ class NormalizeObservation(gym.core.Wrapper):
|
|||||||
obs = self.normalize(np.array([obs]))[0]
|
obs = self.normalize(np.array([obs]))[0]
|
||||||
return obs, rews, dones, infos
|
return obs, rews, dones, infos
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
obs = self.env.reset()
|
obs = self.env.reset(seed=seed)
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
obs = self.normalize(obs)
|
obs = self.normalize(obs)
|
||||||
else:
|
else:
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
@@ -11,6 +13,6 @@ class OrderEnforcing(gym.Wrapper):
|
|||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
self._has_reset = True
|
self._has_reset = True
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(seed=seed, **kwargs)
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
@@ -16,8 +18,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
self.length_queue = deque(maxlen=deque_size)
|
self.length_queue = deque(maxlen=deque_size)
|
||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
observations = super().reset(**kwargs)
|
observations = super().reset(seed=seed, **kwargs)
|
||||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
||||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
||||||
return observations
|
return observations
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import gym
|
import gym
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from gym import logger
|
from gym import logger
|
||||||
from gym.wrappers.monitoring import video_recorder
|
from gym.wrappers.monitoring import video_recorder
|
||||||
@@ -52,8 +52,8 @@ class RecordVideo(gym.Wrapper):
|
|||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
self.episode_id = 0
|
self.episode_id = 0
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
observations = super().reset(**kwargs)
|
observations = super().reset(seed=seed, **kwargs)
|
||||||
if not self.recording and self._video_enabled():
|
if not self.recording and self._video_enabled():
|
||||||
self.start_video_recorder()
|
self.start_video_recorder()
|
||||||
return observations
|
return observations
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym import ObservationWrapper
|
from gym import ObservationWrapper
|
||||||
@@ -27,6 +29,6 @@ class TimeAwareObservation(ObservationWrapper):
|
|||||||
self.t += 1
|
self.t += 1
|
||||||
return super().step(action)
|
return super().step(action)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
self.t = 0
|
self.t = 0
|
||||||
return super().reset(**kwargs)
|
return super().reset(seed=seed, **kwargs)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
@@ -22,6 +24,6 @@ class TimeLimit(gym.Wrapper):
|
|||||||
done = True
|
done = True
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, seed: Optional[int] = None, **kwargs):
|
||||||
self._elapsed_steps = 0
|
self._elapsed_steps = 0
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(seed=seed, **kwargs)
|
||||||
|
@@ -15,6 +15,6 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mjpro150/bin
|
|||||||
COPY . /usr/local/gym/
|
COPY . /usr/local/gym/
|
||||||
WORKDIR /usr/local/gym/
|
WORKDIR /usr/local/gym/
|
||||||
|
|
||||||
RUN pip install .[nomujoco,accept-rom-license] && pip install -r test_requirements.txt
|
RUN pip install .[nomujoco] && pip install -r test_requirements.txt
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
|
||||||
|
2
setup.py
2
setup.py
@@ -21,7 +21,7 @@ extras = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Meta dependency groups.
|
# Meta dependency groups.
|
||||||
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license"])
|
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license", "atari"])
|
||||||
nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
|
nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
|
||||||
|
|
||||||
extras["nomujoco"] = list(
|
extras["nomujoco"] = list(
|
||||||
|
@@ -10,16 +10,14 @@ def test_env(spec):
|
|||||||
# threads. However, we probably already can't do multithreading
|
# threads. However, we probably already can't do multithreading
|
||||||
# due to some environments.
|
# due to some environments.
|
||||||
env1 = spec.make()
|
env1 = spec.make()
|
||||||
env1.seed(0)
|
initial_observation1 = env1.reset(seed=0)
|
||||||
initial_observation1 = env1.reset()
|
|
||||||
env1.action_space.seed(0)
|
env1.action_space.seed(0)
|
||||||
action_samples1 = [env1.action_space.sample() for i in range(4)]
|
action_samples1 = [env1.action_space.sample() for i in range(4)]
|
||||||
step_responses1 = [env1.step(action) for action in action_samples1]
|
step_responses1 = [env1.step(action) for action in action_samples1]
|
||||||
env1.close()
|
env1.close()
|
||||||
|
|
||||||
env2 = spec.make()
|
env2 = spec.make()
|
||||||
env2.seed(0)
|
initial_observation2 = env2.reset(seed=0)
|
||||||
initial_observation2 = env2.reset()
|
|
||||||
env2.action_space.seed(0)
|
env2.action_space.seed(0)
|
||||||
action_samples2 = [env2.action_space.sample() for i in range(4)]
|
action_samples2 = [env2.action_space.sample() for i in range(4)]
|
||||||
step_responses2 = [env2.step(action) for action in action_samples2]
|
step_responses2 = [env2.step(action) for action in action_samples2]
|
||||||
|
@@ -10,11 +10,8 @@ def verify_environments_match(
|
|||||||
old_environment = envs.make(old_environment_id)
|
old_environment = envs.make(old_environment_id)
|
||||||
new_environment = envs.make(new_environment_id)
|
new_environment = envs.make(new_environment_id)
|
||||||
|
|
||||||
old_environment.seed(seed)
|
old_reset_observation = old_environment.reset(seed=seed)
|
||||||
new_environment.seed(seed)
|
new_reset_observation = new_environment.reset(seed=seed)
|
||||||
|
|
||||||
old_reset_observation = old_environment.reset()
|
|
||||||
new_reset_observation = new_environment.reset()
|
|
||||||
|
|
||||||
np.testing.assert_allclose(old_reset_observation, new_reset_observation)
|
np.testing.assert_allclose(old_reset_observation, new_reset_observation)
|
||||||
|
|
||||||
|
@@ -370,7 +370,7 @@ def test_seed_subspace_incorrelated(space):
|
|||||||
|
|
||||||
space.seed(0)
|
space.seed(0)
|
||||||
states = [
|
states = [
|
||||||
convert_sample_hashable(subspace.np_random.get_state())
|
convert_sample_hashable(subspace.np_random.bit_generator.state)
|
||||||
for subspace in subspaces
|
for subspace in subspaces
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -16,7 +18,8 @@ class UnittestEnv(core.Env):
|
|||||||
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
||||||
action_space = spaces.Discrete(3)
|
action_space = spaces.Discrete(3)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
return self.observation_space.sample() # Dummy observation
|
return self.observation_space.sample() # Dummy observation
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@@ -31,7 +34,8 @@ class UnknownSpacesEnv(core.Env):
|
|||||||
on external resources), it is not encouraged.
|
on external resources), it is not encouraged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||||
)
|
)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -16,7 +18,8 @@ class ActionDictTestEnv(gym.Env):
|
|||||||
done = True
|
done = True
|
||||||
return observation, reward, done
|
return observation, reward, done
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
return np.array([1.0, 1.5, 0.5])
|
return np.array([1.0, 1.5, 0.5])
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
|
@@ -17,17 +17,14 @@ def test_vector_env_equal(shared_memory):
|
|||||||
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
sync_env = SyncVectorEnv(env_fns)
|
sync_env = SyncVectorEnv(env_fns)
|
||||||
|
|
||||||
async_env.seed(0)
|
|
||||||
sync_env.seed(0)
|
|
||||||
|
|
||||||
assert async_env.num_envs == sync_env.num_envs
|
assert async_env.num_envs == sync_env.num_envs
|
||||||
assert async_env.observation_space == sync_env.observation_space
|
assert async_env.observation_space == sync_env.observation_space
|
||||||
assert async_env.single_observation_space == sync_env.single_observation_space
|
assert async_env.single_observation_space == sync_env.single_observation_space
|
||||||
assert async_env.action_space == sync_env.action_space
|
assert async_env.action_space == sync_env.action_space
|
||||||
assert async_env.single_action_space == sync_env.single_action_space
|
assert async_env.single_action_space == sync_env.single_action_space
|
||||||
|
|
||||||
async_observations = async_env.reset()
|
async_observations = async_env.reset(seed=0)
|
||||||
sync_observations = sync_env.reset()
|
sync_observations = sync_env.reset(seed=0)
|
||||||
assert np.all(async_observations == sync_observations)
|
assert np.all(async_observations == sync_observations)
|
||||||
|
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
|
@@ -8,7 +8,7 @@ class DummyWrapper(VectorEnvWrapper):
|
|||||||
self.env = env
|
self.env = env
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
def reset_async(self):
|
def reset_async(self, **kwargs):
|
||||||
super().reset_async()
|
super().reset_async()
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
|
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
import time
|
import time
|
||||||
@@ -55,7 +57,8 @@ class UnittestSlowEnv(gym.Env):
|
|||||||
)
|
)
|
||||||
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
|
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
if self.slow_reset > 0:
|
if self.slow_reset > 0:
|
||||||
time.sleep(self.slow_reset)
|
time.sleep(self.slow_reset)
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample()
|
||||||
@@ -86,7 +89,8 @@ class CustomSpaceEnv(gym.Env):
|
|||||||
self.observation_space = CustomSpace()
|
self.observation_space = CustomSpace()
|
||||||
self.action_space = CustomSpace()
|
self.action_space = CustomSpace()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
return "reset"
|
return "reset"
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@@ -98,7 +102,7 @@ class CustomSpaceEnv(gym.Env):
|
|||||||
def make_env(env_name, seed):
|
def make_env(env_name, seed):
|
||||||
def _make():
|
def _make():
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
env.seed(seed)
|
env.reset(seed=seed)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
return _make
|
return _make
|
||||||
@@ -107,7 +111,7 @@ def make_env(env_name, seed):
|
|||||||
def make_slow_env(slow_reset, seed):
|
def make_slow_env(slow_reset, seed):
|
||||||
def _make():
|
def _make():
|
||||||
env = UnittestSlowEnv(slow_reset=slow_reset)
|
env = UnittestSlowEnv(slow_reset=slow_reset)
|
||||||
env.seed(seed)
|
env.reset(seed=seed)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
return _make
|
return _make
|
||||||
@@ -116,7 +120,7 @@ def make_slow_env(slow_reset, seed):
|
|||||||
def make_custom_space_env(seed):
|
def make_custom_space_env(seed):
|
||||||
def _make():
|
def _make():
|
||||||
env = CustomSpaceEnv()
|
env = CustomSpaceEnv()
|
||||||
env.seed(seed)
|
env.reset(seed=seed)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
return _make
|
return _make
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for the flatten observation wrapper."""
|
"""Tests for the flatten observation wrapper."""
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -14,7 +15,8 @@ class FakeEnvironment(gym.Env):
|
|||||||
def __init__(self, observation_space):
|
def __init__(self, observation_space):
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.observation = self.observation_space.sample()
|
self.observation = self.observation_space.sample()
|
||||||
return self.observation
|
return self.observation
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
"""Tests for the filter observation wrapper."""
|
"""Tests for the filter observation wrapper."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -21,7 +21,8 @@ class FakeEnvironment(gym.Env):
|
|||||||
image_shape = (height, width, 3)
|
image_shape = (height, width, 3)
|
||||||
return np.zeros(image_shape, dtype=np.uint8)
|
return np.zeros(image_shape, dtype=np.uint8)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
@@ -29,14 +29,10 @@ def test_atari_preprocessing_grayscale(env_fn):
|
|||||||
noop_max=0,
|
noop_max=0,
|
||||||
grayscale_newaxis=True,
|
grayscale_newaxis=True,
|
||||||
)
|
)
|
||||||
env1.seed(0)
|
obs1 = env1.reset(seed=0)
|
||||||
env2.seed(0)
|
obs2 = env2.reset(seed=0)
|
||||||
env3.seed(0)
|
obs3 = env3.reset(seed=0)
|
||||||
env4.seed(0)
|
obs4 = env4.reset(seed=0)
|
||||||
obs1 = env1.reset()
|
|
||||||
obs2 = env2.reset()
|
|
||||||
obs3 = env3.reset()
|
|
||||||
obs4 = env4.reset()
|
|
||||||
assert env1.observation_space.shape == (210, 160, 3)
|
assert env1.observation_space.shape == (210, 160, 3)
|
||||||
assert env2.observation_space.shape == (84, 84)
|
assert env2.observation_space.shape == (84, 84)
|
||||||
assert env3.observation_space.shape == (84, 84, 3)
|
assert env3.observation_space.shape == (84, 84, 3)
|
||||||
|
@@ -11,11 +11,9 @@ def test_clip_action():
|
|||||||
wrapped_env = ClipAction(make_env())
|
wrapped_env = ClipAction(make_env())
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
env.seed(seed)
|
|
||||||
wrapped_env.seed(seed)
|
|
||||||
|
|
||||||
env.reset()
|
env.reset(seed=seed)
|
||||||
wrapped_env.reset()
|
wrapped_env.reset(seed=seed)
|
||||||
|
|
||||||
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
||||||
for action in actions:
|
for action in actions:
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -22,7 +24,8 @@ class FakeEnvironment(gym.Env):
|
|||||||
image_shape = (height, width, 3)
|
image_shape = (height, width, 3)
|
||||||
return np.zeros(image_shape, dtype=np.uint8)
|
return np.zeros(image_shape, dtype=np.uint8)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
@@ -28,17 +28,15 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
def test_frame_stack(env_id, num_stack, lz4_compress):
|
def test_frame_stack(env_id, num_stack, lz4_compress):
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
env.seed(0)
|
|
||||||
shape = env.observation_space.shape
|
shape = env.observation_space.shape
|
||||||
env = FrameStack(env, num_stack, lz4_compress)
|
env = FrameStack(env, num_stack, lz4_compress)
|
||||||
assert env.observation_space.shape == (num_stack,) + shape
|
assert env.observation_space.shape == (num_stack,) + shape
|
||||||
assert env.observation_space.dtype == env.env.observation_space.dtype
|
assert env.observation_space.dtype == env.env.observation_space.dtype
|
||||||
|
|
||||||
dup = gym.make(env_id)
|
dup = gym.make(env_id)
|
||||||
dup.seed(0)
|
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset(seed=0)
|
||||||
dup_obs = dup.reset()
|
dup_obs = dup.reset(seed=0)
|
||||||
assert np.allclose(obs[-1], dup_obs)
|
assert np.allclose(obs[-1], dup_obs)
|
||||||
|
|
||||||
for _ in range(num_stack ** 2):
|
for _ in range(num_stack ** 2):
|
||||||
|
@@ -21,11 +21,9 @@ def test_gray_scale_observation(env_id, keep_dim):
|
|||||||
assert rgb_env.observation_space.shape[-1] == 3
|
assert rgb_env.observation_space.shape[-1] == 3
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
gray_env.seed(seed)
|
|
||||||
wrapped_env.seed(seed)
|
|
||||||
|
|
||||||
gray_obs = gray_env.reset()
|
gray_obs = gray_env.reset(seed=seed)
|
||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs = wrapped_env.reset(seed=seed)
|
||||||
|
|
||||||
if keep_dim:
|
if keep_dim:
|
||||||
assert wrapped_env.observation_space.shape[-1] == 1
|
assert wrapped_env.observation_space.shape[-1] == 1
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.testing import assert_almost_equal
|
from numpy.testing import assert_almost_equal
|
||||||
@@ -21,7 +23,8 @@ class DummyRewardEnv(gym.Env):
|
|||||||
self.t += 1
|
self.t += 1
|
||||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
self.t = self.return_reward_idx
|
self.t = self.return_reward_idx
|
||||||
return np.array([self.t])
|
return np.array([self.t])
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
"""Tests for the pixel observation wrapper."""
|
"""Tests for the pixel observation wrapper."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -19,7 +19,8 @@ class FakeEnvironment(gym.Env):
|
|||||||
image_shape = (height, width, 3)
|
image_shape = (height, width, 3)
|
||||||
return np.zeros(image_shape, dtype=np.uint8)
|
return np.zeros(image_shape, dtype=np.uint8)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed: Optional[int] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
@@ -16,11 +16,9 @@ def test_rescale_action():
|
|||||||
wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1)
|
wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1)
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
env.seed(seed)
|
|
||||||
wrapped_env.seed(seed)
|
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset(seed=seed)
|
||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs = wrapped_env.reset(seed=seed)
|
||||||
assert np.allclose(obs, wrapped_obs)
|
assert np.allclose(obs, wrapped_obs)
|
||||||
|
|
||||||
obs, reward, _, _ = env.step([1.5])
|
obs, reward, _, _ = env.step([1.5])
|
||||||
|
@@ -14,11 +14,8 @@ def test_transform_observation(env_id):
|
|||||||
gym.make(env_id), lambda obs: affine_transform(obs)
|
gym.make(env_id), lambda obs: affine_transform(obs)
|
||||||
)
|
)
|
||||||
|
|
||||||
env.seed(0)
|
obs = env.reset(seed=0)
|
||||||
wrapped_env.seed(0)
|
wrapped_obs = wrapped_env.reset(seed=0)
|
||||||
|
|
||||||
obs = env.reset()
|
|
||||||
wrapped_obs = wrapped_env.reset()
|
|
||||||
assert np.allclose(wrapped_obs, affine_transform(obs))
|
assert np.allclose(wrapped_obs, affine_transform(obs))
|
||||||
|
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
|
@@ -15,10 +15,8 @@ def test_transform_reward(env_id):
|
|||||||
wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r)
|
wrapped_env = TransformReward(gym.make(env_id), lambda r: scale * r)
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
|
|
||||||
env.seed(0)
|
env.reset(seed=0)
|
||||||
env.reset()
|
wrapped_env.reset(seed=0)
|
||||||
wrapped_env.seed(0)
|
|
||||||
wrapped_env.reset()
|
|
||||||
|
|
||||||
_, reward, _, _ = env.step(action)
|
_, reward, _, _ = env.step(action)
|
||||||
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
||||||
@@ -33,10 +31,8 @@ def test_transform_reward(env_id):
|
|||||||
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r))
|
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.clip(r, min_r, max_r))
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
|
|
||||||
env.seed(0)
|
env.reset(seed=0)
|
||||||
env.reset()
|
wrapped_env.reset(seed=0)
|
||||||
wrapped_env.seed(0)
|
|
||||||
wrapped_env.reset()
|
|
||||||
|
|
||||||
_, reward, _, _ = env.step(action)
|
_, reward, _, _ = env.step(action)
|
||||||
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
||||||
@@ -49,10 +45,8 @@ def test_transform_reward(env_id):
|
|||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r))
|
wrapped_env = TransformReward(gym.make(env_id), lambda r: np.sign(r))
|
||||||
|
|
||||||
env.seed(0)
|
env.reset(seed=0)
|
||||||
env.reset()
|
wrapped_env.reset(seed=0)
|
||||||
wrapped_env.seed(0)
|
|
||||||
wrapped_env.reset()
|
|
||||||
|
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
|
Reference in New Issue
Block a user