Add __get_attr__ for experimental wrappers for generic solution to optimise extra module imports (#392)

This commit is contained in:
Mark Towers
2023-03-17 21:00:48 +00:00
committed by GitHub
parent 24a55188e5
commit a6672bad9e
13 changed files with 389 additions and 378 deletions

View File

@@ -1,5 +1,6 @@
"""Root `__init__` of the gymnasium module setting the `__all__` of gymnasium modules."""
# isort: skip_file
# pyright: reportUnsupportedDunderAll=false
from gymnasium.core import (
Env,
@@ -17,7 +18,9 @@ from gymnasium.envs.registration import (
pprint_registry,
make_vec,
)
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
from gymnasium import envs
__all__ = [
@@ -37,6 +40,7 @@ __all__ = [
"pprint_registry",
# module folders
"envs",
"experimental",
"spaces",
"utils",
"vector",

View File

@@ -1,7 +1,7 @@
"""Root __init__ of the gym experimental wrappers."""
from gymnasium.experimental import functional
from gymnasium.experimental import functional, wrappers
from gymnasium.experimental.functional import FuncEnv
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
@@ -17,4 +17,6 @@ __all__ = [
"VectorWrapper",
"SyncVectorEnv",
"AsyncVectorEnv",
# wrappers
"wrappers",
]

View File

@@ -11,7 +11,7 @@ import numpy as np
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space

View File

@@ -1,53 +1,11 @@
"""Experimental Wrappers."""
# isort: skip_file
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
# pyright: reportUnsupportedDunderAll=false
from gymnasium.experimental.wrappers.lambda_action import (
LambdaActionV0,
ClipActionV0,
RescaleActionV0,
)
from gymnasium.experimental.wrappers.lambda_observations import (
LambdaObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
ResizeObservationV0,
ReshapeObservationV0,
RescaleObservationV0,
DtypeObservationV0,
PixelObservationV0,
NormalizeObservationV0,
)
from gymnasium.experimental.wrappers.lambda_reward import (
ClipRewardV0,
LambdaRewardV0,
NormalizeRewardV0,
)
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
from gymnasium.experimental.wrappers.stateful_observation import (
TimeAwareObservationV0,
DelayObservationV0,
FrameStackObservationV0,
)
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
from gymnasium.experimental.wrappers.common import (
PassiveEnvCheckerV0,
OrderEnforcingV0,
AutoresetV0,
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.rendering import (
RenderCollectionV0,
RecordVideoV0,
HumanRenderingV0,
)
import importlib
from gymnasium.experimental.wrappers.vector import (
VectorRecordEpisodeStatistics,
VectorListInfo,
)
__all__ = [
"vector",
# --- Observation wrappers ---
"LambdaObservationV0",
"FilterObservationV0",
@@ -82,7 +40,80 @@ __all__ = [
"RenderCollectionV0",
"RecordVideoV0",
"HumanRenderingV0",
# --- Vector ---
"VectorRecordEpisodeStatistics",
"VectorListInfo",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]
_wrapper_to_class = {
# lambda_action.py
"LambdaActionV0": "lambda_action",
"ClipActionV0": "lambda_action",
"RescaleActionV0": "lambda_action",
# lambda_observations.py
"LambdaObservationV0": "lambda_observations",
"FilterObservationV0": "lambda_observations",
"FlattenObservationV0": "lambda_observations",
"GrayscaleObservationV0": "lambda_observations",
"ResizeObservationV0": "lambda_observations",
"ReshapeObservationV0": "lambda_observations",
"RescaleObservationV0": "lambda_observations",
"DtypeObservationV0": "lambda_observations",
"PixelObservationV0": "lambda_observations",
"NormalizeObservationV0": "lambda_observations",
# lambda_reward.py
"ClipRewardV0": "lambda_reward",
"LambdaRewardV0": "lambda_reward",
"NormalizeRewardV0": "lambda_reward",
# stateful_action
"StickyActionV0": "stateful_action",
# stateful_observation
"TimeAwareObservationV0": "stateful_observation",
"DelayObservationV0": "stateful_observation",
"FrameStackObservationV0": "stateful_observation",
# atari_preprocessing
"AtariPreprocessingV0": "atari_preprocessing",
# common
"PassiveEnvCheckerV0": "common",
"OrderEnforcingV0": "common",
"AutoresetV0": "common",
"RecordEpisodeStatisticsV0": "common",
# rendering
"RenderCollectionV0": "rendering",
"RecordVideoV0": "rendering",
"HumanRenderingV0": "rendering",
# jax_to_numpy
"JaxToNumpyV0": "jax_to_numpy",
# "jax_to_numpy": "jax_to_numpy",
# "numpy_to_jax": "jax_to_numpy",
# jax_to_torch
"JaxToTorchV0": "jax_to_torch",
# "jax_to_torch": "jax_to_torch",
# "torch_to_jax": "jax_to_torch",
# numpy_to_torch
"NumpyToTorchV0": "numpy_to_torch",
# "torch_to_numpy": "numpy_to_torch",
# "numpy_to_torch": "numpy_to_torch",
}
def __getattr__(name: str):
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules.
This optimises the loading of gymnasium.
Args:
name: The name of a wrapper to load
Returns:
Wrapper
"""
if name in _wrapper_to_class:
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}"
module = importlib.import_module(import_stmt)
return getattr(module, name)
# add helpful error message if version number has changed
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -8,7 +8,9 @@ from gymnasium.spaces import Box
try:
import cv2
except ImportError:
cv2 = None
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
)
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
@@ -72,10 +74,6 @@ class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
)
gym.Wrapper.__init__(self, env)
if cv2 is None:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0

View File

@@ -1 +0,0 @@
"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually."""

View File

@@ -1,205 +0,0 @@
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
# for an underlying Jax environment then convert the return observations from Jax arrays
# back to torch tensors.
#
# Functionality for converting between torch and jax types originally copied from
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
try:
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
except ImportError:
jnp, jax_dlpack = None, None
try:
import torch
from torch.utils import dlpack as torch_dlpack
Device = Union[str, torch.device]
except ImportError:
torch, torch_dlpack, Device = None, None, None
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
)
if torch is not None and jnp is not None:
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a jax array."""
assert jnp is not None
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
assert torch_dlpack is not None and jax_dlpack is not None
tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
value
)
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
tensor
)
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
)
if torch is not None and jnp is not None:
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert jax_dlpack is not None and torch_dlpack is not None
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
value
)
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
if torch is None:
raise DependencyNotInstalled(
"torch is not installed, run `pip install torch`"
)
elif jnp is None:
raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`"
)
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
jax_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a NumPy array."""
return jax_to_numpy(self.env.render())

View File

@@ -16,80 +16,73 @@ from gymnasium.error import DependencyNotInstalled
try:
import jax.numpy as jnp
except ImportError:
# We handle the error internal to the relative functions
jnp = None
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
)
__all__ = ["jax_to_numpy", "numpy_to_jax", "JaxToNumpyV0"]
@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax DeviceArray."""
if jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)
if jnp is not None:
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(
value: np.ndarray | numbers.Number,
) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(numbers.Number)
@numpy_to_jax.register(np.ndarray)
def _number_ndarray_numpy_to_jax(
value: np.ndarray | numbers.Number,
) -> jnp.DeviceArray:
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
return type(value)(numpy_to_jax(v) for v in value)
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
return type(value)(numpy_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
if jnp is None:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
)
else:
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)
if jnp is not None:
@jax_to_numpy.register(jnp.DeviceArray)
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
@jax_to_numpy.register(jnp.DeviceArray)
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
"""Converts a Jax DeviceArray to a numpy array."""
return np.array(value)
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jnp.DeviceArray | Any]
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jnp.DeviceArray | Any]
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
return type(value)(jax_to_numpy(v) for v in value)
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jnp.DeviceArray | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(

View File

@@ -0,0 +1,178 @@
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
# for an underlying Jax environment then convert the return observations from Jax arrays
# back to torch tensors.
#
# Functionality for converting between torch and jax types originally copied from
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
try:
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
except ImportError:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
)
try:
import torch
from torch.utils import dlpack as torch_dlpack
Device = Union[str, torch.device]
except ImportError:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
)
__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"]
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
raise Exception(
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a jax array."""
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert jax_dlpack is not None and torch_dlpack is not None
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
jax_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a NumPy array."""
return jax_to_numpy(self.env.render())

View File

@@ -18,80 +18,73 @@ try:
Device = Union[str, torch.device]
except ImportError:
torch, Device = None, None
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
)
__all__ = ["torch_to_numpy", "numpy_to_torch", "NumpyToTorchV0"]
@functools.singledispatch
def torch_to_numpy(value: Any) -> Any:
"""Converts a PyTorch Tensor into a NumPy Array."""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
)
else:
raise Exception(
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
)
raise Exception(
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
)
if torch is not None:
@torch_to_numpy.register(numbers.Number)
@torch_to_numpy.register(torch.Tensor)
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
return np.array(value)
@torch_to_numpy.register(numbers.Number)
@torch_to_numpy.register(torch.Tensor)
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
return np.array(value)
@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_numpy(v) for v in value)
@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_numpy(v) for v in value)
@functools.singledispatch
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
if torch is None:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`"
)
else:
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
if torch is not None:
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
return tensor.to(device=device)
return tensor
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(
value: np.ndarray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
return tensor.to(device=device)
return tensor
@numpy_to_torch.register(abc.Mapping)
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
@numpy_to_torch.register(abc.Mapping)
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
@numpy_to_torch.register(abc.Iterable)
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(numpy_to_torch(v, device) for v in value)
@numpy_to_torch.register(abc.Iterable)
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(numpy_to_torch(v, device) for v in value)
class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
@@ -110,11 +103,6 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
if torch is None:
raise DependencyNotInstalled(
"torch is not installed, run `pip install torch`"
)
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)

View File

@@ -0,0 +1,23 @@
"""Tests that all shortened imports for wrappers all work."""
import pytest
import gymnasium
from gymnasium.experimental.wrappers import (
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
)
from gymnasium.experimental.wrappers import __all__
def test_all_wrapper_shorten():
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__)
all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys())
@pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded."""
if wrapper_name != "vector":
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None

View File

@@ -4,7 +4,7 @@ import jax.numpy as jnp
import numpy as np
import pytest
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import (
from gymnasium.experimental.wrappers.jax_to_numpy import (
JaxToNumpyV0,
jax_to_numpy,
numpy_to_jax,

View File

@@ -5,7 +5,7 @@ import numpy as np
import pytest
import torch
from gymnasium.experimental.wrappers.conversion.jax_to_torch import (
from gymnasium.experimental.wrappers.jax_to_torch import (
JaxToTorchV0,
jax_to_torch,
torch_to_jax,