Adds wrappers for jax environments converting to numpy and pytorch (#168)

Co-authored-by: Justin Deutsch <djustin8@vt.edu>
Co-authored-by: Gianluca De Cola <42657588+gianlucadecola@users.noreply.github.com>
This commit is contained in:
Mark Towers
2022-12-03 19:46:01 +00:00
committed by GitHub
parent e2caec7c06
commit 9157a97c80
7 changed files with 503 additions and 8 deletions

View File

@@ -175,21 +175,21 @@ Gymnasium already contains a large collection of wrappers, but we believe that t
* - :class:`wrappers.EnvCompatibility`
- Moved to `shimmy <https://github.com/Farama-Foundation/Shimmy/blob/main/shimmy/openai_gym_compatibility.py>`_
- Not Implemented
* - :class:`RecordEpisodeStatistics`
* - :class:`wrappers.RecordEpisodeStatistics`
- RecordEpisodeStatistics
- VectorRecordEpisodeStatistics
* - :class:`RenderCollection`
* - :class:`wrappers.RenderCollection`
- RenderCollection
- VectorRenderCollection
* - :class:`HumanRendering`
* - :class:`wrappers.HumanRendering`
- HumanRendering
- Not Implemented
* - Not Implemented
- JaxToNumpy
- VectorJaxToNumpy
- :class:`experimental.wrappers.JaxToNumpy`
- VectorJaxToNumpy (*)
* - Not Implemented
- JaxToTorch
- VectorJaxToTorch
- :class:`experimental.wrappers.JaxToTorch`
- VectorJaxToTorch (*)
```
### Vector Only Wrappers

View File

@@ -12,6 +12,8 @@ from gymnasium.experimental.wrappers.lambda_action import (
)
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.numpy_to_jax import JaxToNumpyV0
from gymnasium.experimental.wrappers.torch_to_jax import JaxToTorchV0
from gymnasium.experimental.wrappers.sticky_action import StickyActionV0
from gymnasium.experimental.wrappers.time_aware_observation import (
TimeAwareObservationV0,
@@ -32,4 +34,7 @@ __all__ = [
# Lambda Reward
"LambdaRewardV0",
"ClipRewardV0",
# Jax conversion wrappers
"JaxToNumpyV0",
"JaxToTorchV0",
]

View File

@@ -0,0 +1,134 @@
"""Helper functions and wrapper class for converting between numpy and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat
import jax.numpy as jnp
import numpy as np
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax DeviceArray."""
raise Exception(
f"No conversion for Numpy to Jax registered for type: {type(value)}"
)
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
return jnp.array(value)
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
return type(value)(numpy_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
raise Exception(
f"No conversion for Jax to Numpy registered for type: {type(value)}"
)
@jax_to_numpy.register(jnp.DeviceArray)
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jnp.DeviceArray | Any]
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(Wrapper):
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
Notes:
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
"""
def __init__(self, env: Env):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:
env: the environment to wrap
"""
super().__init__(env)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Transforms the action to a jax array .
Args:
action: the action to perform as a numpy array
Returns:
A tuple containing the next observation, reward, termination, truncation, and extra info.
"""
jax_action = numpy_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_numpy(obs),
float(reward),
bool(terminated),
bool(truncated),
jax_to_numpy(info),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())

View File

@@ -0,0 +1,163 @@
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
# for an underlying Jax environment then convert the return observations from Jax arrays
# back to torch tensors.
#
# Functionality for converting between torch and jax types originally copied from
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
try:
import torch
from torch.utils import dlpack as torch_dlpack
except ImportError:
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
Device = Union[str, torch.device]
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
raise Exception(
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
raise Exception(
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
)
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(Wrapper):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
jax_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a NumPy array."""
return jax_to_numpy(self.env.render())

View File

@@ -42,7 +42,14 @@ extras: Dict[str, List[str]] = {
"mujoco": ["mujoco>=2.3.0", "imageio>=2.14.1"],
"toy_text": ["pygame==2.1.0"],
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
"other": [
"lz4>=3.1.0",
"opencv-python>=3.0",
"matplotlib>=3.0",
"moviepy>=1.0.0",
"tensorflow>=2.1.0",
"torch>=1.0.0",
],
}
# All dependency groups - accept rom license as requires user to run

View File

@@ -0,0 +1,83 @@
import jax.numpy as jnp
import numpy as np
import pytest
from gymnasium.experimental.wrappers import JaxToNumpyV0
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy, numpy_to_jax
from gymnasium.utils.env_checker import data_equivalence
from tests.testing_env import GenericTestEnv
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, np.array(1.0)),
(2, np.array(2)),
((3.0, 4), (np.array(3.0), np.array(4))),
([3.0, 4], [np.array(3.0), np.array(4)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": np.array(6.0), "b": np.array(7)},
),
(np.array(1.0), np.array(1.0)),
(np.array([1, 2]), np.array([1, 2])),
(np.array([[1.0], [2.0]]), np.array([[1.0], [2.0]])),
(
{"a": (1, np.array(2.0), np.array([3, 4])), "b": {"c": 5}},
{
"a": (np.array(1), np.array(2.0), np.array([3, 4])),
"b": {"c": np.array(5)},
},
),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert data_equivalence(jax_to_numpy(numpy_to_jax(value)), expected_value)
def jax_reset_func(self, seed=None, options=None):
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action):
assert isinstance(action, jnp.DeviceArray), type(action)
return (
jnp.array([1, 2, 3]),
jnp.array(5.0),
jnp.array(True),
jnp.array(False),
{"data": jnp.array([1.0, 2.0])},
)
def test_jax_to_numpy():
jax_env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = jax_env.reset()
assert isinstance(obs, jnp.DeviceArray)
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
obs, reward, terminated, truncated, info = jax_env.step(jnp.array([1, 2]))
assert isinstance(obs, jnp.DeviceArray)
assert isinstance(reward, jnp.DeviceArray)
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
truncated, jnp.DeviceArray
)
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
# Check that the wrapped version is correct.
numpy_env = JaxToNumpyV0(jax_env)
obs, info = numpy_env.reset()
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)
obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2]))
assert isinstance(obs, np.ndarray)
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and isinstance(info["data"], np.ndarray)

View File

@@ -0,0 +1,103 @@
import jax.numpy as jnp
import numpy as np
import pytest
import torch
from gymnasium.experimental.wrappers import JaxToTorchV0
from gymnasium.experimental.wrappers.torch_to_jax import jax_to_torch, torch_to_jax
from tests.testing_env import GenericTestEnv
def torch_data_equivalence(data_1, data_2) -> bool:
if type(data_1) == type(data_2):
if isinstance(data_1, dict):
return data_1.keys() == data_2.keys() and all(
torch_data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
torch_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
elif isinstance(data_1, torch.Tensor):
return data_1.shape == data_2.shape and np.allclose(
data_1, data_2, atol=0.00001
)
else:
return data_1 == data_2
else:
return False
@pytest.mark.parametrize(
"value, expected_value",
[
(1.0, torch.tensor(1.0)),
(2, torch.tensor(2)),
((3.0, 4), (torch.tensor(3.0), torch.tensor(4))),
([3.0, 4], [torch.tensor(3.0), torch.tensor(4)]),
(
{
"a": 6.0,
"b": 7,
},
{"a": torch.tensor(6.0), "b": torch.tensor(7)},
),
(torch.tensor(1.0), torch.tensor(1.0)),
(torch.tensor([1, 2]), torch.tensor([1, 2])),
(torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [2.0]])),
(
{"a": (1, torch.tensor(2.0), torch.tensor([3, 4])), "b": {"c": 5}},
{
"a": (torch.tensor(1), torch.tensor(2.0), torch.tensor([3, 4])),
"b": {"c": torch.tensor(5)},
},
),
],
)
def test_roundtripping(value, expected_value):
"""We test numpy -> jax -> numpy as this is direction in the NumpyToJax wrapper."""
assert torch_data_equivalence(jax_to_torch(torch_to_jax(value)), expected_value)
def jax_reset_func(self, seed=None, options=None):
return jnp.array([1.0, 2.0, 3.0]), {"data": jnp.array([1, 2, 3])}
def jax_step_func(self, action):
assert isinstance(action, jnp.DeviceArray), type(action)
return (
jnp.array([1, 2, 3]),
jnp.array(5.0),
jnp.array(True),
jnp.array(False),
{"data": jnp.array([1.0, 2.0])},
)
def test_jax_to_torch():
env = GenericTestEnv(reset_fn=jax_reset_func, step_fn=jax_step_func)
# Check that the reset and step for jax environment are as expected
obs, info = env.reset()
assert isinstance(obs, jnp.DeviceArray)
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
obs, reward, terminated, truncated, info = env.step(jnp.array([1, 2]))
assert isinstance(obs, jnp.DeviceArray)
assert isinstance(reward, jnp.DeviceArray)
assert isinstance(terminated, jnp.DeviceArray) and isinstance(
truncated, jnp.DeviceArray
)
assert isinstance(info, dict) and isinstance(info["data"], jnp.DeviceArray)
# Check that the wrapped version is correct.
wrapped_env = JaxToTorchV0(env)
obs, info = wrapped_env.reset()
assert isinstance(obs, torch.Tensor)
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)
obs, reward, terminated, truncated, info = wrapped_env.step(torch.tensor([1, 2]))
assert isinstance(obs, torch.Tensor)
assert isinstance(reward, float)
assert isinstance(terminated, bool) and isinstance(truncated, bool)
assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor)