wrappers.vector.NumpyToTorch uses Device from wrappers.NumpyTorch not JaxToTorch (#1308)

This commit is contained in:
Mark Towers
2025-02-13 22:45:30 +00:00
committed by GitHub
parent 2a8441d439
commit a8946325e8
2 changed files with 2 additions and 3 deletions

View File

@@ -24,7 +24,7 @@ except ImportError:
)
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch"]
__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch", "Device"]
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
_NoneType = type(None)

View File

@@ -7,8 +7,7 @@ from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.vector_env import ArrayType
from gymnasium.wrappers.jax_to_torch import Device
from gymnasium.wrappers.numpy_to_torch import numpy_to_torch, torch_to_numpy
from gymnasium.wrappers.numpy_to_torch import Device, numpy_to_torch, torch_to_numpy
__all__ = ["NumpyToTorch"]