documentation: replace "Jax" with "NumPy" and make capitalization uniform (#1206)

This commit is contained in:
enjoh
2024-10-12 17:52:30 +02:00
committed by GitHub
parent f3fb8a5891
commit d571ed6301

View File

@@ -37,25 +37,25 @@ def torch_to_numpy(value: Any) -> Any:
@torch_to_numpy.register(numbers.Number)
def _number_to_numpy(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a numpy array."""
"""Convert a python number (int, float, complex) to a NumPy array."""
return np.array(value)
@torch_to_numpy.register(torch.Tensor)
def _torch_to_numpy(value: torch.Tensor) -> Any:
"""Convert a torch.Tensor to a numpy array."""
"""Convert a torch.Tensor to a NumPy array."""
return value.numpy(force=True)
@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
"""Converts a mapping of PyTorch Tensors into a Dictionary of NumPy Array."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
"""Converts an Iterable from PyTorch Tensors to an iterable of NumPy Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
@@ -66,7 +66,7 @@ def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
@functools.singledispatch
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax Array into a PyTorch Tensor."""
"""Converts a NumPy Array into a PyTorch Tensor."""
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@@ -75,7 +75,7 @@ def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
@numpy_to_torch.register(numbers.Number)
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
"""Converts a Jax Array into a PyTorch Tensor."""
"""Converts a NumPy Array into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
@@ -87,7 +87,7 @@ def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Te
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
"""Converts a mapping of NumPy Array into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
@@ -95,7 +95,7 @@ def _numpy_mapping_to_torch(
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
"""Converts an Iterable from NumPy Array to an iterable of PyTorch Tensors."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
@@ -140,7 +140,7 @@ class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
env: The NumPy-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)