mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-01 10:27:43 +00:00
Make Tuple and Dicts be seedable with lists and dicts of seeds + make the seed in default initialization controllable (#1774)
* Make the seed in default initialization controllable Since seed() is being called in default initialization of Space, it should be controllable for reproducibility. * Updated derived classes of Space to have their seeds controllable at initialization. * Allow Tuple's spaces to each have their own seed * Added dict based seeding for Dict space; test cases for Tuple and Dict seeding * Update discrete.py * Update test_spaces.py * Add seed to __init__() * blacked * Fix black * Fix failing tests
This commit is contained in:
@@ -23,7 +23,7 @@ class Box(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, low, high, shape=None, dtype=np.float32):
|
||||
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None):
|
||||
assert dtype is not None, "dtype must be explicitly provided. "
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
@@ -81,7 +81,7 @@ class Box(Space):
|
||||
self.bounded_below = -np.inf < self.low
|
||||
self.bounded_above = np.inf > self.high
|
||||
|
||||
super(Box, self).__init__(self.shape, self.dtype)
|
||||
super(Box, self).__init__(self.shape, self.dtype, seed)
|
||||
|
||||
def is_bounded(self, manner="both"):
|
||||
below = np.all(self.bounded_below)
|
||||
|
@@ -33,10 +33,11 @@ class Dict(Space):
|
||||
})
|
||||
"""
|
||||
|
||||
def __init__(self, spaces=None, **spaces_kwargs):
|
||||
def __init__(self, spaces=None, seed=None, **spaces_kwargs):
|
||||
assert (spaces is None) or (
|
||||
not spaces_kwargs
|
||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||
|
||||
if spaces is None:
|
||||
spaces = spaces_kwargs
|
||||
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
||||
@@ -49,11 +50,23 @@ class Dict(Space):
|
||||
space, Space
|
||||
), "Values of the dict should be instances of gym.Space"
|
||||
super(Dict, self).__init__(
|
||||
None, None
|
||||
None, None, seed
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
def seed(self, seed=None):
|
||||
seed = super().seed(seed)
|
||||
seeds = []
|
||||
if isinstance(seed, dict):
|
||||
for key, seed_key in zip(self.spaces, seed):
|
||||
assert key == seed_key, print(
|
||||
"Key value",
|
||||
seed_key,
|
||||
"in passed seed dict did not match key value",
|
||||
key,
|
||||
"in spaces Dict.",
|
||||
)
|
||||
seeds += self.spaces[key].seed(seed[seed_key])
|
||||
elif isinstance(seed, int):
|
||||
seeds = super().seed(seed)
|
||||
try:
|
||||
subseeds = self.np_random.choice(
|
||||
np.iinfo(int).max,
|
||||
@@ -68,9 +81,14 @@ class Dict(Space):
|
||||
)
|
||||
|
||||
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
||||
seed.append(subspace.seed(int(subseed))[0])
|
||||
seeds.append(subspace.seed(int(subseed))[0])
|
||||
elif seed is None:
|
||||
for space in self.spaces.values():
|
||||
seeds += space.seed(seed)
|
||||
else:
|
||||
raise TypeError("Passed seed not of an expected type: dict or int or None")
|
||||
|
||||
return seed
|
||||
return seeds
|
||||
|
||||
def sample(self):
|
||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||
|
@@ -11,10 +11,10 @@ class Discrete(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
def __init__(self, n, seed=None):
|
||||
assert n >= 0
|
||||
self.n = n
|
||||
super(Discrete, self).__init__((), np.int64)
|
||||
super(Discrete, self).__init__((), np.int64, seed)
|
||||
|
||||
def sample(self):
|
||||
return self.np_random.randint(self.n)
|
||||
|
@@ -26,13 +26,13 @@ class MultiBinary(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
def __init__(self, n, seed=None):
|
||||
self.n = n
|
||||
if type(n) in [tuple, list, np.ndarray]:
|
||||
input_n = n
|
||||
else:
|
||||
input_n = (n,)
|
||||
super(MultiBinary, self).__init__(input_n, np.int8)
|
||||
super(MultiBinary, self).__init__(input_n, np.int8, seed)
|
||||
|
||||
def sample(self):
|
||||
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||
|
@@ -25,14 +25,14 @@ class MultiDiscrete(Space):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, nvec, dtype=np.int64):
|
||||
def __init__(self, nvec, dtype=np.int64, seed=None):
|
||||
"""
|
||||
nvec: vector of counts of each categorical variable
|
||||
"""
|
||||
assert (np.array(nvec) > 0).all(), "nvec (counts) have to be positive"
|
||||
self.nvec = np.asarray(nvec, dtype=dtype)
|
||||
|
||||
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
|
||||
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype, seed)
|
||||
|
||||
def sample(self):
|
||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
||||
|
@@ -16,12 +16,14 @@ class Space(object):
|
||||
not handle custom spaces properly. Use custom spaces with care.
|
||||
"""
|
||||
|
||||
def __init__(self, shape=None, dtype=None):
|
||||
def __init__(self, shape=None, dtype=None, seed=None):
|
||||
import numpy as np # takes about 300-400ms to import, so we load lazily
|
||||
|
||||
self._shape = None if shape is None else tuple(shape)
|
||||
self.dtype = None if dtype is None else np.dtype(dtype)
|
||||
self._np_random = None
|
||||
if seed is not None:
|
||||
self.seed(seed)
|
||||
|
||||
@property
|
||||
def np_random(self):
|
||||
|
@@ -180,6 +180,53 @@ def test_bad_space_calls(space_fn):
|
||||
space_fn()
|
||||
|
||||
|
||||
def test_seed_Dict():
|
||||
test_space = Dict(
|
||||
{
|
||||
"a": Box(low=0, high=1, shape=(3, 3)),
|
||||
"b": Dict(
|
||||
{
|
||||
"b_1": Box(low=-100, high=100, shape=(2,)),
|
||||
"b_2": Box(low=-1, high=1, shape=(2,)),
|
||||
}
|
||||
),
|
||||
"c": Discrete(5),
|
||||
}
|
||||
)
|
||||
|
||||
seed_dict = {
|
||||
"a": 0,
|
||||
"b": {
|
||||
"b_1": 1,
|
||||
"b_2": 2,
|
||||
},
|
||||
"c": 3,
|
||||
}
|
||||
|
||||
test_space.seed(seed_dict)
|
||||
|
||||
# "Unpack" the dict sub-spaces into individual spaces
|
||||
a = Box(low=0, high=1, shape=(3, 3))
|
||||
a.seed(0)
|
||||
b_1 = Box(low=-100, high=100, shape=(2,))
|
||||
b_1.seed(1)
|
||||
b_2 = Box(low=-1, high=1, shape=(2,))
|
||||
b_2.seed(2)
|
||||
c = Discrete(5)
|
||||
c.seed(3)
|
||||
|
||||
for i in range(10):
|
||||
test_s = test_space.sample()
|
||||
a_s = a.sample()
|
||||
assert (test_s["a"] == a_s).all()
|
||||
b_1_s = b_1.sample()
|
||||
assert (test_s["b"]["b_1"] == b_1_s).all()
|
||||
b_2_s = b_2.sample()
|
||||
assert (test_s["b"]["b_2"] == b_2_s).all()
|
||||
c_s = c.sample()
|
||||
assert test_s["c"] == c_s
|
||||
|
||||
|
||||
def test_box_dtype_check():
|
||||
# Related Issues:
|
||||
# https://github.com/openai/gym/issues/2357
|
||||
|
@@ -10,16 +10,22 @@ class Tuple(Space):
|
||||
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
||||
"""
|
||||
|
||||
def __init__(self, spaces):
|
||||
def __init__(self, spaces, seed=None):
|
||||
self.spaces = spaces
|
||||
for space in spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
super(Tuple, self).__init__(None, None)
|
||||
super(Tuple, self).__init__(None, None, seed)
|
||||
|
||||
def seed(self, seed=None):
|
||||
seed = super().seed(seed)
|
||||
seeds = []
|
||||
|
||||
if isinstance(seed, list):
|
||||
for i, space in enumerate(self.spaces):
|
||||
seeds += space.seed(seed[i])
|
||||
elif isinstance(seed, int):
|
||||
seeds = super().seed(seed)
|
||||
try:
|
||||
subseeds = self.np_random.choice(
|
||||
np.iinfo(int).max,
|
||||
@@ -34,9 +40,14 @@ class Tuple(Space):
|
||||
)
|
||||
|
||||
for subspace, subseed in zip(self.spaces, subseeds):
|
||||
seed.append(subspace.seed(int(subseed))[0])
|
||||
seeds.append(subspace.seed(int(subseed))[0])
|
||||
elif seed is None:
|
||||
for space in self.spaces:
|
||||
seeds += space.seed(seed)
|
||||
else:
|
||||
raise TypeError("Passed seed not of an expected type: list or int or None")
|
||||
|
||||
return seed
|
||||
return seeds
|
||||
|
||||
def sample(self):
|
||||
return tuple([space.sample() for space in self.spaces])
|
||||
|
Reference in New Issue
Block a user