mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Co-authored-by: Justin Deutsch <djustin8@vt.edu> Co-authored-by: Gianluca De Cola <42657588+gianlucadecola@users.noreply.github.com>
135 lines
4.5 KiB
Python
135 lines
4.5 KiB
Python
"""Helper functions and wrapper class for converting between numpy and Jax."""
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import numbers
|
|
from collections import abc
|
|
from typing import Any, Iterable, Mapping, SupportsFloat
|
|
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
from gymnasium import Env, Wrapper
|
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
|
|
|
|
|
@functools.singledispatch
|
|
def numpy_to_jax(value: Any) -> Any:
|
|
"""Converts a value to a Jax DeviceArray."""
|
|
raise Exception(
|
|
f"No conversion for Numpy to Jax registered for type: {type(value)}"
|
|
)
|
|
|
|
|
|
@numpy_to_jax.register(numbers.Number)
|
|
@numpy_to_jax.register(np.ndarray)
|
|
def _number_ndarray_numpy_to_jax(value: np.ndarray | numbers.Number) -> jnp.DeviceArray:
|
|
"""Converts a numpy array or number (int, float, etc.) to a Jax DeviceArray."""
|
|
return jnp.array(value)
|
|
|
|
|
|
@numpy_to_jax.register(abc.Mapping)
|
|
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
"""Converts a dictionary of numpy arrays to a mapping of Jax DeviceArrays."""
|
|
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
|
|
|
|
|
|
@numpy_to_jax.register(abc.Iterable)
|
|
def _iterable_numpy_to_jax(
|
|
value: Iterable[np.ndarray | Any],
|
|
) -> Iterable[jnp.DeviceArray | Any]:
|
|
"""Converts an Iterable from Numpy Arrays to an iterable of Jax DeviceArrays."""
|
|
return type(value)(numpy_to_jax(v) for v in value)
|
|
|
|
|
|
@functools.singledispatch
|
|
def jax_to_numpy(value: Any) -> Any:
|
|
"""Converts a value to a numpy array."""
|
|
raise Exception(
|
|
f"No conversion for Jax to Numpy registered for type: {type(value)}"
|
|
)
|
|
|
|
|
|
@jax_to_numpy.register(jnp.DeviceArray)
|
|
def _devicearray_jax_to_numpy(value: jnp.DeviceArray) -> np.ndarray:
|
|
"""Converts a Jax DeviceArray to a numpy array."""
|
|
return np.array(value)
|
|
|
|
|
|
@jax_to_numpy.register(abc.Mapping)
|
|
def _mapping_jax_to_numpy(
|
|
value: Mapping[str, jnp.DeviceArray | Any]
|
|
) -> Mapping[str, np.ndarray | Any]:
|
|
"""Converts a dictionary of Jax DeviceArrays to a mapping of numpy arrays."""
|
|
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
|
|
|
|
|
|
@jax_to_numpy.register(abc.Iterable)
|
|
def _iterable_jax_to_numpy(
|
|
value: Iterable[np.ndarray | Any],
|
|
) -> Iterable[jnp.DeviceArray | Any]:
|
|
"""Converts an Iterable from Numpy arrays to an iterable of Jax DeviceArrays."""
|
|
return type(value)(jax_to_numpy(v) for v in value)
|
|
|
|
|
|
class JaxToNumpyV0(Wrapper):
|
|
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
|
|
|
|
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
|
|
|
Notes:
|
|
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
|
|
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
|
"""
|
|
|
|
def __init__(self, env: Env):
|
|
"""Wraps an environment such that the input and outputs are numpy arrays.
|
|
|
|
Args:
|
|
env: the environment to wrap
|
|
"""
|
|
super().__init__(env)
|
|
|
|
def step(
|
|
self, action: WrapperActType
|
|
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
|
|
"""Transforms the action to a jax array .
|
|
|
|
Args:
|
|
action: the action to perform as a numpy array
|
|
|
|
Returns:
|
|
A tuple containing the next observation, reward, termination, truncation, and extra info.
|
|
"""
|
|
jax_action = numpy_to_jax(action)
|
|
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
|
|
|
return (
|
|
jax_to_numpy(obs),
|
|
float(reward),
|
|
bool(terminated),
|
|
bool(truncated),
|
|
jax_to_numpy(info),
|
|
)
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
|
"""Resets the environment returning numpy-based observation and info.
|
|
|
|
Args:
|
|
seed: The seed for resetting the environment
|
|
options: The options for resetting the environment, these are converted to jax arrays.
|
|
|
|
Returns:
|
|
Numpy-based observations and info
|
|
"""
|
|
if options:
|
|
options = numpy_to_jax(options)
|
|
|
|
return jax_to_numpy(self.env.reset(seed=seed, options=options))
|
|
|
|
def render(self) -> RenderFrame | list[RenderFrame] | None:
|
|
"""Returns the rendered frames as a numpy array."""
|
|
return jax_to_numpy(self.env.render())
|