mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Add __get_attr__
for experimental wrappers for generic solution to optimise extra module imports (#392)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Root `__init__` of the gymnasium module setting the `__all__` of gymnasium modules."""
|
||||
# isort: skip_file
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
|
||||
from gymnasium.core import (
|
||||
Env,
|
||||
@@ -17,7 +18,9 @@ from gymnasium.envs.registration import (
|
||||
pprint_registry,
|
||||
make_vec,
|
||||
)
|
||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
|
||||
|
||||
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
|
||||
from gymnasium import envs
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -37,6 +40,7 @@ __all__ = [
|
||||
"pprint_registry",
|
||||
# module folders
|
||||
"envs",
|
||||
"experimental",
|
||||
"spaces",
|
||||
"utils",
|
||||
"vector",
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Root __init__ of the gym experimental wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental import functional
|
||||
from gymnasium.experimental import functional, wrappers
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||
@@ -17,4 +17,6 @@ __all__ = [
|
||||
"VectorWrapper",
|
||||
"SyncVectorEnv",
|
||||
"AsyncVectorEnv",
|
||||
# wrappers
|
||||
"wrappers",
|
||||
]
|
||||
|
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.vector.utils import batch_space
|
||||
|
||||
|
@@ -1,53 +1,11 @@
|
||||
"""Experimental Wrappers."""
|
||||
# isort: skip_file
|
||||
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
|
||||
from gymnasium.experimental.wrappers.lambda_action import (
|
||||
LambdaActionV0,
|
||||
ClipActionV0,
|
||||
RescaleActionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_observations import (
|
||||
LambdaObservationV0,
|
||||
FilterObservationV0,
|
||||
FlattenObservationV0,
|
||||
GrayscaleObservationV0,
|
||||
ResizeObservationV0,
|
||||
ReshapeObservationV0,
|
||||
RescaleObservationV0,
|
||||
DtypeObservationV0,
|
||||
PixelObservationV0,
|
||||
NormalizeObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_reward import (
|
||||
ClipRewardV0,
|
||||
LambdaRewardV0,
|
||||
NormalizeRewardV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||
TimeAwareObservationV0,
|
||||
DelayObservationV0,
|
||||
FrameStackObservationV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
|
||||
from gymnasium.experimental.wrappers.common import (
|
||||
PassiveEnvCheckerV0,
|
||||
OrderEnforcingV0,
|
||||
AutoresetV0,
|
||||
RecordEpisodeStatisticsV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.rendering import (
|
||||
RenderCollectionV0,
|
||||
RecordVideoV0,
|
||||
HumanRenderingV0,
|
||||
)
|
||||
import importlib
|
||||
|
||||
from gymnasium.experimental.wrappers.vector import (
|
||||
VectorRecordEpisodeStatistics,
|
||||
VectorListInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"vector",
|
||||
# --- Observation wrappers ---
|
||||
"LambdaObservationV0",
|
||||
"FilterObservationV0",
|
||||
@@ -82,7 +40,80 @@ __all__ = [
|
||||
"RenderCollectionV0",
|
||||
"RecordVideoV0",
|
||||
"HumanRenderingV0",
|
||||
# --- Vector ---
|
||||
"VectorRecordEpisodeStatistics",
|
||||
"VectorListInfo",
|
||||
# --- Conversion ---
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
"NumpyToTorchV0",
|
||||
]
|
||||
|
||||
|
||||
_wrapper_to_class = {
|
||||
# lambda_action.py
|
||||
"LambdaActionV0": "lambda_action",
|
||||
"ClipActionV0": "lambda_action",
|
||||
"RescaleActionV0": "lambda_action",
|
||||
# lambda_observations.py
|
||||
"LambdaObservationV0": "lambda_observations",
|
||||
"FilterObservationV0": "lambda_observations",
|
||||
"FlattenObservationV0": "lambda_observations",
|
||||
"GrayscaleObservationV0": "lambda_observations",
|
||||
"ResizeObservationV0": "lambda_observations",
|
||||
"ReshapeObservationV0": "lambda_observations",
|
||||
"RescaleObservationV0": "lambda_observations",
|
||||
"DtypeObservationV0": "lambda_observations",
|
||||
"PixelObservationV0": "lambda_observations",
|
||||
"NormalizeObservationV0": "lambda_observations",
|
||||
# lambda_reward.py
|
||||
"ClipRewardV0": "lambda_reward",
|
||||
"LambdaRewardV0": "lambda_reward",
|
||||
"NormalizeRewardV0": "lambda_reward",
|
||||
# stateful_action
|
||||
"StickyActionV0": "stateful_action",
|
||||
# stateful_observation
|
||||
"TimeAwareObservationV0": "stateful_observation",
|
||||
"DelayObservationV0": "stateful_observation",
|
||||
"FrameStackObservationV0": "stateful_observation",
|
||||
# atari_preprocessing
|
||||
"AtariPreprocessingV0": "atari_preprocessing",
|
||||
# common
|
||||
"PassiveEnvCheckerV0": "common",
|
||||
"OrderEnforcingV0": "common",
|
||||
"AutoresetV0": "common",
|
||||
"RecordEpisodeStatisticsV0": "common",
|
||||
# rendering
|
||||
"RenderCollectionV0": "rendering",
|
||||
"RecordVideoV0": "rendering",
|
||||
"HumanRenderingV0": "rendering",
|
||||
# jax_to_numpy
|
||||
"JaxToNumpyV0": "jax_to_numpy",
|
||||
# "jax_to_numpy": "jax_to_numpy",
|
||||
# "numpy_to_jax": "jax_to_numpy",
|
||||
# jax_to_torch
|
||||
"JaxToTorchV0": "jax_to_torch",
|
||||
# "jax_to_torch": "jax_to_torch",
|
||||
# "torch_to_jax": "jax_to_torch",
|
||||
# numpy_to_torch
|
||||
"NumpyToTorchV0": "numpy_to_torch",
|
||||
# "torch_to_numpy": "numpy_to_torch",
|
||||
# "numpy_to_torch": "numpy_to_torch",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules.
|
||||
|
||||
This optimises the loading of gymnasium.
|
||||
|
||||
Args:
|
||||
name: The name of a wrapper to load
|
||||
|
||||
Returns:
|
||||
Wrapper
|
||||
"""
|
||||
if name in _wrapper_to_class:
|
||||
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}"
|
||||
module = importlib.import_module(import_stmt)
|
||||
return getattr(module, name)
|
||||
# add helpful error message if version number has changed
|
||||
else:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
@@ -8,7 +8,9 @@ from gymnasium.spaces import Box
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||
)
|
||||
|
||||
|
||||
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
@@ -72,10 +74,6 @@ class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
if cv2 is None:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||
)
|
||||
assert frame_skip > 0
|
||||
assert screen_size > 0
|
||||
assert noop_max >= 0
|
||||
|
@@ -1 +0,0 @@
|
||||
"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually."""
|
@@ -1,205 +0,0 @@
|
||||
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
|
||||
# for an underlying Jax environment then convert the return observations from Jax arrays
|
||||
# back to torch tensors.
|
||||
#
|
||||
# Functionality for converting between torch and jax types originally copied from
|
||||
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||
|
||||
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
from jax import dlpack as jax_dlpack
|
||||
except ImportError:
|
||||
jnp, jax_dlpack = None, None
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
except ImportError:
|
||||
torch, torch_dlpack, Device = None, None, None
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_jax(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
|
||||
)
|
||||
elif jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
if torch is not None and jnp is not None:
|
||||
|
||||
@torch_to_jax.register(numbers.Number)
|
||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||
"""Convert a python number (int, float, complex) to a jax array."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
@torch_to_jax.register(torch.Tensor)
|
||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
assert torch_dlpack is not None and jax_dlpack is not None
|
||||
tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||
value
|
||||
)
|
||||
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor
|
||||
)
|
||||
return tensor
|
||||
|
||||
@torch_to_jax.register(abc.Mapping)
|
||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||
|
||||
@torch_to_jax.register(abc.Iterable)
|
||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
|
||||
)
|
||||
elif jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
if torch is not None and jnp is not None:
|
||||
|
||||
@jax_to_torch.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_torch(
|
||||
value: jnp.DeviceArray, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
assert jax_dlpack is not None and torch_dlpack is not None
|
||||
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
||||
value
|
||||
)
|
||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
@jax_to_torch.register(abc.Mapping)
|
||||
def _jax_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
@jax_to_torch.register(abc.Iterable)
|
||||
def _jax_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
|
||||
Note:
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"torch is not installed, run `pip install torch`"
|
||||
)
|
||||
elif jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
action: The action to perform as a PyTorch Tensor
|
||||
|
||||
Returns:
|
||||
The next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_torch(obs, self.device),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_jax(options)
|
||||
|
||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a NumPy array."""
|
||||
return jax_to_numpy(self.env.render())
|
@@ -16,80 +16,73 @@ from gymnasium.error import DependencyNotInstalled
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
except ImportError:
|
||||
# We handle the error internal to the relative functions
|
||||
jnp = None
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
|
||||
__all__ = ["jax_to_numpy", "numpy_to_jax", "JaxToNumpyV0"]
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_jax(value: Any) -> Any:
|
||||
"""Converts a value to a Jax DeviceArray."""
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
raise Exception(
|
||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
if jnp is not None:
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _number_ndarray_numpy_to_jax(
|
||||
value: np.ndarray | numbers.Number,
|
||||
) -> jnp.DeviceArray:
|
||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _number_ndarray_numpy_to_jax(
|
||||
value: np.ndarray | numbers.Number,
|
||||
) -> jnp.DeviceArray:
|
||||
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_numpy(value: Any) -> Any:
|
||||
"""Converts a value to a numpy array."""
|
||||
if jnp is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
if jnp is not None:
|
||||
@jax_to_numpy.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||
"""Converts a Jax DeviceArray to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
@jax_to_numpy.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||
"""Converts a Jax DeviceArray to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jnp.DeviceArray | Any]
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jnp.DeviceArray | Any]
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jnp.DeviceArray | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
class JaxToNumpyV0(
|
178
gymnasium/experimental/wrappers/jax_to_torch.py
Normal file
178
gymnasium/experimental/wrappers/jax_to_torch.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
|
||||
# for an underlying Jax environment then convert the return observations from Jax arrays
|
||||
# back to torch tensors.
|
||||
#
|
||||
# Functionality for converting between torch and jax types originally copied from
|
||||
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||
|
||||
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
from jax import dlpack as jax_dlpack
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
|
||||
)
|
||||
|
||||
try:
|
||||
import torch
|
||||
from torch.utils import dlpack as torch_dlpack
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"]
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_jax(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@torch_to_jax.register(numbers.Number)
|
||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||
"""Convert a python number (int, float, complex) to a jax array."""
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@torch_to_jax.register(torch.Tensor)
|
||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
|
||||
return tensor
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Mapping)
|
||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@torch_to_jax.register(abc.Iterable)
|
||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@jax_to_torch.register(jnp.DeviceArray)
|
||||
def _devicearray_jax_to_torch(
|
||||
value: jnp.DeviceArray, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
assert jax_dlpack is not None and torch_dlpack is not None
|
||||
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Mapping)
|
||||
def _jax_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_torch.register(abc.Iterable)
|
||||
def _jax_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(jax_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||
|
||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||
|
||||
Note:
|
||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
self.device: Device | None = device
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Performs the given action within the environment.
|
||||
|
||||
Args:
|
||||
action: The action to perform as a PyTorch Tensor
|
||||
|
||||
Returns:
|
||||
The next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_torch(obs, self.device),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_torch(info, self.device),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning PyTorch-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = torch_to_jax(options)
|
||||
|
||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a NumPy array."""
|
||||
return jax_to_numpy(self.env.render())
|
@@ -18,80 +18,73 @@ try:
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
except ImportError:
|
||||
torch, Device = None, None
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["torch_to_numpy", "numpy_to_torch", "NumpyToTorchV0"]
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def torch_to_numpy(value: Any) -> Any:
|
||||
"""Converts a PyTorch Tensor into a NumPy Array."""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
raise Exception(
|
||||
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
if torch is not None:
|
||||
@torch_to_numpy.register(numbers.Number)
|
||||
@torch_to_numpy.register(torch.Tensor)
|
||||
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
|
||||
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
@torch_to_numpy.register(numbers.Number)
|
||||
@torch_to_numpy.register(torch.Tensor)
|
||||
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
|
||||
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
@torch_to_numpy.register(abc.Mapping)
|
||||
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
||||
@torch_to_numpy.register(abc.Mapping)
|
||||
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
@torch_to_numpy.register(abc.Iterable)
|
||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_numpy(v) for v in value)
|
||||
|
||||
@torch_to_numpy.register(abc.Iterable)
|
||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||
return type(value)(torch_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
raise Exception(
|
||||
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
if torch is not None:
|
||||
@numpy_to_torch.register(np.ndarray)
|
||||
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
assert torch is not None
|
||||
tensor = torch.tensor(value)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
@numpy_to_torch.register(np.ndarray)
|
||||
def _numpy_to_torch(
|
||||
value: np.ndarray, device: Device | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||
assert torch is not None
|
||||
tensor = torch.tensor(value)
|
||||
if device:
|
||||
return tensor.to(device=device)
|
||||
return tensor
|
||||
|
||||
@numpy_to_torch.register(abc.Mapping)
|
||||
def _numpy_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
||||
@numpy_to_torch.register(abc.Mapping)
|
||||
def _numpy_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
@numpy_to_torch.register(abc.Iterable)
|
||||
def _numpy_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||
|
||||
@numpy_to_torch.register(abc.Iterable)
|
||||
def _numpy_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||
|
||||
|
||||
class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
@@ -110,11 +103,6 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
env: The Jax-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
if torch is None:
|
||||
raise DependencyNotInstalled(
|
||||
"torch is not installed, run `pip install torch`"
|
||||
)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
23
tests/experimental/wrappers/test_init_shorten_import.py
Normal file
23
tests/experimental/wrappers/test_init_shorten_import.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Tests that all shortened imports for wrappers all work."""
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.experimental.wrappers import (
|
||||
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from gymnasium.experimental.wrappers import __all__
|
||||
|
||||
|
||||
def test_all_wrapper_shorten():
|
||||
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion."""
|
||||
all_wrappers = set(__all__)
|
||||
all_wrappers.remove("vector")
|
||||
assert all_wrappers == set(_wrapper_to_class.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapper_name", __all__)
|
||||
def test_all_wrappers_shortened(wrapper_name):
|
||||
"""Check that each element of the `__all__` wrappers can be loaded."""
|
||||
if wrapper_name != "vector":
|
||||
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
@@ -4,7 +4,7 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import (
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import (
|
||||
JaxToNumpyV0,
|
||||
jax_to_numpy,
|
||||
numpy_to_jax,
|
||||
|
@@ -5,7 +5,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_torch import (
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import (
|
||||
JaxToTorchV0,
|
||||
jax_to_torch,
|
||||
torch_to_jax,
|
||||
|
Reference in New Issue
Block a user