2022-01-24 23:22:11 +01:00
|
|
|
from typing import Optional
|
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
import numpy as np
|
2019-01-30 22:39:55 +01:00
|
|
|
from .space import Space
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
class Discrete(Space[int]):
|
2021-07-29 02:26:34 +02:00
|
|
|
r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`.
|
2019-07-12 14:08:54 -07:00
|
|
|
|
2021-10-30 21:42:01 +05:30
|
|
|
A start value can be optionally specified to shift the range
|
|
|
|
to :math:`\{ a, a+1, \\dots, a+n-1 \}`.
|
|
|
|
|
2019-03-25 00:42:53 +01:00
|
|
|
Example::
|
2019-07-12 14:08:54 -07:00
|
|
|
|
2019-03-25 00:42:53 +01:00
|
|
|
>>> Discrete(2)
|
2021-10-30 21:42:01 +05:30
|
|
|
>>> Discrete(3, start=-1) # {-1, 0, 1}
|
2019-07-12 14:08:54 -07:00
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __init__(self, n: int, seed: Optional[int] = None, start: int = 0):
|
2021-12-16 13:45:37 +08:00
|
|
|
assert n > 0, "n (counts) have to be positive"
|
|
|
|
assert isinstance(start, (int, np.integer))
|
|
|
|
self.n = int(n)
|
2021-10-30 21:42:01 +05:30
|
|
|
self.start = int(start)
|
2021-11-14 14:50:23 +01:00
|
|
|
super().__init__((), np.int64, seed)
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def sample(self) -> int:
|
2021-10-30 21:42:01 +05:30
|
|
|
return self.start + self.np_random.randint(self.n)
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def contains(self, x) -> bool:
|
2016-04-27 18:31:32 -07:00
|
|
|
if isinstance(x, int):
|
|
|
|
as_int = x
|
2021-07-29 15:39:42 -04:00
|
|
|
elif isinstance(x, (np.generic, np.ndarray)) and (
|
|
|
|
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
|
|
|
|
):
|
2022-01-24 23:22:11 +01:00
|
|
|
as_int = int(x) # type: ignore
|
2016-04-27 18:31:32 -07:00
|
|
|
else:
|
|
|
|
return False
|
2021-10-30 21:42:01 +05:30
|
|
|
return self.start <= as_int < self.start + self.n
|
2018-08-27 15:30:47 -07:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __repr__(self) -> str:
|
2021-10-30 21:42:01 +05:30
|
|
|
if self.start != 0:
|
|
|
|
return "Discrete(%d, start=%d)" % (self.n, self.start)
|
2016-04-27 08:00:58 -07:00
|
|
|
return "Discrete(%d)" % self.n
|
2018-09-24 20:11:03 +02:00
|
|
|
|
2022-01-24 23:22:11 +01:00
|
|
|
def __eq__(self, other) -> bool:
|
2021-10-30 21:42:01 +05:30
|
|
|
return (
|
|
|
|
isinstance(other, Discrete)
|
|
|
|
and self.n == other.n
|
|
|
|
and self.start == other.start
|
|
|
|
)
|