mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 01:27:29 +00:00
documentation: replace "Jax" with "NumPy" and make capitalization uniform (#1206)
This commit is contained in:
@@ -37,25 +37,25 @@ def torch_to_numpy(value: Any) -> Any:
|
||||
|
||||
@torch_to_numpy.register(numbers.Number)
|
||||
def _number_to_numpy(value: numbers.Number) -> Any:
|
||||
"""Convert a python number (int, float, complex) to a numpy array."""
|
||||
"""Convert a python number (int, float, complex) to a NumPy array."""
|
||||
return np.array(value)
|
||||
|
||||
|
||||
@torch_to_numpy.register(torch.Tensor)
|
||||
def _torch_to_numpy(value: torch.Tensor) -> Any:
|
||||
"""Convert a torch.Tensor to a numpy array."""
|
||||
"""Convert a torch.Tensor to a NumPy array."""
|
||||
return value.numpy(force=True)
|
||||
|
||||
|
||||
@torch_to_numpy.register(abc.Mapping)
|
||||
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
|
||||
"""Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array."""
|
||||
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
|
||||
|
||||
|
||||
@torch_to_numpy.register(abc.Iterable)
|
||||
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
|
||||
"""Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
@@ -66,7 +66,7 @@ def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
|
||||
|
||||
@functools.singledispatch
|
||||
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
"""Converts a Jax Array into a PyTorch Tensor."""
|
||||
"""Converts a NumPy Array into a PyTorch Tensor."""
|
||||
raise Exception(
|
||||
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
|
||||
)
|
||||
@@ -75,7 +75,7 @@ def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
|
||||
@numpy_to_torch.register(numbers.Number)
|
||||
@numpy_to_torch.register(np.ndarray)
|
||||
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
|
||||
"""Converts a Jax Array into a PyTorch Tensor."""
|
||||
"""Converts a NumPy Array into a PyTorch Tensor."""
|
||||
assert torch is not None
|
||||
tensor = torch.tensor(value)
|
||||
if device:
|
||||
@@ -87,7 +87,7 @@ def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Te
|
||||
def _numpy_mapping_to_torch(
|
||||
value: Mapping[str, Any], device: Device | None = None
|
||||
) -> Mapping[str, Any]:
|
||||
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
|
||||
"""Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors."""
|
||||
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def _numpy_mapping_to_torch(
|
||||
def _numpy_iterable_to_torch(
|
||||
value: Iterable[Any], device: Device | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
|
||||
"""Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors."""
|
||||
if hasattr(value, "_make"):
|
||||
# namedtuple - underline used to prevent potential name conflicts
|
||||
# noinspection PyProtectedMember
|
||||
@@ -140,7 +140,7 @@ class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
env: The Jax-based environment to wrap
|
||||
env: The NumPy-based environment to wrap
|
||||
device: The device the torch Tensors should be moved to
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(self, device=device)
|
||||
|
Reference in New Issue
Block a user