Files
Gymnasium/gymnasium/experimental/wrappers/torch_to_jax.py

164 lines
5.8 KiB
Python
Raw Normal View History

# This wrapper will convert torch inputs for the actions and observations to Jax arrays
# for an underlying Jax environment then convert the return observations from Jax arrays
# back to torch tensors.
#
# Functionality for converting between torch and jax types originally copied from
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
from gymnasium import Env, Wrapper
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.numpy_to_jax import jax_to_numpy
try:
import torch
from torch.utils import dlpack as torch_dlpack
except ImportError:
raise DependencyNotInstalled("torch is not installed, run `pip install torch`")
Device = Union[str, torch.device]
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
raise Exception(
f"No conversion for PyTorch to Jax registered for type: {type(value)}"
)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jnp.DeviceArray:
"""Converts a PyTorch Tensor into a Jax DeviceArray."""
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax DeviceArrays."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax DeviceArrays."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
raise Exception(
f"No conversion for Jax to PyTorch registered for type={type(value)} and device: {device}"
)
@jax_to_torch.register(jnp.DeviceArray)
def _devicearray_jax_to_torch(
value: jnp.DeviceArray, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax DeviceArray into a PyTorch Tensor."""
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax DeviceArrays into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax DeviceArrays to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(Wrapper):
"""Wraps a jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
jax_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a NumPy array."""
return jax_to_numpy(self.env.render())