diff --git a/gymnasium/wrappers/vector/numpy_to_torch.py b/gymnasium/wrappers/vector/numpy_to_torch.py index b0e09aee2..8ae7e2728 100644 --- a/gymnasium/wrappers/vector/numpy_to_torch.py +++ b/gymnasium/wrappers/vector/numpy_to_torch.py @@ -42,7 +42,7 @@ class NumpyToTorch(VectorWrapper): """Wrapper class to change inputs and outputs of environment to PyTorch tensors. Args: - env: The Jax-based vector environment to wrap + env: The NumPy-based vector environment to wrap device: The device the torch Tensors should be moved to """ super().__init__(env) @@ -60,8 +60,8 @@ class NumpyToTorch(VectorWrapper): Returns: The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info """ - jax_action = torch_to_numpy(actions) - obs, reward, terminated, truncated, info = self.env.step(jax_action) + numpy_action = torch_to_numpy(actions) + obs, reward, terminated, truncated, info = self.env.step(numpy_action) return ( numpy_to_torch(obs, self.device), @@ -81,7 +81,7 @@ class NumpyToTorch(VectorWrapper): Args: seed: The seed for resetting the environment - options: The options for resetting the environment, these are converted to jax arrays. + options: The options for resetting the environment, these are converted to NumPy arrays. Returns: PyTorch-based observations and info