2023-05-23 15:46:04 +01:00
|
|
|
"""Wrapper for converting NumPy environments to PyTorch."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2023-05-23 15:46:04 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-05-12 00:10:06 +02:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
2023-05-23 15:46:04 +01:00
|
|
|
|
2025-05-12 00:10:06 +02:00
|
|
|
from gymnasium.vector import VectorEnv
|
|
|
|
from gymnasium.wrappers.numpy_to_torch import Device
|
|
|
|
from gymnasium.wrappers.vector.array_conversion import ArrayConversion
|
2023-11-07 13:27:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["NumpyToTorch"]
|
|
|
|
|
|
|
|
|
2025-05-12 00:10:06 +02:00
|
|
|
class NumpyToTorch(ArrayConversion):
|
2023-11-07 13:27:25 +00:00
|
|
|
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> import torch
|
|
|
|
>>> import gymnasium as gym
|
|
|
|
>>> from gymnasium.wrappers.vector import NumpyToTorch
|
|
|
|
>>> envs = gym.make_vec("CartPole-v1", 3)
|
|
|
|
>>> envs = NumpyToTorch(envs)
|
|
|
|
>>> obs, _ = envs.reset(seed=123)
|
|
|
|
>>> type(obs)
|
|
|
|
<class 'torch.Tensor'>
|
|
|
|
>>> action = torch.tensor(envs.action_space.sample())
|
|
|
|
>>> obs, reward, terminated, truncated, info = envs.step(action)
|
|
|
|
>>> envs.close()
|
|
|
|
>>> type(obs)
|
|
|
|
<class 'torch.Tensor'>
|
|
|
|
>>> type(reward)
|
|
|
|
<class 'torch.Tensor'>
|
|
|
|
>>> type(terminated)
|
|
|
|
<class 'torch.Tensor'>
|
|
|
|
>>> type(truncated)
|
|
|
|
<class 'torch.Tensor'>
|
|
|
|
"""
|
2023-05-23 15:46:04 +01:00
|
|
|
|
|
|
|
def __init__(self, env: VectorEnv, device: Device | None = None):
|
|
|
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
|
|
|
|
|
|
|
Args:
|
2025-03-04 16:28:47 +03:00
|
|
|
env: The NumPy-based vector environment to wrap
|
2023-05-23 15:46:04 +01:00
|
|
|
device: The device the torch Tensors should be moved to
|
|
|
|
"""
|
2025-05-12 00:10:06 +02:00
|
|
|
super().__init__(env, env_xp=np, target_xp=torch, target_device=device)
|
2023-05-23 15:46:04 +01:00
|
|
|
|
|
|
|
self.device: Device | None = device
|