Files
Gymnasium/gymnasium/wrappers/vector/array_conversion.py

132 lines
4.9 KiB
Python

"""Vector wrapper for converting between Array API compatible frameworks."""
from __future__ import annotations
from types import ModuleType
from typing import Any
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.array_conversion import (
Device,
array_conversion,
module_name_to_namespace,
)
__all__ = ["ArrayConversion"]
class ArrayConversion(VectorWrapper, gym.utils.RecordConstructorArgs):
"""Wraps a vector environment returning Array API compatible arrays so that it can be interacted with through a specific framework.
Notes:
A vectorized version of ``gymnasium.wrappers.ArrayConversion``
Actions must be provided as Array API compatible arrays and observations, rewards, terminations and truncations will be returned in the desired framework.
xp here is a module that is compatible with the Array API standard, e.g. ``numpy``, ``jax`` etc.
Example:
>>> import gymnasium as gym # doctest: +SKIP
>>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP
>>> envs = ArrayConversion(envs, xp=np) # doctest: +SKIP
"""
def __init__(
self,
env: VectorEnv,
env_xp: ModuleType | str,
target_xp: ModuleType | str,
env_device: Device | None = None,
target_device: Device | None = None,
):
"""Wrapper class to change inputs and outputs of environment to any Array API framework.
Args:
env: The Array API compatible environment to wrap
env_xp: The Array API framework the environment is on
target_xp: The Array API framework to convert to
env_device: The device the environment is on
target_device: The device on which Arrays should be returned
"""
gym.utils.RecordConstructorArgs.__init__(self)
VectorWrapper.__init__(self, env)
self._env_xp = env_xp
self._target_xp = target_xp
self._env_device = env_device
self._target_device = target_device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Transforms the action to the specified xp module array type.
Args:
actions: The action to perform
Returns:
A tuple containing xp versions of the next observation, reward, termination, truncation, and extra info.
"""
actions = array_conversion(actions, xp=self._env_xp, device=self._env_device)
obs, reward, terminated, truncated, info = self.env.step(actions)
return (
array_conversion(obs, xp=self._target_xp, device=self._target_device),
array_conversion(reward, xp=self._target_xp, device=self._target_device),
array_conversion(
terminated, xp=self._target_xp, device=self._target_device
),
array_conversion(truncated, xp=self._target_xp, device=self._target_device),
array_conversion(info, xp=self._target_xp, device=self._target_device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning xp-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to xp arrays.
Returns:
xp-based observations and info
"""
if options:
options = array_conversion(
options, xp=self._env_xp, device=self._env_device
)
return array_conversion(
self.env.reset(seed=seed, options=options),
xp=self._target_xp,
device=self._target_device,
)
def __getstate__(self):
"""Returns the object pickle state with args and kwargs."""
env_xp_name = self._env_xp.__name__.replace("array_api_compat.", "")
target_xp_name = self._target_xp.__name__.replace("array_api_compat.", "")
env_device = self._env_device
target_device = self._target_device
return {
"env_xp_name": env_xp_name,
"target_xp_name": target_xp_name,
"env_device": env_device,
"target_device": target_device,
"env": self.env,
}
def __setstate__(self, d):
"""Sets the object pickle state using d."""
self.env = d["env"]
self._env_xp = module_name_to_namespace(d["env_xp_name"])
self._target_xp = module_name_to_namespace(d["target_xp_name"])
self._env_device = d["env_device"]
self._target_device = d["target_device"]