mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
Add generic conversion wrapper between Array API compatible frameworks (#1333)
This commit is contained in:
@@ -3,19 +3,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import numbers
|
||||
from collections import abc
|
||||
from typing import Any, Iterable, Mapping, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.wrappers.array_conversion import (
|
||||
ArrayConversion,
|
||||
array_conversion,
|
||||
module_namespace,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
@@ -24,110 +25,13 @@ except ImportError:
|
||||
|
||||
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
|
||||
|
||||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
|
||||
_NoneType = type(None)
|
||||
|
||||
jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
|
||||
|
||||
numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_jax(value: Any) -> Any:
|
||||
"""Converts a value to a Jax Array."""
|
||||
raise Exception(
|
||||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@numpy_to_jax.register(numbers.Number)
|
||||
def _number_to_jax(
|
||||
value: numbers.Number,
|
||||
) -> jax.Array:
|
||||
"""Converts a number (int, float, etc.) to a Jax Array."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value)
|
||||
|
||||
|
||||
@numpy_to_jax.register(np.ndarray)
|
||||
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
|
||||
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
|
||||
assert jnp is not None
|
||||
return jnp.array(value, dtype=value.dtype)
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Mapping)
|
||||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
|
||||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@numpy_to_jax.register(abc.Iterable)
|
||||
def _iterable_numpy_to_jax(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jax.Array | Any]:
|
||||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(numpy_to_jax(v) for v in value)
|
||||
else:
|
||||
return type(value)(numpy_to_jax(v) for v in value)
|
||||
|
||||
|
||||
@numpy_to_jax.register(_NoneType)
|
||||
def _none_numpy_to_jax(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def jax_to_numpy(value: Any) -> Any:
|
||||
"""Converts a value to a numpy array."""
|
||||
raise Exception(
|
||||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
|
||||
)
|
||||
|
||||
|
||||
@jax_to_numpy.register(jax.Array)
|
||||
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
|
||||
"""Converts a Jax Array to a numpy array."""
|
||||
return np.array(value)
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Mapping)
|
||||
def _mapping_jax_to_numpy(
|
||||
value: Mapping[str, jax.Array | Any],
|
||||
) -> Mapping[str, np.ndarray | Any]:
|
||||
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
|
||||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@jax_to_numpy.register(abc.Iterable)
|
||||
def _iterable_jax_to_numpy(
|
||||
value: Iterable[np.ndarray | Any],
|
||||
) -> Iterable[jax.Array | Any]:
|
||||
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
|
||||
if isinstance(value, jax.Array):
|
||||
# Since the update to jax 0.6.0, calling jax_to_numpy with a <class 'jaxlib.xla_extension.ArrayImpl'>
|
||||
# argument wrongly dispatches to _iterable_jax_to_numpy which fails with:
|
||||
# TypeError: (): incompatible function arguments.
|
||||
# See: https://github.com/Farama-Foundation/Gymnasium/issues/1360
|
||||
return _devicearray_jax_to_numpy(value)
|
||||
elif hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
return type(value)._make(jax_to_numpy(v) for v in value)
|
||||
else:
|
||||
return type(value)(jax_to_numpy(v) for v in value)
|
||||
|
||||
|
||||
@jax_to_numpy.register(_NoneType)
|
||||
def _none_jax_to_numpy(value: None) -> None:
|
||||
"""Passes through None values."""
|
||||
return value
|
||||
|
||||
|
||||
class JaxToNumpy(
|
||||
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
|
||||
gym.utils.RecordConstructorArgs,
|
||||
):
|
||||
class JaxToNumpy(ArrayConversion):
|
||||
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
||||
|
||||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
||||
@@ -169,48 +73,4 @@ class JaxToNumpy(
|
||||
raise DependencyNotInstalled(
|
||||
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
||||
)
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
||||
"""Transforms the action to a jax array .
|
||||
|
||||
Args:
|
||||
action: the action to perform as a numpy array
|
||||
|
||||
Returns:
|
||||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
|
||||
"""
|
||||
jax_action = numpy_to_jax(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
|
||||
return (
|
||||
jax_to_numpy(obs),
|
||||
float(reward),
|
||||
bool(terminated),
|
||||
bool(truncated),
|
||||
jax_to_numpy(info),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Resets the environment returning numpy-based observation and info.
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
|
||||
Returns:
|
||||
Numpy-based observations and info
|
||||
"""
|
||||
if options:
|
||||
options = numpy_to_jax(options)
|
||||
|
||||
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
||||
|
||||
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
||||
"""Returns the rendered frames as a numpy array."""
|
||||
return jax_to_numpy(self.env.render())
|
||||
super().__init__(env=env, env_xp=jnp, target_xp=np)
|
||||
|
Reference in New Issue
Block a user