mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-02 02:32:50 +00:00
Make Tuple and Dicts be seedable with lists and dicts of seeds + make the seed in default initialization controllable (#1774)
* Make the seed in default initialization controllable Since seed() is being called in default initialization of Space, it should be controllable for reproducibility. * Updated derived classes of Space to have their seeds controllable at initialization. * Allow Tuple's spaces to each have their own seed * Added dict based seeding for Dict space; test cases for Tuple and Dict seeding * Update discrete.py * Update test_spaces.py * Add seed to __init__() * blacked * Fix black * Fix failing tests
This commit is contained in:
@@ -23,7 +23,7 @@ class Box(Space):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, low, high, shape=None, dtype=np.float32):
|
def __init__(self, low, high, shape=None, dtype=np.float32, seed=None):
|
||||||
assert dtype is not None, "dtype must be explicitly provided. "
|
assert dtype is not None, "dtype must be explicitly provided. "
|
||||||
self.dtype = np.dtype(dtype)
|
self.dtype = np.dtype(dtype)
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ class Box(Space):
|
|||||||
self.bounded_below = -np.inf < self.low
|
self.bounded_below = -np.inf < self.low
|
||||||
self.bounded_above = np.inf > self.high
|
self.bounded_above = np.inf > self.high
|
||||||
|
|
||||||
super(Box, self).__init__(self.shape, self.dtype)
|
super(Box, self).__init__(self.shape, self.dtype, seed)
|
||||||
|
|
||||||
def is_bounded(self, manner="both"):
|
def is_bounded(self, manner="both"):
|
||||||
below = np.all(self.bounded_below)
|
below = np.all(self.bounded_below)
|
||||||
|
@@ -33,10 +33,11 @@ class Dict(Space):
|
|||||||
})
|
})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, spaces=None, **spaces_kwargs):
|
def __init__(self, spaces=None, seed=None, **spaces_kwargs):
|
||||||
assert (spaces is None) or (
|
assert (spaces is None) or (
|
||||||
not spaces_kwargs
|
not spaces_kwargs
|
||||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||||
|
|
||||||
if spaces is None:
|
if spaces is None:
|
||||||
spaces = spaces_kwargs
|
spaces = spaces_kwargs
|
||||||
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
||||||
@@ -49,28 +50,45 @@ class Dict(Space):
|
|||||||
space, Space
|
space, Space
|
||||||
), "Values of the dict should be instances of gym.Space"
|
), "Values of the dict should be instances of gym.Space"
|
||||||
super(Dict, self).__init__(
|
super(Dict, self).__init__(
|
||||||
None, None
|
None, None, seed
|
||||||
) # None for shape and dtype, since it'll require special handling
|
) # None for shape and dtype, since it'll require special handling
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
seed = super().seed(seed)
|
seeds = []
|
||||||
try:
|
if isinstance(seed, dict):
|
||||||
subseeds = self.np_random.choice(
|
for key, seed_key in zip(self.spaces, seed):
|
||||||
np.iinfo(int).max,
|
assert key == seed_key, print(
|
||||||
size=len(self.spaces),
|
"Key value",
|
||||||
replace=False, # unique subseed for each subspace
|
seed_key,
|
||||||
)
|
"in passed seed dict did not match key value",
|
||||||
except ValueError:
|
key,
|
||||||
subseeds = self.np_random.choice(
|
"in spaces Dict.",
|
||||||
np.iinfo(int).max,
|
)
|
||||||
size=len(self.spaces),
|
seeds += self.spaces[key].seed(seed[seed_key])
|
||||||
replace=True, # we get more than INT_MAX subspaces
|
elif isinstance(seed, int):
|
||||||
)
|
seeds = super().seed(seed)
|
||||||
|
try:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=False, # unique subseed for each subspace
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=True, # we get more than INT_MAX subspaces
|
||||||
|
)
|
||||||
|
|
||||||
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
for subspace, subseed in zip(self.spaces.values(), subseeds):
|
||||||
seed.append(subspace.seed(int(subseed))[0])
|
seeds.append(subspace.seed(int(subseed))[0])
|
||||||
|
elif seed is None:
|
||||||
|
for space in self.spaces.values():
|
||||||
|
seeds += space.seed(seed)
|
||||||
|
else:
|
||||||
|
raise TypeError("Passed seed not of an expected type: dict or int or None")
|
||||||
|
|
||||||
return seed
|
return seeds
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||||
|
@@ -11,10 +11,10 @@ class Discrete(Space):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n):
|
def __init__(self, n, seed=None):
|
||||||
assert n >= 0
|
assert n >= 0
|
||||||
self.n = n
|
self.n = n
|
||||||
super(Discrete, self).__init__((), np.int64)
|
super(Discrete, self).__init__((), np.int64, seed)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return self.np_random.randint(self.n)
|
return self.np_random.randint(self.n)
|
||||||
|
@@ -26,13 +26,13 @@ class MultiBinary(Space):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n):
|
def __init__(self, n, seed=None):
|
||||||
self.n = n
|
self.n = n
|
||||||
if type(n) in [tuple, list, np.ndarray]:
|
if type(n) in [tuple, list, np.ndarray]:
|
||||||
input_n = n
|
input_n = n
|
||||||
else:
|
else:
|
||||||
input_n = (n,)
|
input_n = (n,)
|
||||||
super(MultiBinary, self).__init__(input_n, np.int8)
|
super(MultiBinary, self).__init__(input_n, np.int8, seed)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
|
return self.np_random.randint(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||||
|
@@ -25,14 +25,14 @@ class MultiDiscrete(Space):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, nvec, dtype=np.int64):
|
def __init__(self, nvec, dtype=np.int64, seed=None):
|
||||||
"""
|
"""
|
||||||
nvec: vector of counts of each categorical variable
|
nvec: vector of counts of each categorical variable
|
||||||
"""
|
"""
|
||||||
assert (np.array(nvec) > 0).all(), "nvec (counts) have to be positive"
|
assert (np.array(nvec) > 0).all(), "nvec (counts) have to be positive"
|
||||||
self.nvec = np.asarray(nvec, dtype=dtype)
|
self.nvec = np.asarray(nvec, dtype=dtype)
|
||||||
|
|
||||||
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
|
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype, seed)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
||||||
|
@@ -16,12 +16,14 @@ class Space(object):
|
|||||||
not handle custom spaces properly. Use custom spaces with care.
|
not handle custom spaces properly. Use custom spaces with care.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shape=None, dtype=None):
|
def __init__(self, shape=None, dtype=None, seed=None):
|
||||||
import numpy as np # takes about 300-400ms to import, so we load lazily
|
import numpy as np # takes about 300-400ms to import, so we load lazily
|
||||||
|
|
||||||
self._shape = None if shape is None else tuple(shape)
|
self._shape = None if shape is None else tuple(shape)
|
||||||
self.dtype = None if dtype is None else np.dtype(dtype)
|
self.dtype = None if dtype is None else np.dtype(dtype)
|
||||||
self._np_random = None
|
self._np_random = None
|
||||||
|
if seed is not None:
|
||||||
|
self.seed(seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def np_random(self):
|
def np_random(self):
|
||||||
|
@@ -180,6 +180,53 @@ def test_bad_space_calls(space_fn):
|
|||||||
space_fn()
|
space_fn()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_Dict():
|
||||||
|
test_space = Dict(
|
||||||
|
{
|
||||||
|
"a": Box(low=0, high=1, shape=(3, 3)),
|
||||||
|
"b": Dict(
|
||||||
|
{
|
||||||
|
"b_1": Box(low=-100, high=100, shape=(2,)),
|
||||||
|
"b_2": Box(low=-1, high=1, shape=(2,)),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"c": Discrete(5),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
seed_dict = {
|
||||||
|
"a": 0,
|
||||||
|
"b": {
|
||||||
|
"b_1": 1,
|
||||||
|
"b_2": 2,
|
||||||
|
},
|
||||||
|
"c": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
test_space.seed(seed_dict)
|
||||||
|
|
||||||
|
# "Unpack" the dict sub-spaces into individual spaces
|
||||||
|
a = Box(low=0, high=1, shape=(3, 3))
|
||||||
|
a.seed(0)
|
||||||
|
b_1 = Box(low=-100, high=100, shape=(2,))
|
||||||
|
b_1.seed(1)
|
||||||
|
b_2 = Box(low=-1, high=1, shape=(2,))
|
||||||
|
b_2.seed(2)
|
||||||
|
c = Discrete(5)
|
||||||
|
c.seed(3)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
test_s = test_space.sample()
|
||||||
|
a_s = a.sample()
|
||||||
|
assert (test_s["a"] == a_s).all()
|
||||||
|
b_1_s = b_1.sample()
|
||||||
|
assert (test_s["b"]["b_1"] == b_1_s).all()
|
||||||
|
b_2_s = b_2.sample()
|
||||||
|
assert (test_s["b"]["b_2"] == b_2_s).all()
|
||||||
|
c_s = c.sample()
|
||||||
|
assert test_s["c"] == c_s
|
||||||
|
|
||||||
|
|
||||||
def test_box_dtype_check():
|
def test_box_dtype_check():
|
||||||
# Related Issues:
|
# Related Issues:
|
||||||
# https://github.com/openai/gym/issues/2357
|
# https://github.com/openai/gym/issues/2357
|
||||||
|
@@ -10,33 +10,44 @@ class Tuple(Space):
|
|||||||
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, spaces):
|
def __init__(self, spaces, seed=None):
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
for space in spaces:
|
for space in spaces:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
space, Space
|
space, Space
|
||||||
), "Elements of the tuple must be instances of gym.Space"
|
), "Elements of the tuple must be instances of gym.Space"
|
||||||
super(Tuple, self).__init__(None, None)
|
super(Tuple, self).__init__(None, None, seed)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
seed = super().seed(seed)
|
seeds = []
|
||||||
try:
|
|
||||||
subseeds = self.np_random.choice(
|
|
||||||
np.iinfo(int).max,
|
|
||||||
size=len(self.spaces),
|
|
||||||
replace=False, # unique subseed for each subspace
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
subseeds = self.np_random.choice(
|
|
||||||
np.iinfo(int).max,
|
|
||||||
size=len(self.spaces),
|
|
||||||
replace=True, # we get more than INT_MAX subspaces
|
|
||||||
)
|
|
||||||
|
|
||||||
for subspace, subseed in zip(self.spaces, subseeds):
|
if isinstance(seed, list):
|
||||||
seed.append(subspace.seed(int(subseed))[0])
|
for i, space in enumerate(self.spaces):
|
||||||
|
seeds += space.seed(seed[i])
|
||||||
|
elif isinstance(seed, int):
|
||||||
|
seeds = super().seed(seed)
|
||||||
|
try:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=False, # unique subseed for each subspace
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
subseeds = self.np_random.choice(
|
||||||
|
np.iinfo(int).max,
|
||||||
|
size=len(self.spaces),
|
||||||
|
replace=True, # we get more than INT_MAX subspaces
|
||||||
|
)
|
||||||
|
|
||||||
return seed
|
for subspace, subseed in zip(self.spaces, subseeds):
|
||||||
|
seeds.append(subspace.seed(int(subseed))[0])
|
||||||
|
elif seed is None:
|
||||||
|
for space in self.spaces:
|
||||||
|
seeds += space.seed(seed)
|
||||||
|
else:
|
||||||
|
raise TypeError("Passed seed not of an expected type: list or int or None")
|
||||||
|
|
||||||
|
return seeds
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return tuple([space.sample() for space in self.spaces])
|
return tuple([space.sample() for space in self.spaces])
|
||||||
|
Reference in New Issue
Block a user