mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
133 lines
5.0 KiB
Python
133 lines
5.0 KiB
Python
"""Vector wrapper for converting between Array API compatible frameworks."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from types import ModuleType
|
|
from typing import Any
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.core import ActType, ObsType
|
|
from gymnasium.vector import VectorEnv, VectorWrapper
|
|
from gymnasium.vector.vector_env import ArrayType
|
|
from gymnasium.wrappers.array_conversion import (
|
|
Device,
|
|
array_conversion,
|
|
module_name_to_namespace,
|
|
)
|
|
|
|
|
|
__all__ = ["ArrayConversion"]
|
|
|
|
|
|
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
|
|
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
|
|
|
|
Popular Array API frameworks include ``numpy``, ``torch``, ``jax.numpy``, ``cupy`` etc. With this wrapper, you can convert outputs from your environment to
|
|
any of these frameworks. Conversely, actions are automatically mapped back to the environment framework, if possible without moving the
|
|
data or device transfers.
|
|
|
|
Notes:
|
|
A vectorized version of :class:`gymnasium.wrappers.ArrayConversion`
|
|
|
|
Example:
|
|
>>> import gymnasium as gym # doctest: +SKIP
|
|
>>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
|
|
>>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: VectorEnv,
|
|
env_xp: ModuleType | str,
|
|
target_xp: ModuleType | str,
|
|
env_device: Device | None = None,
|
|
target_device: Device | None = None,
|
|
):
|
|
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
|
|
|
|
Args:
|
|
env: The Array API compatible environment to wrap
|
|
env_xp: The Array API framework the environment is on
|
|
target_xp: The Array API framework to convert to
|
|
env_device: The device the environment is on
|
|
target_device: The device on which Arrays should be returned
|
|
"""
|
|
gym.utils.RecordConstructorArgs.__init__(self)
|
|
VectorWrapper.__init__(self, env)
|
|
self._env_xp = env_xp
|
|
self._target_xp = target_xp
|
|
self._env_device = env_device
|
|
self._target_device = target_device
|
|
|
|
def step(
|
|
self, actions: ActType
|
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
|
"""Transforms the action to the specified xp module array type.
|
|
|
|
Args:
|
|
actions: The action to perform
|
|
|
|
Returns:
|
|
A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info.
|
|
"""
|
|
actions = array_conversion(actions, xp=self._env_xp, device=self._env_device)
|
|
obs, reward, terminated, truncated, info = self.env.step(actions)
|
|
|
|
return (
|
|
array_conversion(obs, xp=self._target_xp, device=self._target_device),
|
|
array_conversion(reward, xp=self._target_xp, device=self._target_device),
|
|
array_conversion(
|
|
terminated, xp=self._target_xp, device=self._target_device
|
|
),
|
|
array_conversion(truncated, xp=self._target_xp, device=self._target_device),
|
|
array_conversion(info, xp=self._target_xp, device=self._target_device),
|
|
)
|
|
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: int | list[int] | None = None,
|
|
options: dict[str, Any] | None = None,
|
|
) -> tuple[ObsType, dict[str, Any]]:
|
|
"""Resets the environment returning xp-based observation and info.
|
|
|
|
Args:
|
|
seed: The seed for resetting the environment
|
|
options: The options for resetting the environment, these are converted to xp arrays.
|
|
|
|
Returns:
|
|
xp-based observations and info
|
|
"""
|
|
if options:
|
|
options = array_conversion(
|
|
options, xp=self._env_xp, device=self._env_device
|
|
)
|
|
|
|
return array_conversion(
|
|
self.env.reset(seed=seed, options=options),
|
|
xp=self._target_xp,
|
|
device=self._target_device,
|
|
)
|
|
|
|
def __getstate__(self):
|
|
"""Returns the object pickle state with args and kwargs."""
|
|
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
|
|
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
|
|
env_device = self._env_device
|
|
target_device = self._target_device
|
|
return {
|
|
"env_xp_name": env_xp_name,
|
|
"target_xp_name": target_xp_name,
|
|
"env_device": env_device,
|
|
"target_device": target_device,
|
|
"env": self.env,
|
|
}
|
|
|
|
def __setstate__(self, d):
|
|
"""Sets the object pickle state using d."""
|
|
self.env = d["env"]
|
|
self._env_xp = module_name_to_namespace(d["env_xp_name"])
|
|
self._target_xp = module_name_to_namespace(d["target_xp_name"])
|
|
self._env_device = d["env_device"]
|
|
self._target_device = d["target_device"]
|