Refer to numpy instead of jax [for vars and docs] in vector NumpyToTorch (#1319)

This commit is contained in:
Petr Kuderov
2025-03-04 16:28:47 +03:00
committed by GitHub
parent 9ff8bf45dd
commit 583f6a78d6

View File

@@ -42,7 +42,7 @@ class NumpyToTorch(VectorWrapper):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
env: The NumPy-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
@@ -60,8 +60,8 @@ class NumpyToTorch(VectorWrapper):
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
numpy_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(numpy_action)
return (
numpy_to_torch(obs, self.device),
@@ -81,7 +81,7 @@ class NumpyToTorch(VectorWrapper):
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
options: The options for resetting the environment, these are converted to NumPy arrays.
Returns:
PyTorch-based observations and info