diff --git a/docs/api/experimental.md b/docs/api/experimental.md index c09dd3a18..50e22416c 100644 --- a/docs/api/experimental.md +++ b/docs/api/experimental.md @@ -175,21 +175,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t * - :class:`wrappers.EnvCompatibility` - Moved to `shimmy `_ - Not Implemented - * - :class:`RecordEpisodeStatistics` + * - :class:`wrappers.RecordEpisodeStatistics` - RecordEpisodeStatistics - VectorRecordEpisodeStatistics - * - :class:`RenderCollection` + * - :class:`wrappers.RenderCollection` - RenderCollection - VectorRenderCollection - * - :class:`HumanRendering` + * - :class:`wrappers.HumanRendering` - HumanRendering - Not Implemented * - Not Implemented - - JaxToNumpy - - VectorJaxToNumpy + - :class:`experimental.wrappers.JaxToNumpy` + - VectorJaxToNumpy (*) * - Not Implemented - - JaxToTorch - - VectorJaxToTorch + - :class:`experimental.wrappers.JaxToTorch` + - VectorJaxToTorch (*) ``` ### Vector Only Wrappers diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 06ff71f66..988247db6 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -12,6 +12,8 @@ from gymnasium.experimental.wrappers.lambda_action import ( ) from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0 from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 +from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0 +from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0 from gymnasium.experimental.wrappers.sticky_action import StickyActionV0 from gymnasium.experimental.wrappers.time_aware_observation import ( TimeAwareObservationV0, @@ -32,4 +34,7 @@ __all__ = [ # Lambda Reward "LambdaRewardV0", "ClipRewardV0", + # Jax conversion wrappers + "JaxToNumpyV0", + "JaxToTorchV0", ] diff --git a/gymnasium/experimental/wrappers/numpy_to_jax.py b/gymnasium/experimental/wrappers/numpy_to_jax.py new file mode 100644 index 000000000..f352766be --- /dev/null +++ b/gymnasium/experimental/wrappers/numpy_to_jax.py @@ -0,0 +1,134 @@ +"""Helper functions and wrapper class for converting between numpy and Jax.""" +from __future__ import annotations + +import functools +import numbers +from collections import abc +from typing import Any, Iterable, Mapping, SupportsFloat + +import jax.numpy as jnp +import numpy as np + +from gymnasium import Env, Wrapper +from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType + + +@functools.singledispatch +def numpy_to_jax(value: Any) -> Any: + """Converts a value to a Jax DeviceArray.""" + raise Exception( + f"No conversion for Numpy to Jax registered for type: {type(value)}" + ) + + +@numpy_to_jax.register(numbers.Number) +@numpy_to_jax.register(np.ndarray) +def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray: + """Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray.""" + return jnp.array(value) + + +@numpy_to_jax.register(abc.Mapping) +def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays.""" + return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) + + +@numpy_to_jax.register(abc.Iterable) +def _iterable_numpy_to_jax( + value: Iterable[np.ndarray | Any], +) -> Iterable[jnp.DeviceArray | Any]: + """Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays.""" + return type(value)(numpy_to_jax(v) for v in value) + + +@functools.singledispatch +def jax_to_numpy(value: Any) -> Any: + """Converts a value to a numpy array.""" + raise Exception( + f"No conversion for Jax to Numpy registered for type: {type(value)}" + ) + + +@jax_to_numpy.register(jnp.DeviceArray) +def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray: + """Converts a Jax DeviceArray to a numpy array.""" + return np.array(value) + + +@jax_to_numpy.register(abc.Mapping) +def _mapping_jax_to_numpy( + value: Mapping[str, jnp.DeviceArray | Any] +) -> Mapping[str, np.ndarray | Any]: + """Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays.""" + return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) + + +@jax_to_numpy.register(abc.Iterable) +def _iterable_jax_to_numpy( + value: Iterable[np.ndarray | Any], +) -> Iterable[jnp.DeviceArray | Any]: + """Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays.""" + return type(value)(jax_to_numpy(v) for v in value) + + +class JaxToNumpyV0(Wrapper): + """Wraps a jax environment so that it can be interacted with through numpy arrays. + + Actions must be provided as numpy arrays and observations will be returned as numpy arrays. + + Notes: + The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa. + The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)`` + """ + + def __init__(self, env: Env): + """Wraps an environment such that the input and outputs are numpy arrays. + + Args: + env: the environment to wrap + """ + super().__init__(env) + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Transforms the action to a jax array . + + Args: + action: the action to perform as a numpy array + + Returns: + A tuple containing the next observation, reward, termination, truncation, and extra info. + """ + jax_action = numpy_to_jax(action) + obs, reward, terminated, truncated, info = self.env.step(jax_action) + + return ( + jax_to_numpy(obs), + float(reward), + bool(terminated), + bool(truncated), + jax_to_numpy(info), + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment returning numpy-based observation and info. + + Args: + seed: The seed for resetting the environment + options: The options for resetting the environment, these are converted to jax arrays. + + Returns: + Numpy-based observations and info + """ + if options: + options = numpy_to_jax(options) + + return jax_to_numpy(self.env.reset(seed=seed, options=options)) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Returns the rendered frames as a numpy array.""" + return jax_to_numpy(self.env.render()) diff --git a/gymnasium/experimental/wrappers/torch_to_jax.py b/gymnasium/experimental/wrappers/torch_to_jax.py new file mode 100644 index 000000000..a33a11b52 --- /dev/null +++ b/gymnasium/experimental/wrappers/torch_to_jax.py @@ -0,0 +1,163 @@ +# This wrapper will convert torch inputs for the actions and observations to Jax arrays +# for an underlying Jax environment then convert the return observations from Jax arrays +# back to torch tensors. +# +# Functionality for converting between torch and jax types originally copied from +# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py +# Under the Apache 2.0 license. Copyright is held by the authors + +"""Helper functions and wrapper class for converting between PyTorch and Jax.""" +from __future__ import annotations + +import functools +import numbers +from collections import abc +from typing import Any, Iterable, Mapping, SupportsFloat, Union + +import jax.numpy as jnp +from jax import dlpack as jax_dlpack + +from gymnasium import Env, Wrapper +from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType +from gymnasium.error import DependencyNotInstalled +from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy + +try: + import torch + from torch.utils import dlpack as torch_dlpack +except ImportError: + raise DependencyNotInstalled("torch is not installed, run `pip install torch`") + + +Device = Union[str, torch.device] + + +@functools.singledispatch +def torch_to_jax(value: Any) -> Any: + """Converts a PyTorch Tensor into a Jax DeviceArray.""" + raise Exception( + f"No conversion for PyTorch to Jax registered for type: {type(value)}" + ) + + +@torch_to_jax.register(numbers.Number) +def _number_torch_to_jax(value: numbers.Number) -> Any: + return jnp.array(value) + + +@torch_to_jax.register(torch.Tensor) +def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray: + """Converts a PyTorch Tensor into a Jax DeviceArray.""" + tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] + tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage] + return tensor + + +@torch_to_jax.register(abc.Mapping) +def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: + """Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays.""" + return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) + + +@torch_to_jax.register(abc.Iterable) +def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]: + """Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays.""" + return type(value)(torch_to_jax(v) for v in value) + + +@functools.singledispatch +def jax_to_torch(value: Any, device: Device | None = None) -> Any: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + raise Exception( + f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}" + ) + + +@jax_to_torch.register(jnp.DeviceArray) +def _devicearray_jax_to_torch( + value: jnp.DeviceArray, device: Device | None = None +) -> torch.Tensor: + """Converts a Jax DeviceArray into a PyTorch Tensor.""" + dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage] + tensor = torch_dlpack.from_dlpack(dlpack) + if device: + return tensor.to(device=device) + return tensor + + +@jax_to_torch.register(abc.Mapping) +def _jax_mapping_to_torch( + value: Mapping[str, Any], device: Device | None = None +) -> Mapping[str, Any]: + """Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors.""" + return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()}) + + +@jax_to_torch.register(abc.Iterable) +def _jax_iterable_to_torch( + value: Iterable[Any], device: Device | None = None +) -> Iterable[Any]: + """Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors.""" + return type(value)(jax_to_torch(v, device) for v in value) + + +class JaxToTorchV0(Wrapper): + """Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors. + + Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. + + For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. + """ + + def __init__(self, env: Env, device: Device | None = None): + """Wrapper class to change inputs and outputs of environment to PyTorch tensors. + + Args: + env: The Jax-based environment to wrap + device: The device the torch Tensors should be moved to + """ + super().__init__(env) + self.device: Device | None = device + + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: + """Performs the given action within the environment. + + Args: + action: The action to perform as a PyTorch Tensor + + Returns: + The next observation, reward, termination, truncation, and extra info + """ + jax_action = torch_to_jax(action) + obs, reward, terminated, truncated, info = self.env.step(jax_action) + + return ( + jax_to_torch(obs, self.device), + float(reward), + bool(terminated), + bool(truncated), + jax_to_torch(info, self.device), + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment returning PyTorch-based observation and info. + + Args: + seed: The seed for resetting the environment + options: The options for resetting the environment, these are converted to jax arrays. + + Returns: + PyTorch-based observations and info + """ + if options: + options = torch_to_jax(options) + + return jax_to_torch(self.env.reset(seed=seed, options=options), self.device) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Returns the rendered frames as a NumPy array.""" + return jax_to_numpy(self.env.render()) diff --git a/setup.py b/setup.py index ab6d3e8db..ddbc649f4 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,14 @@ extras: Dict[str, List[str]] = { "mujoco": ["mujoco>=2.3.0", "imageio>=2.14.1"], "toy_text": ["pygame==2.1.0"], "jax": ["jax==0.3.20", "jaxlib==0.3.20"], - "other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"], + "other": [ + "lz4>=3.1.0", + "opencv-python>=3.0", + "matplotlib>=3.0", + "moviepy>=1.0.0", + "tensorflow>=2.1.0", + "torch>=1.0.0", + ], } # All dependency groups - accept rom license as requires user to run diff --git a/tests/experimental/wrappers/test_numpy_to_jax.py b/tests/experimental/wrappers/test_numpy_to_jax.py new file mode 100644 index 000000000..f687b2177 --- /dev/null +++ b/tests/experimental/wrappers/test_numpy_to_jax.py @@ -0,0 +1,83 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +from gymnasium.experimental.wrappers import JaxToNumpyV0 +from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax +from gymnasium.utils.env_checker import data_equivalence +from tests.testing_env import GenericTestEnv + + +@pytest.mark.parametrize( + "value, expected_value", + [ + (1.0, np.array(1.0)), + (2, np.array(2)), + ((3.0, 4), (np.array(3.0), np.array(4))), + ([3.0, 4], [np.array(3.0), np.array(4)]), + ( + { + "a": 6.0, + "b": 7, + }, + {"a": np.array(6.0), "b": np.array(7)}, + ), + (np.array(1.0), np.array(1.0)), + (np.array([1, 2]), np.array([1, 2])), + (np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])), + ( + {"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}}, + { + "a": (np.array(1), np.array(2.0), np.array([3, 4])), + "b": {"c": np.array(5)}, + }, + ), + ], +) +def test_roundtripping(value, expected_value): + """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" + assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value) + + +def jax_reset_func(self, seed=None, options=None): + return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])} + + +def jax_step_func(self, action): + assert isinstance(action, jnp.DeviceArray), type(action) + return ( + jnp.array([1, 2, 3]), + jnp.array(5.0), + jnp.array(True), + jnp.array(False), + {"data": jnp.array([1.0, 2.0])}, + ) + + +def test_jax_to_numpy(): + jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func) + + # Check that the reset and step for jax environment are as expected + obs, info = jax_env.reset() + assert isinstance(obs, jnp.DeviceArray) + assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray) + + obs, reward, terminated, truncated, info = jax_env.step(jnp.array([1, 2])) + assert isinstance(obs, jnp.DeviceArray) + assert isinstance(reward, jnp.DeviceArray) + assert isinstance(terminated, jnp.DeviceArray) and isinstance( + truncated, jnp.DeviceArray + ) + assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray) + + # Check that the wrapped version is correct. + numpy_env = JaxToNumpyV0(jax_env) + obs, info = numpy_env.reset() + assert isinstance(obs, np.ndarray) + assert isinstance(info, dict) and isinstance(info["data"], np.ndarray) + + obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2])) + assert isinstance(obs, np.ndarray) + assert isinstance(reward, float) + assert isinstance(terminated, bool) and isinstance(truncated, bool) + assert isinstance(info, dict) and isinstance(info["data"], np.ndarray) diff --git a/tests/experimental/wrappers/test_torch_to_jax.py b/tests/experimental/wrappers/test_torch_to_jax.py new file mode 100644 index 000000000..cb6cdf030 --- /dev/null +++ b/tests/experimental/wrappers/test_torch_to_jax.py @@ -0,0 +1,103 @@ +import jax.numpy as jnp +import numpy as np +import pytest +import torch + +from gymnasium.experimental.wrappers import JaxToTorchV0 +from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax +from tests.testing_env import GenericTestEnv + + +def torch_data_equivalence(data_1, data_2) -> bool: + if type(data_1) == type(data_2): + if isinstance(data_1, dict): + return data_1.keys() == data_2.keys() and all( + torch_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys() + ) + elif isinstance(data_1, (tuple, list)): + return len(data_1) == len(data_2) and all( + torch_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) + ) + elif isinstance(data_1, torch.Tensor): + return data_1.shape == data_2.shape and np.allclose( + data_1, data_2, atol=0.00001 + ) + else: + return data_1 == data_2 + else: + return False + + +@pytest.mark.parametrize( + "value, expected_value", + [ + (1.0, torch.tensor(1.0)), + (2, torch.tensor(2)), + ((3.0, 4), (torch.tensor(3.0), torch.tensor(4))), + ([3.0, 4], [torch.tensor(3.0), torch.tensor(4)]), + ( + { + "a": 6.0, + "b": 7, + }, + {"a": torch.tensor(6.0), "b": torch.tensor(7)}, + ), + (torch.tensor(1.0), torch.tensor(1.0)), + (torch.tensor([1, 2]), torch.tensor([1, 2])), + (torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])), + ( + {"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}}, + { + "a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])), + "b": {"c": torch.tensor(5)}, + }, + ), + ], +) +def test_roundtripping(value, expected_value): + """We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper.""" + assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value) + + +def jax_reset_func(self, seed=None, options=None): + return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])} + + +def jax_step_func(self, action): + assert isinstance(action, jnp.DeviceArray), type(action) + return ( + jnp.array([1, 2, 3]), + jnp.array(5.0), + jnp.array(True), + jnp.array(False), + {"data": jnp.array([1.0, 2.0])}, + ) + + +def test_jax_to_torch(): + env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func) + + # Check that the reset and step for jax environment are as expected + obs, info = env.reset() + assert isinstance(obs, jnp.DeviceArray) + assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray) + + obs, reward, terminated, truncated, info = env.step(jnp.array([1, 2])) + assert isinstance(obs, jnp.DeviceArray) + assert isinstance(reward, jnp.DeviceArray) + assert isinstance(terminated, jnp.DeviceArray) and isinstance( + truncated, jnp.DeviceArray + ) + assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray) + + # Check that the wrapped version is correct. + wrapped_env = JaxToTorchV0(env) + obs, info = wrapped_env.reset() + assert isinstance(obs, torch.Tensor) + assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor) + + obs, reward, terminated, truncated, info = wrapped_env.step(torch.tensor([1, 2])) + assert isinstance(obs, torch.Tensor) + assert isinstance(reward, float) + assert isinstance(terminated, bool) and isinstance(truncated, bool) + assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)