mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-25 15:59:06 +00:00
Fix non-used device argument in to_torch
conversion (#1107)
This commit is contained in:
@@ -119,9 +119,9 @@ def _jax_iterable_to_torch(
|
|||||||
if hasattr(value, "_make"):
|
if hasattr(value, "_make"):
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
return type(value)._make(jax_to_torch(v) for v in value)
|
return type(value)._make(jax_to_torch(v, device) for v in value)
|
||||||
else:
|
else:
|
||||||
return type(value)(jax_to_torch(v) for v in value)
|
return type(value)(jax_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
@@ -94,9 +94,9 @@ def _numpy_iterable_to_torch(
|
|||||||
if hasattr(value, "_make"):
|
if hasattr(value, "_make"):
|
||||||
# namedtuple - underline used to prevent potential name conflicts
|
# namedtuple - underline used to prevent potential name conflicts
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
return type(value)._make(numpy_to_torch(v) for v in value)
|
return type(value)._make(numpy_to_torch(v, device) for v in value)
|
||||||
else:
|
else:
|
||||||
return type(value)(numpy_to_torch(v) for v in value)
|
return type(value)(numpy_to_torch(v, device) for v in value)
|
||||||
|
|
||||||
|
|
||||||
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
Reference in New Issue
Block a user