Refer to numpy instead of jax [for vars and docs] in vector NumpyToTorch (#1319)

This commit is contained in:
Petr Kuderov
2025-03-04 16:28:47 +03:00
committed by GitHub
parent 9ff8bf45dd
commit 583f6a78d6

View File

@@ -42,7 +42,7 @@ class NumpyToTorch(VectorWrapper):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors. """Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args: Args:
env: The Jax-based vector environment to wrap env: The NumPy-based vector environment to wrap
device: The device the torch Tensors should be moved to device: The device the torch Tensors should be moved to
""" """
super().__init__(env) super().__init__(env)
@@ -60,8 +60,8 @@ class NumpyToTorch(VectorWrapper):
Returns: Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
""" """
jax_action = torch_to_numpy(actions) numpy_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action) obs, reward, terminated, truncated, info = self.env.step(numpy_action)
return ( return (
numpy_to_torch(obs, self.device), numpy_to_torch(obs, self.device),
@@ -81,7 +81,7 @@ class NumpyToTorch(VectorWrapper):
Args: Args:
seed: The seed for resetting the environment seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays. options: The options for resetting the environment, these are converted to NumPy arrays.
Returns: Returns:
PyTorch-based observations and info PyTorch-based observations and info