mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
Add __get_attr__
for experimental wrappers for generic solution to optimise extra module imports (#392)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""Root `__init__` of the gymnasium module setting the `__all__` of gymnasium modules."""
|
"""Root `__init__` of the gymnasium module setting the `__all__` of gymnasium modules."""
|
||||||
# isort: skip_file
|
# isort: skip_file
|
||||||
|
# pyright: reportUnsupportedDunderAll=false
|
||||||
|
|
||||||
from gymnasium.core import (
|
from gymnasium.core import (
|
||||||
Env,
|
Env,
|
||||||
@@ -17,7 +18,9 @@ from gymnasium.envs.registration import (
|
|||||||
pprint_registry,
|
pprint_registry,
|
||||||
make_vec,
|
make_vec,
|
||||||
)
|
)
|
||||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
|
|
||||||
|
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
|
||||||
|
from gymnasium import envs
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -37,6 +40,7 @@ __all__ = [
|
|||||||
"pprint_registry",
|
"pprint_registry",
|
||||||
# module folders
|
# module folders
|
||||||
"envs",
|
"envs",
|
||||||
|
"experimental",
|
||||||
"spaces",
|
"spaces",
|
||||||
"utils",
|
"utils",
|
||||||
"vector",
|
"vector",
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Root __init__ of the gym experimental wrappers."""
|
"""Root __init__ of the gym experimental wrappers."""
|
||||||
|
|
||||||
|
|
||||||
from gymnasium.experimental import functional
|
from gymnasium.experimental import functional, wrappers
|
||||||
from gymnasium.experimental.functional import FuncEnv
|
from gymnasium.experimental.functional import FuncEnv
|
||||||
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||||
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||||
@@ -17,4 +17,6 @@ __all__ = [
|
|||||||
"VectorWrapper",
|
"VectorWrapper",
|
||||||
"SyncVectorEnv",
|
"SyncVectorEnv",
|
||||||
"AsyncVectorEnv",
|
"AsyncVectorEnv",
|
||||||
|
# wrappers
|
||||||
|
"wrappers",
|
||||||
]
|
]
|
||||||
|
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||||
from gymnasium.utils import seeding
|
from gymnasium.utils import seeding
|
||||||
from gymnasium.vector.utils import batch_space
|
from gymnasium.vector.utils import batch_space
|
||||||
|
|
||||||
|
@@ -1,53 +1,11 @@
|
|||||||
"""Experimental Wrappers."""
|
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
|
||||||
# isort: skip_file
|
# pyright: reportUnsupportedDunderAll=false
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers.lambda_action import (
|
import importlib
|
||||||
LambdaActionV0,
|
|
||||||
ClipActionV0,
|
|
||||||
RescaleActionV0,
|
|
||||||
)
|
|
||||||
from gymnasium.experimental.wrappers.lambda_observations import (
|
|
||||||
LambdaObservationV0,
|
|
||||||
FilterObservationV0,
|
|
||||||
FlattenObservationV0,
|
|
||||||
GrayscaleObservationV0,
|
|
||||||
ResizeObservationV0,
|
|
||||||
ReshapeObservationV0,
|
|
||||||
RescaleObservationV0,
|
|
||||||
DtypeObservationV0,
|
|
||||||
PixelObservationV0,
|
|
||||||
NormalizeObservationV0,
|
|
||||||
)
|
|
||||||
from gymnasium.experimental.wrappers.lambda_reward import (
|
|
||||||
ClipRewardV0,
|
|
||||||
LambdaRewardV0,
|
|
||||||
NormalizeRewardV0,
|
|
||||||
)
|
|
||||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
|
||||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
|
||||||
TimeAwareObservationV0,
|
|
||||||
DelayObservationV0,
|
|
||||||
FrameStackObservationV0,
|
|
||||||
)
|
|
||||||
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
|
|
||||||
from gymnasium.experimental.wrappers.common import (
|
|
||||||
PassiveEnvCheckerV0,
|
|
||||||
OrderEnforcingV0,
|
|
||||||
AutoresetV0,
|
|
||||||
RecordEpisodeStatisticsV0,
|
|
||||||
)
|
|
||||||
from gymnasium.experimental.wrappers.rendering import (
|
|
||||||
RenderCollectionV0,
|
|
||||||
RecordVideoV0,
|
|
||||||
HumanRenderingV0,
|
|
||||||
)
|
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers.vector import (
|
|
||||||
VectorRecordEpisodeStatistics,
|
|
||||||
VectorListInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"vector",
|
||||||
# --- Observation wrappers ---
|
# --- Observation wrappers ---
|
||||||
"LambdaObservationV0",
|
"LambdaObservationV0",
|
||||||
"FilterObservationV0",
|
"FilterObservationV0",
|
||||||
@@ -82,7 +40,80 @@ __all__ = [
|
|||||||
"RenderCollectionV0",
|
"RenderCollectionV0",
|
||||||
"RecordVideoV0",
|
"RecordVideoV0",
|
||||||
"HumanRenderingV0",
|
"HumanRenderingV0",
|
||||||
# --- Vector ---
|
# --- Conversion ---
|
||||||
"VectorRecordEpisodeStatistics",
|
"JaxToNumpyV0",
|
||||||
"VectorListInfo",
|
"JaxToTorchV0",
|
||||||
|
"NumpyToTorchV0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_wrapper_to_class = {
|
||||||
|
# lambda_action.py
|
||||||
|
"LambdaActionV0": "lambda_action",
|
||||||
|
"ClipActionV0": "lambda_action",
|
||||||
|
"RescaleActionV0": "lambda_action",
|
||||||
|
# lambda_observations.py
|
||||||
|
"LambdaObservationV0": "lambda_observations",
|
||||||
|
"FilterObservationV0": "lambda_observations",
|
||||||
|
"FlattenObservationV0": "lambda_observations",
|
||||||
|
"GrayscaleObservationV0": "lambda_observations",
|
||||||
|
"ResizeObservationV0": "lambda_observations",
|
||||||
|
"ReshapeObservationV0": "lambda_observations",
|
||||||
|
"RescaleObservationV0": "lambda_observations",
|
||||||
|
"DtypeObservationV0": "lambda_observations",
|
||||||
|
"PixelObservationV0": "lambda_observations",
|
||||||
|
"NormalizeObservationV0": "lambda_observations",
|
||||||
|
# lambda_reward.py
|
||||||
|
"ClipRewardV0": "lambda_reward",
|
||||||
|
"LambdaRewardV0": "lambda_reward",
|
||||||
|
"NormalizeRewardV0": "lambda_reward",
|
||||||
|
# stateful_action
|
||||||
|
"StickyActionV0": "stateful_action",
|
||||||
|
# stateful_observation
|
||||||
|
"TimeAwareObservationV0": "stateful_observation",
|
||||||
|
"DelayObservationV0": "stateful_observation",
|
||||||
|
"FrameStackObservationV0": "stateful_observation",
|
||||||
|
# atari_preprocessing
|
||||||
|
"AtariPreprocessingV0": "atari_preprocessing",
|
||||||
|
# common
|
||||||
|
"PassiveEnvCheckerV0": "common",
|
||||||
|
"OrderEnforcingV0": "common",
|
||||||
|
"AutoresetV0": "common",
|
||||||
|
"RecordEpisodeStatisticsV0": "common",
|
||||||
|
# rendering
|
||||||
|
"RenderCollectionV0": "rendering",
|
||||||
|
"RecordVideoV0": "rendering",
|
||||||
|
"HumanRenderingV0": "rendering",
|
||||||
|
# jax_to_numpy
|
||||||
|
"JaxToNumpyV0": "jax_to_numpy",
|
||||||
|
# "jax_to_numpy": "jax_to_numpy",
|
||||||
|
# "numpy_to_jax": "jax_to_numpy",
|
||||||
|
# jax_to_torch
|
||||||
|
"JaxToTorchV0": "jax_to_torch",
|
||||||
|
# "jax_to_torch": "jax_to_torch",
|
||||||
|
# "torch_to_jax": "jax_to_torch",
|
||||||
|
# numpy_to_torch
|
||||||
|
"NumpyToTorchV0": "numpy_to_torch",
|
||||||
|
# "torch_to_numpy": "numpy_to_torch",
|
||||||
|
# "numpy_to_torch": "numpy_to_torch",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules.
|
||||||
|
|
||||||
|
This optimises the loading of gymnasium.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of a wrapper to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapper
|
||||||
|
"""
|
||||||
|
if name in _wrapper_to_class:
|
||||||
|
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}"
|
||||||
|
module = importlib.import_module(import_stmt)
|
||||||
|
return getattr(module, name)
|
||||||
|
# add helpful error message if version number has changed
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
@@ -8,7 +8,9 @@ from gymnasium.spaces import Box
|
|||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
except ImportError:
|
except ImportError:
|
||||||
cv2 = None
|
raise gym.error.DependencyNotInstalled(
|
||||||
|
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
@@ -72,10 +74,6 @@ class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
)
|
)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
if cv2 is None:
|
|
||||||
raise gym.error.DependencyNotInstalled(
|
|
||||||
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
|
|
||||||
)
|
|
||||||
assert frame_skip > 0
|
assert frame_skip > 0
|
||||||
assert screen_size > 0
|
assert screen_size > 0
|
||||||
assert noop_max >= 0
|
assert noop_max >= 0
|
||||||
|
@@ -1 +0,0 @@
|
|||||||
"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually."""
|
|
@@ -1,205 +0,0 @@
|
|||||||
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
|
|
||||||
# for an underlying Jax environment then convert the return observations from Jax arrays
|
|
||||||
# back to torch tensors.
|
|
||||||
#
|
|
||||||
# Functionality for converting between torch and jax types originally copied from
|
|
||||||
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
|
||||||
# Under the Apache 2.0 license. Copyright is held by the authors
|
|
||||||
|
|
||||||
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import numbers
|
|
||||||
from collections import abc
|
|
||||||
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
|
||||||
from gymnasium.error import DependencyNotInstalled
|
|
||||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from jax import dlpack as jax_dlpack
|
|
||||||
except ImportError:
|
|
||||||
jnp, jax_dlpack = None, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from torch.utils import dlpack as torch_dlpack
|
|
||||||
|
|
||||||
Device = Union[str, torch.device]
|
|
||||||
except ImportError:
|
|
||||||
torch, torch_dlpack, Device = None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
|
||||||
def torch_to_jax(value: Any) -> Any:
|
|
||||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
|
||||||
if torch is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
|
|
||||||
)
|
|
||||||
elif jnp is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if torch is not None and jnp is not None:
|
|
||||||
|
|
||||||
@torch_to_jax.register(numbers.Number)
|
|
||||||
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
|
||||||
"""Convert a python number (int, float, complex) to a jax array."""
|
|
||||||
assert jnp is not None
|
|
||||||
return jnp.array(value)
|
|
||||||
|
|
||||||
@torch_to_jax.register(torch.Tensor)
|
|
||||||
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
|
||||||
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
|
||||||
assert torch_dlpack is not None and jax_dlpack is not None
|
|
||||||
tensor = torch_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
value
|
|
||||||
)
|
|
||||||
tensor = jax_dlpack.from_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
tensor
|
|
||||||
)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
@torch_to_jax.register(abc.Mapping)
|
|
||||||
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
||||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
|
||||||
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
|
||||||
|
|
||||||
@torch_to_jax.register(abc.Iterable)
|
|
||||||
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
|
||||||
return type(value)(torch_to_jax(v) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
|
||||||
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
|
||||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
|
||||||
if torch is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Torch is not installed therefore cannot call `jax_to_torch`, run `pip install torch`"
|
|
||||||
)
|
|
||||||
elif jnp is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Jax is not installed therefore cannot call `jax_to_torch`, run `pip install gymnasium[jax]`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if torch is not None and jnp is not None:
|
|
||||||
|
|
||||||
@jax_to_torch.register(jnp.DeviceArray)
|
|
||||||
def _devicearray_jax_to_torch(
|
|
||||||
value: jnp.DeviceArray, device: Device | None = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
|
||||||
assert jax_dlpack is not None and torch_dlpack is not None
|
|
||||||
dlpack = jax_dlpack.to_dlpack( # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
value
|
|
||||||
)
|
|
||||||
tensor = torch_dlpack.from_dlpack(dlpack)
|
|
||||||
if device:
|
|
||||||
return tensor.to(device=device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
@jax_to_torch.register(abc.Mapping)
|
|
||||||
def _jax_mapping_to_torch(
|
|
||||||
value: Mapping[str, Any], device: Device | None = None
|
|
||||||
) -> Mapping[str, Any]:
|
|
||||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
|
||||||
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
|
||||||
|
|
||||||
@jax_to_torch.register(abc.Iterable)
|
|
||||||
def _jax_iterable_to_torch(
|
|
||||||
value: Iterable[Any], device: Device | None = None
|
|
||||||
) -> Iterable[Any]:
|
|
||||||
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
|
||||||
return type(value)(jax_to_torch(v, device) for v in value)
|
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|
||||||
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
|
||||||
|
|
||||||
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, device: Device | None = None):
|
|
||||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env: The Jax-based environment to wrap
|
|
||||||
device: The device the torch Tensors should be moved to
|
|
||||||
"""
|
|
||||||
if torch is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"torch is not installed, run `pip install torch`"
|
|
||||||
)
|
|
||||||
elif jnp is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"jax is not installed, run `pip install gymnasium[jax]`"
|
|
||||||
)
|
|
||||||
|
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
|
|
||||||
self.device: Device | None = device
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self, action: WrapperActType
|
|
||||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
|
||||||
"""Performs the given action within the environment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: The action to perform as a PyTorch Tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The next observation, reward, termination, truncation, and extra info
|
|
||||||
"""
|
|
||||||
jax_action = torch_to_jax(action)
|
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
|
||||||
|
|
||||||
return (
|
|
||||||
jax_to_torch(obs, self.device),
|
|
||||||
float(reward),
|
|
||||||
bool(terminated),
|
|
||||||
bool(truncated),
|
|
||||||
jax_to_torch(info, self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
||||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
||||||
"""Resets the environment returning PyTorch-based observation and info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seed for resetting the environment
|
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PyTorch-based observations and info
|
|
||||||
"""
|
|
||||||
if options:
|
|
||||||
options = torch_to_jax(options)
|
|
||||||
|
|
||||||
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
|
||||||
|
|
||||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
|
||||||
"""Returns the rendered frames as a NumPy array."""
|
|
||||||
return jax_to_numpy(self.env.render())
|
|
@@ -16,25 +16,21 @@ from gymnasium.error import DependencyNotInstalled
|
|||||||
try:
|
try:
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# We handle the error internal to the relative functions
|
raise DependencyNotInstalled(
|
||||||
jnp = None
|
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["jax_to_numpy", "numpy_to_jax", "JaxToNumpyV0"]
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def numpy_to_jax(value: Any) -> Any:
|
def numpy_to_jax(value: Any) -> Any:
|
||||||
"""Converts a value to a Jax DeviceArray."""
|
"""Converts a value to a Jax DeviceArray."""
|
||||||
if jnp is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if jnp is not None:
|
|
||||||
|
|
||||||
@numpy_to_jax.register(numbers.Number)
|
@numpy_to_jax.register(numbers.Number)
|
||||||
@numpy_to_jax.register(np.ndarray)
|
@numpy_to_jax.register(np.ndarray)
|
||||||
def _number_ndarray_numpy_to_jax(
|
def _number_ndarray_numpy_to_jax(
|
||||||
@@ -44,11 +40,13 @@ if jnp is not None:
|
|||||||
assert jnp is not None
|
assert jnp is not None
|
||||||
return jnp.array(value)
|
return jnp.array(value)
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(abc.Mapping)
|
@numpy_to_jax.register(abc.Mapping)
|
||||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
||||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_jax.register(abc.Iterable)
|
@numpy_to_jax.register(abc.Iterable)
|
||||||
def _iterable_numpy_to_jax(
|
def _iterable_numpy_to_jax(
|
||||||
value: Iterable[np.ndarray | Any],
|
value: Iterable[np.ndarray | Any],
|
||||||
@@ -60,23 +58,17 @@ if jnp is not None:
|
|||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def jax_to_numpy(value: Any) -> Any:
|
def jax_to_numpy(value: Any) -> Any:
|
||||||
"""Converts a value to a numpy array."""
|
"""Converts a value to a numpy array."""
|
||||||
if jnp is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Jax is not installed therefore cannot call `jax_to_numpy`, run `pip install gymnasium[jax]`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if jnp is not None:
|
|
||||||
|
|
||||||
@jax_to_numpy.register(jnp.DeviceArray)
|
@jax_to_numpy.register(jnp.DeviceArray)
|
||||||
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
||||||
"""Converts a Jax DeviceArray to a numpy array."""
|
"""Converts a Jax DeviceArray to a numpy array."""
|
||||||
return np.array(value)
|
return np.array(value)
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(abc.Mapping)
|
@jax_to_numpy.register(abc.Mapping)
|
||||||
def _mapping_jax_to_numpy(
|
def _mapping_jax_to_numpy(
|
||||||
value: Mapping[str, jnp.DeviceArray | Any]
|
value: Mapping[str, jnp.DeviceArray | Any]
|
||||||
@@ -84,6 +76,7 @@ if jnp is not None:
|
|||||||
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
||||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
@jax_to_numpy.register(abc.Iterable)
|
@jax_to_numpy.register(abc.Iterable)
|
||||||
def _iterable_jax_to_numpy(
|
def _iterable_jax_to_numpy(
|
||||||
value: Iterable[np.ndarray | Any],
|
value: Iterable[np.ndarray | Any],
|
178
gymnasium/experimental/wrappers/jax_to_torch.py
Normal file
178
gymnasium/experimental/wrappers/jax_to_torch.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
|
||||||
|
# for an underlying Jax environment then convert the return observations from Jax arrays
|
||||||
|
# back to torch tensors.
|
||||||
|
#
|
||||||
|
# Functionality for converting between torch and jax types originally copied from
|
||||||
|
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
|
||||||
|
# Under the Apache 2.0 license. Copyright is held by the authors
|
||||||
|
|
||||||
|
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import numbers
|
||||||
|
from collections import abc
|
||||||
|
from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import dlpack as jax_dlpack
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.utils import dlpack as torch_dlpack
|
||||||
|
|
||||||
|
Device = Union[str, torch.device]
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["jax_to_torch", "torch_to_jax", "JaxToTorchV0"]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def torch_to_jax(value: Any) -> Any:
|
||||||
|
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(numbers.Number)
|
||||||
|
def _number_torch_to_jax(value: numbers.Number) -> Any:
|
||||||
|
"""Convert a python number (int, float, complex) to a jax array."""
|
||||||
|
return jnp.array(value)
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(torch.Tensor)
|
||||||
|
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
|
||||||
|
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
|
||||||
|
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(abc.Mapping)
|
||||||
|
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||||
|
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@torch_to_jax.register(abc.Iterable)
|
||||||
|
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||||
|
return type(value)(torch_to_jax(v) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.singledispatch
|
||||||
|
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
raise Exception(
|
||||||
|
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(jnp.DeviceArray)
|
||||||
|
def _devicearray_jax_to_torch(
|
||||||
|
value: jnp.DeviceArray, device: Device | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
|
assert jax_dlpack is not None and torch_dlpack is not None
|
||||||
|
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
tensor = torch_dlpack.from_dlpack(dlpack)
|
||||||
|
if device:
|
||||||
|
return tensor.to(device=device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(abc.Mapping)
|
||||||
|
def _jax_mapping_to_torch(
|
||||||
|
value: Mapping[str, Any], device: Device | None = None
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||||
|
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
|
@jax_to_torch.register(abc.Iterable)
|
||||||
|
def _jax_iterable_to_torch(
|
||||||
|
value: Iterable[Any], device: Device | None = None
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
|
||||||
|
return type(value)(jax_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
|
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
|
||||||
|
|
||||||
|
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, device: Device | None = None):
|
||||||
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Jax-based environment to wrap
|
||||||
|
device: The device the torch Tensors should be moved to
|
||||||
|
"""
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self.device: Device | None = device
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: WrapperActType
|
||||||
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||||
|
"""Performs the given action within the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to perform as a PyTorch Tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next observation, reward, termination, truncation, and extra info
|
||||||
|
"""
|
||||||
|
jax_action = torch_to_jax(action)
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||||
|
|
||||||
|
return (
|
||||||
|
jax_to_torch(obs, self.device),
|
||||||
|
float(reward),
|
||||||
|
bool(terminated),
|
||||||
|
bool(truncated),
|
||||||
|
jax_to_torch(info, self.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment returning PyTorch-based observation and info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The seed for resetting the environment
|
||||||
|
options: The options for resetting the environment, these are converted to jax arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyTorch-based observations and info
|
||||||
|
"""
|
||||||
|
if options:
|
||||||
|
options = torch_to_jax(options)
|
||||||
|
|
||||||
|
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
|
||||||
|
|
||||||
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||||
|
"""Returns the rendered frames as a NumPy array."""
|
||||||
|
return jax_to_numpy(self.env.render())
|
@@ -18,35 +18,35 @@ try:
|
|||||||
|
|
||||||
Device = Union[str, torch.device]
|
Device = Union[str, torch.device]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
torch, Device = None, None
|
raise DependencyNotInstalled(
|
||||||
|
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["torch_to_numpy", "numpy_to_torch", "NumpyToTorchV0"]
|
||||||
|
|
||||||
|
|
||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def torch_to_numpy(value: Any) -> Any:
|
def torch_to_numpy(value: Any) -> Any:
|
||||||
"""Converts a PyTorch Tensor into a NumPy Array."""
|
"""Converts a PyTorch Tensor into a NumPy Array."""
|
||||||
if torch is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if torch is not None:
|
|
||||||
|
|
||||||
@torch_to_numpy.register(numbers.Number)
|
@torch_to_numpy.register(numbers.Number)
|
||||||
@torch_to_numpy.register(torch.Tensor)
|
@torch_to_numpy.register(torch.Tensor)
|
||||||
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
|
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
|
||||||
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
|
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
|
||||||
return np.array(value)
|
return np.array(value)
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(abc.Mapping)
|
@torch_to_numpy.register(abc.Mapping)
|
||||||
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
|
||||||
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
@torch_to_numpy.register(abc.Iterable)
|
@torch_to_numpy.register(abc.Iterable)
|
||||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
|
||||||
@@ -56,22 +56,13 @@ if torch is not None:
|
|||||||
@functools.singledispatch
|
@functools.singledispatch
|
||||||
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
if torch is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"Torch is not installed therefore cannot call `numpy_to_torch`, run `pip install torch`"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if torch is not None:
|
|
||||||
|
|
||||||
@numpy_to_torch.register(np.ndarray)
|
@numpy_to_torch.register(np.ndarray)
|
||||||
def _numpy_to_torch(
|
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
|
||||||
value: np.ndarray, device: Device | None = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
|
||||||
assert torch is not None
|
assert torch is not None
|
||||||
tensor = torch.tensor(value)
|
tensor = torch.tensor(value)
|
||||||
@@ -79,6 +70,7 @@ if torch is not None:
|
|||||||
return tensor.to(device=device)
|
return tensor.to(device=device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_torch.register(abc.Mapping)
|
@numpy_to_torch.register(abc.Mapping)
|
||||||
def _numpy_mapping_to_torch(
|
def _numpy_mapping_to_torch(
|
||||||
value: Mapping[str, Any], device: Device | None = None
|
value: Mapping[str, Any], device: Device | None = None
|
||||||
@@ -86,6 +78,7 @@ if torch is not None:
|
|||||||
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
|
||||||
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
||||||
|
|
||||||
|
|
||||||
@numpy_to_torch.register(abc.Iterable)
|
@numpy_to_torch.register(abc.Iterable)
|
||||||
def _numpy_iterable_to_torch(
|
def _numpy_iterable_to_torch(
|
||||||
value: Iterable[Any], device: Device | None = None
|
value: Iterable[Any], device: Device | None = None
|
||||||
@@ -110,11 +103,6 @@ class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
env: The Jax-based environment to wrap
|
env: The Jax-based environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
if torch is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"torch is not installed, run `pip install torch`"
|
|
||||||
)
|
|
||||||
|
|
||||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
23
tests/experimental/wrappers/test_init_shorten_import.py
Normal file
23
tests/experimental/wrappers/test_init_shorten_import.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Tests that all shortened imports for wrappers all work."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
from gymnasium.experimental.wrappers import (
|
||||||
|
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
|
||||||
|
)
|
||||||
|
from gymnasium.experimental.wrappers import __all__
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_wrapper_shorten():
|
||||||
|
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion."""
|
||||||
|
all_wrappers = set(__all__)
|
||||||
|
all_wrappers.remove("vector")
|
||||||
|
assert all_wrappers == set(_wrapper_to_class.keys())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("wrapper_name", __all__)
|
||||||
|
def test_all_wrappers_shortened(wrapper_name):
|
||||||
|
"""Check that each element of the `__all__` wrappers can be loaded."""
|
||||||
|
if wrapper_name != "vector":
|
||||||
|
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
@@ -4,7 +4,7 @@ import jax.numpy as jnp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import (
|
from gymnasium.experimental.wrappers.jax_to_numpy import (
|
||||||
JaxToNumpyV0,
|
JaxToNumpyV0,
|
||||||
jax_to_numpy,
|
jax_to_numpy,
|
||||||
numpy_to_jax,
|
numpy_to_jax,
|
||||||
|
@@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers.conversion.jax_to_torch import (
|
from gymnasium.experimental.wrappers.jax_to_torch import (
|
||||||
JaxToTorchV0,
|
JaxToTorchV0,
|
||||||
jax_to_torch,
|
jax_to_torch,
|
||||||
torch_to_jax,
|
torch_to_jax,
|
||||||
|
Reference in New Issue
Block a user