mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Refer to numpy instead of jax [for vars and docs] in vector NumpyToTorch (#1319)
This commit is contained in:
@@ -42,7 +42,7 @@ class NumpyToTorch(VectorWrapper):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based vector environment to wrap
|
||||
env: The NumPy-based vector environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
super().__init__(env)
|
||||
@@ -60,8 +60,8 @@ class NumpyToTorch(VectorWrapper):
|
||||
Returns:
|
||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||
"""
|
||||
jax_action = torch_to_numpy(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
||||
numpy_action = torch_to_numpy(actions)
|
||||
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
|
||||
|
||||
return (
|
||||
numpy_to_torch(obs, self.device),
|
||||
@@ -81,7 +81,7 @@ class NumpyToTorch(VectorWrapper):
|
||||
|
||||
Args:
|
||||
seed: The seed for resetting the environment
|
||||
options: The options for resetting the environment, these are converted to jax arrays.
|
||||
options: The options for resetting the environment, these are converted to NumPy arrays.
|
||||
|
||||
Returns:
|
||||
PyTorch-based observations and info
|
||||
|
Reference in New Issue
Block a user