diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index a14a22489..b239dae22 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -1,5 +1,6 @@ """Root `__init__` of the gymnasium module setting the `__all__` of gymnasium modules.""" # isort: skip_file +# pyright: reportUnsupportedDunderAll=false from gymnasium.core import ( Env, @@ -17,7 +18,9 @@ from gymnasium.envs.registration import ( pprint_registry, make_vec, ) -from gymnasium import envs, spaces, utils, vector, wrappers, error, logger + +# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins +from gymnasium import envs __all__ = [ @@ -37,6 +40,7 @@ __all__ = [ "pprint_registry", # module folders "envs", + "experimental", "spaces", "utils", "vector", diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py index 4d678685f..e356dbadd 100644 --- a/gymnasium/experimental/__init__.py +++ b/gymnasium/experimental/__init__.py @@ -1,7 +1,7 @@ """Root __init__ of the gym experimental wrappers.""" -from gymnasium.experimental import functional +from gymnasium.experimental import functional, wrappers from gymnasium.experimental.functional import FuncEnv from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv @@ -17,4 +17,6 @@ __all__ = [ "VectorWrapper", "SyncVectorEnv", "AsyncVectorEnv", + # wrappers + "wrappers", ] diff --git a/gymnasium/experimental/functional_jax_env.py b/gymnasium/experimental/functional_jax_env.py index 05d8a13c8..059dbdd94 100644 --- a/gymnasium/experimental/functional_jax_env.py +++ b/gymnasium/experimental/functional_jax_env.py @@ -11,7 +11,7 @@ import numpy as np import gymnasium as gym from gymnasium.envs.registration import EnvSpec from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy +from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy from gymnasium.utils import seeding from gymnasium.vector.utils import batch_space diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 42c2b204c..8b8ace0f8 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -1,53 +1,11 @@ -"""Experimental Wrappers.""" -# isort: skip_file +"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python.""" +# pyright: reportUnsupportedDunderAll=false -from gymnasium.experimental.wrappers.lambda_action import ( - LambdaActionV0, - ClipActionV0, - RescaleActionV0, -) -from gymnasium.experimental.wrappers.lambda_observations import ( - LambdaObservationV0, - FilterObservationV0, - FlattenObservationV0, - GrayscaleObservationV0, - ResizeObservationV0, - ReshapeObservationV0, - RescaleObservationV0, - DtypeObservationV0, - PixelObservationV0, - NormalizeObservationV0, -) -from gymnasium.experimental.wrappers.lambda_reward import ( - ClipRewardV0, - LambdaRewardV0, - NormalizeRewardV0, -) -from gymnasium.experimental.wrappers.stateful_action import StickyActionV0 -from gymnasium.experimental.wrappers.stateful_observation import ( - TimeAwareObservationV0, - DelayObservationV0, - FrameStackObservationV0, -) -from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0 -from gymnasium.experimental.wrappers.common import ( - PassiveEnvCheckerV0, - OrderEnforcingV0, - AutoresetV0, - RecordEpisodeStatisticsV0, -) -from gymnasium.experimental.wrappers.rendering import ( - RenderCollectionV0, - RecordVideoV0, - HumanRenderingV0, -) +import importlib -from gymnasium.experimental.wrappers.vector import ( - VectorRecordEpisodeStatistics, - VectorListInfo, -) __all__ = [ + "vector", # --- Observation wrappers --- "LambdaObservationV0", "FilterObservationV0", @@ -82,7 +40,80 @@ __all__ = [ "RenderCollectionV0", "RecordVideoV0", "HumanRenderingV0", - # --- Vector --- - "VectorRecordEpisodeStatistics", - "VectorListInfo", + # --- Conversion --- + "JaxToNumpyV0", + "JaxToTorchV0", + "NumpyToTorchV0", ] + + +_wrapper_to_class = { + # lambda_action.py + "LambdaActionV0": "lambda_action", + "ClipActionV0": "lambda_action", + "RescaleActionV0": "lambda_action", + # lambda_observations.py + "LambdaObservationV0": "lambda_observations", + "FilterObservationV0": "lambda_observations", + "FlattenObservationV0": "lambda_observations", + "GrayscaleObservationV0": "lambda_observations", + "ResizeObservationV0": "lambda_observations", + "ReshapeObservationV0": "lambda_observations", + "RescaleObservationV0": "lambda_observations", + "DtypeObservationV0": "lambda_observations", + "PixelObservationV0": "lambda_observations", + "NormalizeObservationV0": "lambda_observations", + # lambda_reward.py + "ClipRewardV0": "lambda_reward", + "LambdaRewardV0": "lambda_reward", + "NormalizeRewardV0": "lambda_reward", + # stateful_action + "StickyActionV0": "stateful_action", + # stateful_observation + "TimeAwareObservationV0": "stateful_observation", + "DelayObservationV0": "stateful_observation", + "FrameStackObservationV0": "stateful_observation", + # atari_preprocessing + "AtariPreprocessingV0": "atari_preprocessing", + # common + "PassiveEnvCheckerV0": "common", + "OrderEnforcingV0": "common", + "AutoresetV0": "common", + "RecordEpisodeStatisticsV0": "common", + # rendering + "RenderCollectionV0": "rendering", + "RecordVideoV0": "rendering", + "HumanRenderingV0": "rendering", + # jax_to_numpy + "JaxToNumpyV0": "jax_to_numpy", + # "jax_to_numpy": "jax_to_numpy", + # "numpy_to_jax": "jax_to_numpy", + # jax_to_torch + "JaxToTorchV0": "jax_to_torch", + # "jax_to_torch": "jax_to_torch", + # "torch_to_jax": "jax_to_torch", + # numpy_to_torch + "NumpyToTorchV0": "numpy_to_torch", + # "torch_to_numpy": "numpy_to_torch", + # "numpy_to_torch": "numpy_to_torch", +} + + +def __getattr__(name: str): + """To avoid having to load all wrappers on `import gymnasium` with all of their extra modules. + + This optimises the loading of gymnasium. + + Args: + name: The name of a wrapper to load + + Returns: + Wrapper + """ + if name in _wrapper_to_class: + import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}" + module = importlib.import_module(import_stmt) + return getattr(module, name) + # add helpful error message if version number has changed + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/gymnasium/experimental/wrappers/atari_preprocessing.py b/gymnasium/experimental/wrappers/atari_preprocessing.py index ae0c2c02d..5564bc0f7 100644 --- a/gymnasium/experimental/wrappers/atari_preprocessing.py +++ b/gymnasium/experimental/wrappers/atari_preprocessing.py @@ -8,7 +8,9 @@ from gymnasium.spaces import Box try: import cv2 except ImportError: - cv2 = None + raise gym.error.DependencyNotInstalled( + "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" + ) class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs): @@ -72,10 +74,6 @@ class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs): ) gym.Wrapper.__init__(self, env) - if cv2 is None: - raise gym.error.DependencyNotInstalled( - "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" - ) assert frame_skip > 0 assert screen_size > 0 assert noop_max >= 0 diff --git a/gymnasium/experimental/wrappers/conversion/__init__.py b/gymnasium/experimental/wrappers/conversion/__init__.py deleted file mode 100644 index bd1b3b471..000000000 --- a/gymnasium/experimental/wrappers/conversion/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually.""" diff --git a/gymnasium/experimental/wrappers/conversion/jax_to_torch.py b/gymnasium/experimental/wrappers/conversion/jax_to_torch.py deleted file mode 100644 index c35c92ed9..000000000 --- a/gymnasium/experimental/wrappers/conversion/jax_to_torch.py +++ /dev/null @@ -1,205 +0,0 @@ -# This wrapper will convert torch inputs for the actions and observations to Jax arrays -# for an underlying Jax environment then convert the return observations from Jax arrays -# back to torch tensors. -# -# Functionality for converting between torch and jax types originally copied from -# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py -# Under the Apache 2.0 license. Copyright is held by the authors - -"""Helper functions and wrapper class for converting between PyTorch and Jax.""" -from __future__ import annotations - -import functools -import numbers -from collections import abc -from typing import Any, Iterable, Mapping, SupportsFloat, Union - -import gymnasium as gym -from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType -from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy - - -try: - import jax.numpy as jnp - from jax import dlpack as jax_dlpack -except ImportError: - jnp, jax_dlpack = None, None - -try: - import torch - from torch.utils import dlpack as torch_dlpack - - Device = Union[str, torch.device] -except ImportError: - torch, torch_dlpack, Device = None, None, None - - -@functools.singledispatch -def torch_to_jax(value: Any) -> Any: - """Converts a PyTorch Tensor into a Jax DeviceArray.""" - if torch is None: - raise DependencyNotInstalled( - "Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`" - ) - elif jnp is None: - raise DependencyNotInstalled( - "Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`" - ) - else: - raise Exception( - f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github." - ) - - -if torch is not None and jnp is not None: - - @torch_to_jax.register(numbers.Number) - def _number_torch_to_jax(value: numbers.Number) -> Any: - """Convert a python number (int, float, complex) to a jax array.""" - assert jnp is not None - return jnp.array(value) - - @torch_to_jax.register(torch.Tensor) - def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray: - """Converts a PyTorch Tensor into a Jax DeviceArray.""" - assert torch_dlpack is not None and jax_dlpack is not None - tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage] - value - ) - tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage] - tensor - ) - return tensor - - @torch_to_jax.register(abc.Mapping) - def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" - return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) - - @torch_to_jax.register(abc.Iterable) - def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: - """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" - return type(value)(torch_to_jax(v) for v in value) - - -@functools.singledispatch -def jax_to_torch(value: Any, device: Device | None = None) -> Any: - """Converts a Jax DeviceArray into a PyTorch Tensor.""" - if torch is None: - raise DependencyNotInstalled( - "Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`" - ) - elif jnp is None: - raise DependencyNotInstalled( - "Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`" - ) - else: - raise Exception( - f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github." - ) - - -if torch is not None and jnp is not None: - - @jax_to_torch.register(jnp.DeviceArray) - def _devicearray_jax_to_torch( - value: jnp.DeviceArray, device: Device | None = None - ) -> torch.Tensor: - """Converts a Jax DeviceArray into a PyTorch Tensor.""" - assert jax_dlpack is not None and torch_dlpack is not None - dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage] - value - ) - tensor = torch_dlpack.from_dlpack(dlpack) - if device: - return tensor.to(device=device) - return tensor - - @jax_to_torch.register(abc.Mapping) - def _jax_mapping_to_torch( - value: Mapping[str, Any], device: Device | None = None - ) -> Mapping[str, Any]: - """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" - return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) - - @jax_to_torch.register(abc.Iterable) - def _jax_iterable_to_torch( - value: Iterable[Any], device: Device | None = None - ) -> Iterable[Any]: - """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" - return type(value)(jax_to_torch(v, device) for v in value) - - -class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors. - - Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. - - Note: - For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. - """ - - def __init__(self, env: gym.Env, device: Device | None = None): - """Wrapper class to change inputs and outputs of environment to PyTorch tensors. - - Args: - env: The Jax-based environment to wrap - device: The device the torch Tensors should be moved to - """ - if torch is None: - raise DependencyNotInstalled( - "torch is not installed, run `pip install torch`" - ) - elif jnp is None: - raise DependencyNotInstalled( - "jax is not installed, run `pip install gymnasium[jax]`" - ) - - gym.utils.RecordConstructorArgs.__init__(self, device=device) - gym.Wrapper.__init__(self, env) - - self.device: Device | None = device - - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: - """Performs the given action within the environment. - - Args: - action: The action to perform as a PyTorch Tensor - - Returns: - The next observation, reward, termination, truncation, and extra info - """ - jax_action = torch_to_jax(action) - obs, reward, terminated, truncated, info = self.env.step(jax_action) - - return ( - jax_to_torch(obs, self.device), - float(reward), - bool(terminated), - bool(truncated), - jax_to_torch(info, self.device), - ) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: - """Resets the environment returning PyTorch-based observation and info. - - Args: - seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. - - Returns: - PyTorch-based observations and info - """ - if options: - options = torch_to_jax(options) - - return jax_to_torch(self.env.reset(seed=seed, options=options), self.device) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Returns the rendered frames as a NumPy array.""" - return jax_to_numpy(self.env.render()) diff --git a/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py b/gymnasium/experimental/wrappers/jax_to_numpy.py similarity index 55% rename from gymnasium/experimental/wrappers/conversion/jax_to_numpy.py rename to gymnasium/experimental/wrappers/jax_to_numpy.py index 7b9a06bad..793f38407 100644 --- a/gymnasium/experimental/wrappers/conversion/jax_to_numpy.py +++ b/gymnasium/experimental/wrappers/jax_to_numpy.py @@ -16,80 +16,73 @@ from gymnasium.error import DependencyNotInstalled try: import jax.numpy as jnp except ImportError: - # We handle the error internal to the relative functions - jnp = None + raise DependencyNotInstalled( + "Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`" + ) + +__all__ = ["jax_to_numpy", "numpy_to_jax", "JaxToNumpyV0"] @functools.singledispatch def numpy_to_jax(value: Any) -> Any: """Converts a value to a Jax DeviceArray.""" - if jnp is None: - raise DependencyNotInstalled( - "Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`" - ) - else: - raise Exception( - f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github." - ) + raise Exception( + f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github." + ) -if jnp is not None: +@numpy_to_jax.register(numbers.Number) +@numpy_to_jax.register(np.ndarray) +def _number_ndarray_numpy_to_jax( + value: np.ndarray | numbers.Number, +) -> jnp.DeviceArray: + """Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray.""" + assert jnp is not None + return jnp.array(value) - @numpy_to_jax.register(numbers.Number) - @numpy_to_jax.register(np.ndarray) - def _number_ndarray_numpy_to_jax( - value: np.ndarray | numbers.Number, - ) -> jnp.DeviceArray: - """Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray.""" - assert jnp is not None - return jnp.array(value) - @numpy_to_jax.register(abc.Mapping) - def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" - return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) +@numpy_to_jax.register(abc.Mapping) +def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" + return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) - @numpy_to_jax.register(abc.Iterable) - def _iterable_numpy_to_jax( - value: Iterable[np.ndarray | Any], - ) -> Iterable[jnp.DeviceArray | Any]: - """Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays.""" - return type(value)(numpy_to_jax(v) for v in value) + +@numpy_to_jax.register(abc.Iterable) +def _iterable_numpy_to_jax( + value: Iterable[np.ndarray | Any], +) -> Iterable[jnp.DeviceArray | Any]: + """Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays.""" + return type(value)(numpy_to_jax(v) for v in value) @functools.singledispatch def jax_to_numpy(value: Any) -> Any: """Converts a value to a numpy array.""" - if jnp is None: - raise DependencyNotInstalled( - "Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`" - ) - else: - raise Exception( - f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github." - ) + raise Exception( + f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github." + ) -if jnp is not None: +@jax_to_numpy.register(jnp.DeviceArray) +def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray: + """Converts a Jax DeviceArray to a numpy array.""" + return np.array(value) - @jax_to_numpy.register(jnp.DeviceArray) - def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray: - """Converts a Jax DeviceArray to a numpy array.""" - return np.array(value) - @jax_to_numpy.register(abc.Mapping) - def _mapping_jax_to_numpy( - value: Mapping[str, jnp.DeviceArray | Any] - ) -> Mapping[str, np.ndarray | Any]: - """Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays.""" - return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) +@jax_to_numpy.register(abc.Mapping) +def _mapping_jax_to_numpy( + value: Mapping[str, jnp.DeviceArray | Any] +) -> Mapping[str, np.ndarray | Any]: + """Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays.""" + return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) - @jax_to_numpy.register(abc.Iterable) - def _iterable_jax_to_numpy( - value: Iterable[np.ndarray | Any], - ) -> Iterable[jnp.DeviceArray | Any]: - """Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays.""" - return type(value)(jax_to_numpy(v) for v in value) + +@jax_to_numpy.register(abc.Iterable) +def _iterable_jax_to_numpy( + value: Iterable[np.ndarray | Any], +) -> Iterable[jnp.DeviceArray | Any]: + """Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays.""" + return type(value)(jax_to_numpy(v) for v in value) class JaxToNumpyV0( diff --git a/gymnasium/experimental/wrappers/jax_to_torch.py b/gymnasium/experimental/wrappers/jax_to_torch.py new file mode 100644 index 000000000..8bfe89316 --- /dev/null +++ b/gymnasium/experimental/wrappers/jax_to_torch.py @@ -0,0 +1,178 @@ +# This wrapper will convert torch inputs for the actions and observations to Jax arrays +# for an underlying Jax environment then convert the return observations from Jax arrays +# back to torch tensors. +# +# Functionality for converting between torch and jax types originally copied from +# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py +# Under the Apache 2.0 license. Copyright is held by the authors + +"""Helper functions and wrapper class for converting between PyTorch and Jax.""" +from __future__ import annotations + +import functools +import numbers +from collections import abc +from typing import Any, Iterable, Mapping, SupportsFloat, Union + +import gymnasium as gym +from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType +from gymnasium.error import DependencyNotInstalled +from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy + + +try: + import jax.numpy as jnp + from jax import dlpack as jax_dlpack +except ImportError: + raise DependencyNotInstalled( + "Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`" + ) + +try: + import torch + from torch.utils import dlpack as torch_dlpack + + Device = Union[str, torch.device] +except ImportError: + raise DependencyNotInstalled( + "Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`" + ) + + +__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"] + + +@functools.singledispatch +def torch_to_jax(value: Any) -> Any: + """Converts a PyTorch Tensor into a Jax DeviceArray.""" + raise Exception( + f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github." + ) + + +@torch_to_jax.register(numbers.Number) +def _number_torch_to_jax(value: numbers.Number) -> Any: + """Convert a python number (int, float, complex) to a jax array.""" + return jnp.array(value) + + +@torch_to_jax.register(torch.Tensor) +def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray: + """Converts a PyTorch Tensor into a Jax DeviceArray.""" + tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] + tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage] + return tensor + + +@torch_to_jax.register(abc.Mapping) +def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" + return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) + + +@torch_to_jax.register(abc.Iterable) +def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: + """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" + return type(value)(torch_to_jax(v) for v in value) + + +@functools.singledispatch +def jax_to_torch(value: Any, device: Device | None = None) -> Any: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + raise Exception( + f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github." + ) + + +@jax_to_torch.register(jnp.DeviceArray) +def _devicearray_jax_to_torch( + value: jnp.DeviceArray, device: Device | None = None +) -> torch.Tensor: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + assert jax_dlpack is not None and torch_dlpack is not None + dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] + tensor = torch_dlpack.from_dlpack(dlpack) + if device: + return tensor.to(device=device) + return tensor + + +@jax_to_torch.register(abc.Mapping) +def _jax_mapping_to_torch( + value: Mapping[str, Any], device: Device | None = None +) -> Mapping[str, Any]: + """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" + return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) + + +@jax_to_torch.register(abc.Iterable) +def _jax_iterable_to_torch( + value: Iterable[Any], device: Device | None = None +) -> Iterable[Any]: + """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" + return type(value)(jax_to_torch(v, device) for v in value) + + +class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs): + """Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors. + + Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. + + Note: + For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. + """ + + def __init__(self, env: gym.Env, device: Device | None = None): + """Wrapper class to change inputs and outputs of environment to PyTorch tensors. + + Args: + env: The Jax-based environment to wrap + device: The device the torch Tensors should be moved to + """ + gym.utils.RecordConstructorArgs.__init__(self, device=device) + gym.Wrapper.__init__(self, env) + + self.device: Device | None = device + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Performs the given action within the environment. + + Args: + action: The action to perform as a PyTorch Tensor + + Returns: + The next observation, reward, termination, truncation, and extra info + """ + jax_action = torch_to_jax(action) + obs, reward, terminated, truncated, info = self.env.step(jax_action) + + return ( + jax_to_torch(obs, self.device), + float(reward), + bool(terminated), + bool(truncated), + jax_to_torch(info, self.device), + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment returning PyTorch-based observation and info. + + Args: + seed: The seed for resetting the environment + options: The options for resetting the environment, these are converted to jax arrays. + + Returns: + PyTorch-based observations and info + """ + if options: + options = torch_to_jax(options) + + return jax_to_torch(self.env.reset(seed=seed, options=options), self.device) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Returns the rendered frames as a NumPy array.""" + return jax_to_numpy(self.env.render()) diff --git a/gymnasium/experimental/wrappers/conversion/numpy_to_torch.py b/gymnasium/experimental/wrappers/numpy_to_torch.py similarity index 51% rename from gymnasium/experimental/wrappers/conversion/numpy_to_torch.py rename to gymnasium/experimental/wrappers/numpy_to_torch.py index f91b41ce8..5c08407d9 100644 --- a/gymnasium/experimental/wrappers/conversion/numpy_to_torch.py +++ b/gymnasium/experimental/wrappers/numpy_to_torch.py @@ -18,80 +18,73 @@ try: Device = Union[str, torch.device] except ImportError: - torch, Device = None, None + raise DependencyNotInstalled( + "Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`" + ) + + +__all__ = ["torch_to_numpy", "numpy_to_torch", "NumpyToTorchV0"] @functools.singledispatch def torch_to_numpy(value: Any) -> Any: """Converts a PyTorch Tensor into a NumPy Array.""" - if torch is None: - raise DependencyNotInstalled( - "Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`" - ) - else: - raise Exception( - f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github." - ) + raise Exception( + f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github." + ) -if torch is not None: +@torch_to_numpy.register(numbers.Number) +@torch_to_numpy.register(torch.Tensor) +def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any: + """Convert a python number (int, float, complex) and torch.Tensor to a numpy array.""" + return np.array(value) - @torch_to_numpy.register(numbers.Number) - @torch_to_numpy.register(torch.Tensor) - def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any: - """Convert a python number (int, float, complex) and torch.Tensor to a numpy array.""" - return np.array(value) - @torch_to_numpy.register(abc.Mapping) - def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]: - """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" - return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()}) +@torch_to_numpy.register(abc.Mapping) +def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" + return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()}) - @torch_to_numpy.register(abc.Iterable) - def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]: - """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" - return type(value)(torch_to_numpy(v) for v in value) + +@torch_to_numpy.register(abc.Iterable) +def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]: + """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" + return type(value)(torch_to_numpy(v) for v in value) @functools.singledispatch def numpy_to_torch(value: Any, device: Device | None = None) -> Any: """Converts a Jax DeviceArray into a PyTorch Tensor.""" - if torch is None: - raise DependencyNotInstalled( - "Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`" - ) - else: - raise Exception( - f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github." - ) + raise Exception( + f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github." + ) -if torch is not None: +@numpy_to_torch.register(np.ndarray) +def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + assert torch is not None + tensor = torch.tensor(value) + if device: + return tensor.to(device=device) + return tensor - @numpy_to_torch.register(np.ndarray) - def _numpy_to_torch( - value: np.ndarray, device: Device | None = None - ) -> torch.Tensor: - """Converts a Jax DeviceArray into a PyTorch Tensor.""" - assert torch is not None - tensor = torch.tensor(value) - if device: - return tensor.to(device=device) - return tensor - @numpy_to_torch.register(abc.Mapping) - def _numpy_mapping_to_torch( - value: Mapping[str, Any], device: Device | None = None - ) -> Mapping[str, Any]: - """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" - return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()}) +@numpy_to_torch.register(abc.Mapping) +def _numpy_mapping_to_torch( + value: Mapping[str, Any], device: Device | None = None +) -> Mapping[str, Any]: + """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" + return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()}) - @numpy_to_torch.register(abc.Iterable) - def _numpy_iterable_to_torch( - value: Iterable[Any], device: Device | None = None - ) -> Iterable[Any]: - """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" - return type(value)(numpy_to_torch(v, device) for v in value) + +@numpy_to_torch.register(abc.Iterable) +def _numpy_iterable_to_torch( + value: Iterable[Any], device: Device | None = None +) -> Iterable[Any]: + """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" + return type(value)(numpy_to_torch(v, device) for v in value) class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs): @@ -110,11 +103,6 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs): env: The Jax-based environment to wrap device: The device the torch Tensors should be moved to """ - if torch is None: - raise DependencyNotInstalled( - "torch is not installed, run `pip install torch`" - ) - gym.utils.RecordConstructorArgs.__init__(self, device=device) gym.Wrapper.__init__(self, env) diff --git a/tests/experimental/wrappers/test_init_shorten_import.py b/tests/experimental/wrappers/test_init_shorten_import.py new file mode 100644 index 000000000..09d9d30f2 --- /dev/null +++ b/tests/experimental/wrappers/test_init_shorten_import.py @@ -0,0 +1,23 @@ +"""Tests that all shortened imports for wrappers all work.""" + +import pytest + +import gymnasium +from gymnasium.experimental.wrappers import ( + _wrapper_to_class, # pyright: ignore[reportPrivateUsage] +) +from gymnasium.experimental.wrappers import __all__ + + +def test_all_wrapper_shorten(): + """Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion.""" + all_wrappers = set(__all__) + all_wrappers.remove("vector") + assert all_wrappers == set(_wrapper_to_class.keys()) + + +@pytest.mark.parametrize("wrapper_name", __all__) +def test_all_wrappers_shortened(wrapper_name): + """Check that each element of the `__all__` wrappers can be loaded.""" + if wrapper_name != "vector": + assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None diff --git a/tests/experimental/wrappers/test_jax_to_numpy.py b/tests/experimental/wrappers/test_jax_to_numpy.py index 1c9815385..b32c0dab1 100644 --- a/tests/experimental/wrappers/test_jax_to_numpy.py +++ b/tests/experimental/wrappers/test_jax_to_numpy.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import numpy as np import pytest -from gymnasium.experimental.wrappers.conversion.jax_to_numpy import ( +from gymnasium.experimental.wrappers.jax_to_numpy import ( JaxToNumpyV0, jax_to_numpy, numpy_to_jax, diff --git a/tests/experimental/wrappers/test_jax_to_torch.py b/tests/experimental/wrappers/test_jax_to_torch.py index d4899c0fc..c156e4771 100644 --- a/tests/experimental/wrappers/test_jax_to_torch.py +++ b/tests/experimental/wrappers/test_jax_to_torch.py @@ -5,7 +5,7 @@ import numpy as np import pytest import torch -from gymnasium.experimental.wrappers.conversion.jax_to_torch import ( +from gymnasium.experimental.wrappers.jax_to_torch import ( JaxToTorchV0, jax_to_torch, torch_to_jax,