Type cast in spaces families (#2491)

* Type cast for `spaces.Dict`

* Type cast for `spaces.Tuple`

* Type cast for `spaces.Discrete`

* Type cast for `spaces.MultiDiscrete`

* Type cast for `spaces.MultiBinary`
This commit is contained in:
Xuehai Pan
2021-12-16 13:45:37 +08:00
committed by GitHub
parent 180d8ddd5c
commit 18c8b988d4
5 changed files with 29 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping, Sequence
import numpy as np import numpy as np
from .space import Space from .space import Space
@@ -42,9 +42,15 @@ class Dict(Space, Mapping):
if spaces is None: if spaces is None:
spaces = spaces_kwargs spaces = spaces_kwargs
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict): if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
spaces = OrderedDict(sorted(list(spaces.items()))) try:
if isinstance(spaces, list): spaces = OrderedDict(sorted(spaces.items()))
except TypeError: # raise when sort by different types of keys
spaces = OrderedDict(spaces.items())
if isinstance(spaces, Sequence):
spaces = OrderedDict(spaces) spaces = OrderedDict(spaces)
assert isinstance(spaces, OrderedDict), "spaces must be a dictionary"
self.spaces = spaces self.spaces = spaces
for space in spaces.values(): for space in spaces.values():
assert isinstance( assert isinstance(

View File

@@ -16,8 +16,9 @@ class Discrete(Space):
""" """
def __init__(self, n, seed=None, start=0): def __init__(self, n, seed=None, start=0):
assert n >= 0 and isinstance(start, (int, np.integer)) assert n > 0, "n (counts) have to be positive"
self.n = n assert isinstance(start, (int, np.integer))
self.n = int(n)
self.start = int(start) self.start = int(start)
super().__init__((), np.int64, seed) super().__init__((), np.int64, seed)

View File

@@ -1,3 +1,4 @@
from collections.abc import Sequence
import numpy as np import numpy as np
from .space import Space from .space import Space
@@ -14,9 +15,9 @@ class MultiBinary(Space):
>> self.observation_space.sample() >> self.observation_space.sample()
array([0,1,0,1,0], dtype =int8) array([0, 1, 0, 1, 0], dtype=int8)
>> self.observation_space = spaces.MultiBinary([3,2]) >> self.observation_space = spaces.MultiBinary([3, 2])
>> self.observation_space.sample() >> self.observation_space.sample()
@@ -27,18 +28,21 @@ class MultiBinary(Space):
""" """
def __init__(self, n, seed=None): def __init__(self, n, seed=None):
self.n = n if isinstance(n, (Sequence, np.ndarray)):
if type(n) in [tuple, list, np.ndarray]: self.n = input_n = tuple(int(i) for i in n)
input_n = n
else: else:
self.n = n = int(n)
input_n = (n,) input_n = (n,)
assert (np.asarray(input_n) > 0).all(), "n (counts) have to be positive"
super().__init__(input_n, np.int8, seed) super().__init__(input_n, np.int8, seed)
def sample(self): def sample(self):
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
def contains(self, x): def contains(self, x):
if isinstance(x, list) or isinstance(x, tuple): if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
if self.shape != x.shape: if self.shape != x.shape:
return False return False

View File

@@ -1,3 +1,4 @@
from collections.abc import Sequence
import numpy as np import numpy as np
from gym import logger from gym import logger
from .space import Space from .space import Space
@@ -29,8 +30,8 @@ class MultiDiscrete(Space):
""" """
nvec: vector of counts of each categorical variable nvec: vector of counts of each categorical variable
""" """
assert (np.array(nvec) > 0).all(), "nvec (counts) have to be positive" self.nvec = np.array(nvec, dtype=dtype, copy=True)
self.nvec = np.asarray(nvec, dtype=dtype) assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
super().__init__(self.nvec.shape, dtype, seed) super().__init__(self.nvec.shape, dtype, seed)
@@ -38,7 +39,7 @@ class MultiDiscrete(Space):
return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype) return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x): def contains(self, x):
if isinstance(x, list): if isinstance(x, Sequence):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
# if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x # if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x
# is within correct bounds for space dtype (even though x does not have to be unsigned) # is within correct bounds for space dtype (even though x does not have to be unsigned)

View File

@@ -11,6 +11,7 @@ class Tuple(Space):
""" """
def __init__(self, spaces, seed=None): def __init__(self, spaces, seed=None):
spaces = tuple(spaces)
self.spaces = spaces self.spaces = spaces
for space in spaces: for space in spaces:
assert isinstance( assert isinstance(
@@ -53,8 +54,8 @@ class Tuple(Space):
return tuple(space.sample() for space in self.spaces) return tuple(space.sample() for space in self.spaces)
def contains(self, x): def contains(self, x):
if isinstance(x, list): if isinstance(x, (list, np.ndarray)):
x = tuple(x) # Promote list to tuple for contains check x = tuple(x) # Promote list and ndarray to tuple for contains check
return ( return (
isinstance(x, tuple) isinstance(x, tuple)
and len(x) == len(self.spaces) and len(x) == len(self.spaces)