Files
Gymnasium/gym/spaces/discrete.py

50 lines
1.3 KiB
Python
Raw Normal View History

2016-04-27 08:00:58 -07:00
import numpy as np
from .space import Space
2016-04-27 08:00:58 -07:00
class Discrete(Space):
2021-07-29 02:26:34 +02:00
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
A start value can be optionally specified to shift the range
to :math:`\{ a, a+1, \\dots, a+n-1 \}`.
2019-03-25 00:42:53 +01:00
Example::
2019-03-25 00:42:53 +01:00
>>> Discrete(2)
>>> Discrete(3, start=-1) # {-1, 0, 1}
2016-04-27 08:00:58 -07:00
"""
2021-07-29 02:26:34 +02:00
def __init__(self, n, seed=None, start=0):
assert n > 0, "n (counts) have to be positive"
assert isinstance(start, (int, np.integer))
self.n = int(n)
self.start = int(start)
super().__init__((), np.int64, seed)
2016-04-27 08:00:58 -07:00
def sample(self):
return self.start + self.np_random.randint(self.n)
2016-04-27 08:00:58 -07:00
def contains(self, x):
if isinstance(x, int):
as_int = x
2021-07-29 15:39:42 -04:00
elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
):
as_int = int(x)
else:
return False
return self.start <= as_int < self.start + self.n
2016-04-27 08:00:58 -07:00
def __repr__(self):
if self.start != 0:
return "Discrete(%d, start=%d)" % (self.n, self.start)
2016-04-27 08:00:58 -07:00
return "Discrete(%d)" % self.n
2016-04-27 08:00:58 -07:00
def __eq__(self, other):
return (
isinstance(other, Discrete)
and self.n == other.n
and self.start == other.start
)