mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-07 00:11:46 +00:00
* Fix seed method for Tuple and Dict * Improve stochasticity * Update test cases for seed method * Update test cases for seed method Update test cases for seed method Update test cases for seed method
82 lines
2.3 KiB
Python
82 lines
2.3 KiB
Python
import numpy as np
|
|
from .space import Space
|
|
|
|
|
|
class Tuple(Space):
|
|
"""
|
|
A tuple (i.e., product) of simpler spaces
|
|
|
|
Example usage:
|
|
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
|
"""
|
|
|
|
def __init__(self, spaces):
|
|
self.spaces = spaces
|
|
for space in spaces:
|
|
assert isinstance(
|
|
space, Space
|
|
), "Elements of the tuple must be instances of gym.Space"
|
|
super(Tuple, self).__init__(None, None)
|
|
|
|
def seed(self, seed=None):
|
|
seed = super().seed(seed)
|
|
try:
|
|
subseeds = self.np_random.choice(
|
|
np.iinfo(int).max,
|
|
size=len(self.spaces),
|
|
replace=False, # unique subseed for each subspace
|
|
)
|
|
except ValueError:
|
|
subseeds = self.np_random.choice(
|
|
np.iinfo(int).max,
|
|
size=len(self.spaces),
|
|
replace=True, # we get more than INT_MAX subspaces
|
|
)
|
|
|
|
for subspace, subseed in zip(self.spaces, subseeds):
|
|
seed.append(subspace.seed(int(subseed))[0])
|
|
|
|
return seed
|
|
|
|
def sample(self):
|
|
return tuple([space.sample() for space in self.spaces])
|
|
|
|
def contains(self, x):
|
|
if isinstance(x, list):
|
|
x = tuple(x) # Promote list to tuple for contains check
|
|
return (
|
|
isinstance(x, tuple)
|
|
and len(x) == len(self.spaces)
|
|
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
|
|
)
|
|
|
|
def __repr__(self):
|
|
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
|
|
|
def to_jsonable(self, sample_n):
|
|
# serialize as list-repr of tuple of vectors
|
|
return [
|
|
space.to_jsonable([sample[i] for sample in sample_n])
|
|
for i, space in enumerate(self.spaces)
|
|
]
|
|
|
|
def from_jsonable(self, sample_n):
|
|
return [
|
|
sample
|
|
for sample in zip(
|
|
*[
|
|
space.from_jsonable(sample_n[i])
|
|
for i, space in enumerate(self.spaces)
|
|
]
|
|
)
|
|
]
|
|
|
|
def __getitem__(self, index):
|
|
return self.spaces[index]
|
|
|
|
def __len__(self):
|
|
return len(self.spaces)
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, Tuple) and self.spaces == other.spaces
|