diff --git a/gymnasium/wrappers/jax_to_torch.py b/gymnasium/wrappers/jax_to_torch.py index 563160180..fb13d37eb 100644 --- a/gymnasium/wrappers/jax_to_torch.py +++ b/gymnasium/wrappers/jax_to_torch.py @@ -119,9 +119,9 @@ def _jax_iterable_to_torch( if hasattr(value, "_make"): # namedtuple - underline used to prevent potential name conflicts # noinspection PyProtectedMember - return type(value)._make(jax_to_torch(v) for v in value) + return type(value)._make(jax_to_torch(v, device) for v in value) else: - return type(value)(jax_to_torch(v) for v in value) + return type(value)(jax_to_torch(v, device) for v in value) class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): diff --git a/gymnasium/wrappers/numpy_to_torch.py b/gymnasium/wrappers/numpy_to_torch.py index 9743d387e..ac2581beb 100644 --- a/gymnasium/wrappers/numpy_to_torch.py +++ b/gymnasium/wrappers/numpy_to_torch.py @@ -94,9 +94,9 @@ def _numpy_iterable_to_torch( if hasattr(value, "_make"): # namedtuple - underline used to prevent potential name conflicts # noinspection PyProtectedMember - return type(value)._make(numpy_to_torch(v) for v in value) + return type(value)._make(numpy_to_torch(v, device) for v in value) else: - return type(value)(numpy_to_torch(v) for v in value) + return type(value)(numpy_to_torch(v, device) for v in value) class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):