From a8946325e84426da8bc33cd73e0093521d309b9c Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Thu, 13 Feb 2025 22:45:30 +0000 Subject: [PATCH] `wrappers.vector.NumpyToTorch` uses `Device` from `wrappers.NumpyTorch` not `JaxToTorch` (#1308) --- gymnasium/wrappers/numpy_to_torch.py | 2 +- gymnasium/wrappers/vector/numpy_to_torch.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gymnasium/wrappers/numpy_to_torch.py b/gymnasium/wrappers/numpy_to_torch.py index 13715bd64..db3870611 100644 --- a/gymnasium/wrappers/numpy_to_torch.py +++ b/gymnasium/wrappers/numpy_to_torch.py @@ -24,7 +24,7 @@ except ImportError: ) -__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch"] +__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"] # The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10 _NoneType = type(None) diff --git a/gymnasium/wrappers/vector/numpy_to_torch.py b/gymnasium/wrappers/vector/numpy_to_torch.py index d9c3eaf12..b0e09aee2 100644 --- a/gymnasium/wrappers/vector/numpy_to_torch.py +++ b/gymnasium/wrappers/vector/numpy_to_torch.py @@ -7,8 +7,7 @@ from typing import Any from gymnasium.core import ActType, ObsType from gymnasium.vector import VectorEnv, VectorWrapper from gymnasium.vector.vector_env import ArrayType -from gymnasium.wrappers.jax_to_torch import Device -from gymnasium.wrappers.numpy_to_torch import numpy_to_torch, torch_to_numpy +from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy __all__ = ["NumpyToTorch"]