2019-03-25 00:47:16 +01:00
|
|
|
import numpy as np
|
2019-01-30 22:39:55 +01:00
|
|
|
from .space import Space
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
|
|
|
|
class Tuple(Space):
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
|
|
|
A tuple (i.e., product) of simpler spaces
|
2016-06-11 23:10:58 -07:00
|
|
|
|
|
|
|
Example usage:
|
|
|
|
self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(3)))
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def __init__(self, spaces):
|
|
|
|
self.spaces = spaces
|
2019-03-23 23:18:19 -07:00
|
|
|
for space in spaces:
|
2021-07-29 02:26:34 +02:00
|
|
|
assert isinstance(
|
|
|
|
space, Space
|
|
|
|
), "Elements of the tuple must be instances of gym.Space"
|
2019-02-07 11:29:04 -08:00
|
|
|
super(Tuple, self).__init__(None, None)
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2019-05-24 18:16:07 -07:00
|
|
|
def seed(self, seed=None):
|
2019-01-30 22:39:55 +01:00
|
|
|
[space.seed(seed) for space in self.spaces]
|
2016-04-27 08:00:58 -07:00
|
|
|
|
|
|
|
def sample(self):
|
|
|
|
return tuple([space.sample() for space in self.spaces])
|
|
|
|
|
|
|
|
def contains(self, x):
|
2016-05-06 15:46:11 -07:00
|
|
|
if isinstance(x, list):
|
|
|
|
x = tuple(x) # Promote list to tuple for contains check
|
2021-07-29 02:26:34 +02:00
|
|
|
return (
|
|
|
|
isinstance(x, tuple)
|
|
|
|
and len(x) == len(self.spaces)
|
|
|
|
and all(space.contains(part) for (space, part) in zip(self.spaces, x))
|
|
|
|
)
|
2016-04-27 08:00:58 -07:00
|
|
|
|
|
|
|
def __repr__(self):
|
2021-07-29 02:26:34 +02:00
|
|
|
return "Tuple(" + ", ".join([str(s) for s in self.spaces]) + ")"
|
2016-04-27 08:00:58 -07:00
|
|
|
|
|
|
|
def to_jsonable(self, sample_n):
|
|
|
|
# serialize as list-repr of tuple of vectors
|
2021-07-29 02:26:34 +02:00
|
|
|
return [
|
|
|
|
space.to_jsonable([sample[i] for sample in sample_n])
|
|
|
|
for i, space in enumerate(self.spaces)
|
|
|
|
]
|
2016-04-27 08:00:58 -07:00
|
|
|
|
|
|
|
def from_jsonable(self, sample_n):
|
2021-07-29 02:26:34 +02:00
|
|
|
return [
|
|
|
|
sample
|
|
|
|
for sample in zip(
|
|
|
|
*[
|
|
|
|
space.from_jsonable(sample_n[i])
|
|
|
|
for i, space in enumerate(self.spaces)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
]
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2019-03-01 18:22:58 -05:00
|
|
|
def __getitem__(self, index):
|
|
|
|
return self.spaces[index]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.spaces)
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2018-09-24 20:11:03 +02:00
|
|
|
def __eq__(self, other):
|
2019-03-23 23:18:19 -07:00
|
|
|
return isinstance(other, Tuple) and self.spaces == other.spaces
|