mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 12:57:38 +00:00
Refer to numpy instead of jax [for vars and docs] in vector NumpyToTorch (#1319)
This commit is contained in:
@@ -42,7 +42,7 @@ class NumpyToTorch(VectorWrapper):
|
|||||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The Jax-based vector environment to wrap
|
env: The NumPy-based vector environment to wrap
|
||||||
device: The device the torch Tensors should be moved to
|
device: The device the torch Tensors should be moved to
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
@@ -60,8 +60,8 @@ class NumpyToTorch(VectorWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
|
||||||
"""
|
"""
|
||||||
jax_action = torch_to_numpy(actions)
|
numpy_action = torch_to_numpy(actions)
|
||||||
obs, reward, terminated, truncated, info = self.env.step(jax_action)
|
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
numpy_to_torch(obs, self.device),
|
numpy_to_torch(obs, self.device),
|
||||||
@@ -81,7 +81,7 @@ class NumpyToTorch(VectorWrapper):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: The seed for resetting the environment
|
seed: The seed for resetting the environment
|
||||||
options: The options for resetting the environment, these are converted to jax arrays.
|
options: The options for resetting the environment, these are converted to NumPy arrays.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PyTorch-based observations and info
|
PyTorch-based observations and info
|
||||||
|
Reference in New Issue
Block a user