from __future__ import annotations from typing import Iterable, Optional, Sequence import numpy as np from gym import logger from gym.spaces.discrete import Discrete from gym.spaces.space import Space from gym.utils import seeding class MultiDiscrete(Space[np.ndarray]): """ The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each. It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space. It is parametrized by passing an array of positive integers specifying number of actions for each discrete action space. Note: Some environment wrappers assume a value of 0 always represents the NOOP action. e.g. Nintendo Game Controller - Can be conceptualized as 3 discrete action spaces: 1. Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 2. Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 3. Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 It can be initialized as ``MultiDiscrete([ 5, 2, 2 ])`` """ def __init__( self, nvec: list[int], dtype=np.int64, seed: Optional[int | seeding.RandomNumberGenerator] = None, ): """ nvec: vector of counts of each categorical variable """ self.nvec = np.array(nvec, dtype=dtype, copy=True) assert (self.nvec > 0).all(), "nvec (counts) have to be positive" super().__init__(self.nvec.shape, dtype, seed) @property def shape(self) -> tuple[int, ...]: """Has stricter type than gym.Space - never None.""" return self._shape # type: ignore def sample(self) -> np.ndarray: return (self.np_random.random(self.nvec.shape) * self.nvec).astype(self.dtype) def contains(self, x) -> bool: if isinstance(x, Sequence): x = np.array(x) # Promote list to array for contains check # if nvec is uint32 and space dtype is uint32, then 0 <= x < self.nvec guarantees that x # is within correct bounds for space dtype (even though x does not have to be unsigned) return bool(x.shape == self.shape and (0 <= x).all() and (x < self.nvec).all()) def to_jsonable(self, sample_n: Iterable[np.ndarray]): return [sample.tolist() for sample in sample_n] def from_jsonable(self, sample_n): return np.array(sample_n) def __repr__(self): return f"MultiDiscrete({self.nvec})" def __getitem__(self, index): nvec = self.nvec[index] if nvec.ndim == 0: subspace = Discrete(nvec) else: subspace = MultiDiscrete(nvec, self.dtype) # type: ignore subspace.np_random.bit_generator.state = self.np_random.bit_generator.state return subspace def __len__(self): if self.nvec.ndim >= 2: logger.warn("Get length of a multi-dimensional MultiDiscrete space.") return len(self.nvec) def __eq__(self, other): return isinstance(other, MultiDiscrete) and np.all(self.nvec == other.nvec)