Files
Gymnasium/gym/spaces/tuple.py

52 lines
1.6 KiB
Python
Raw Normal View History

2019-03-25 00:47:16 +01:00
import numpy as np
from .space import Space
2016-04-27 08:00:58 -07:00
class Tuple(Space):
2016-04-27 08:00:58 -07:00
"""
A tuple (i.e., product) of simpler spaces
2016-06-11 23:10:58 -07:00
Example usage:
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
2016-04-27 08:00:58 -07:00
"""
2021-07-29 02:26:34 +02:00
2016-04-27 08:00:58 -07:00
def __init__(self, spaces):
self.spaces = spaces
for space in spaces:
2021-07-29 12:42:48 -04:00
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
super(Tuple, self).__init__(None, None)
def seed(self, seed=None):
[space.seed(seed) for space in self.spaces]
2016-04-27 08:00:58 -07:00
def sample(self):
return tuple([space.sample() for space in self.spaces])
def contains(self, x):
if isinstance(x, list):
x = tuple(x) # Promote list to tuple for contains check
2021-07-29 02:26:34 +02:00
return (
isinstance(x, tuple)
and len(x) == len(self.spaces)
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
)
2016-04-27 08:00:58 -07:00
def __repr__(self):
2021-07-29 02:26:34 +02:00
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
2016-04-27 08:00:58 -07:00
def to_jsonable(self, sample_n):
# serialize as list-repr of tuple of vectors
2021-07-29 12:42:48 -04:00
return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)]
2016-04-27 08:00:58 -07:00
def from_jsonable(self, sample_n):
2021-07-29 12:42:48 -04:00
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
2019-03-01 18:22:58 -05:00
def __getitem__(self, index):
return self.spaces[index]
def __len__(self):
return len(self.spaces)
2021-07-29 02:26:34 +02:00
def __eq__(self, other):
return isinstance(other, Tuple) and self.spaces == other.spaces