diff --git a/gymnasium/wrappers/numpy_to_torch.py b/gymnasium/wrappers/numpy_to_torch.py index 13715bd64..db3870611 100644 --- a/gymnasium/wrappers/numpy_to_torch.py +++ b/gymnasium/wrappers/numpy_to_torch.py @@ -24,7 +24,7 @@ except ImportError: ) -__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch"] +__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"] # The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10 _NoneType = type(None) diff --git a/gymnasium/wrappers/vector/numpy_to_torch.py b/gymnasium/wrappers/vector/numpy_to_torch.py index d9c3eaf12..b0e09aee2 100644 --- a/gymnasium/wrappers/vector/numpy_to_torch.py +++ b/gymnasium/wrappers/vector/numpy_to_torch.py @@ -7,8 +7,7 @@ from typing import Any from gymnasium.core import ActType, ObsType from gymnasium.vector import VectorEnv, VectorWrapper from gymnasium.vector.vector_env import ArrayType -from gymnasium.wrappers.jax_to_torch import Device -from gymnasium.wrappers.numpy_to_torch import numpy_to_torch, torch_to_numpy +from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy __all__ = ["NumpyToTorch"]