mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
Properly hide additional requirements between separate imports (#323)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f6d41e85f9
commit
0bb8b0cc73
@@ -17,7 +17,7 @@ from gymnasium.envs.registration import (
|
|||||||
pprint_registry,
|
pprint_registry,
|
||||||
make_vec,
|
make_vec,
|
||||||
)
|
)
|
||||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
|
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -43,7 +43,6 @@ __all__ = [
|
|||||||
"wrappers",
|
"wrappers",
|
||||||
"error",
|
"error",
|
||||||
"logger",
|
"logger",
|
||||||
"experimental",
|
|
||||||
]
|
]
|
||||||
__version__ = "0.27.1"
|
__version__ = "0.27.1"
|
||||||
|
|
||||||
|
@@ -12,6 +12,7 @@ import gymnasium as gym
|
|||||||
from gymnasium import logger, spaces
|
from gymnasium import logger, spaces
|
||||||
from gymnasium.envs.classic_control import utils
|
from gymnasium.envs.classic_control import utils
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.experimental.vector import VectorEnv
|
||||||
from gymnasium.vector.utils import batch_space
|
from gymnasium.vector.utils import batch_space
|
||||||
|
|
||||||
|
|
||||||
@@ -315,7 +316,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
|||||||
self.isopen = False
|
self.isopen = False
|
||||||
|
|
||||||
|
|
||||||
class CartPoleVectorEnv(gym.experimental.VectorEnv):
|
class CartPoleVectorEnv(VectorEnv):
|
||||||
metadata = {
|
metadata = {
|
||||||
"render_modes": ["human", "rgb_array"],
|
"render_modes": ["human", "rgb_array"],
|
||||||
"render_fps": 50,
|
"render_fps": 50,
|
||||||
|
@@ -13,8 +13,8 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Iterable, Sequence
|
from typing import Any, Callable, Iterable, Sequence
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium import Env, Wrapper, error, logger
|
from gymnasium import Env, Wrapper, error, logger
|
||||||
|
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
|
||||||
from gymnasium.wrappers import (
|
from gymnasium.wrappers import (
|
||||||
AutoResetWrapper,
|
AutoResetWrapper,
|
||||||
HumanRendering,
|
HumanRendering,
|
||||||
@@ -63,7 +63,7 @@ class EnvCreator(Protocol):
|
|||||||
class VectorEnvCreator(Protocol):
|
class VectorEnvCreator(Protocol):
|
||||||
"""Function type expected for an environment."""
|
"""Function type expected for an environment."""
|
||||||
|
|
||||||
def __call__(self, **kwargs: Any) -> gym.experimental.VectorEnv:
|
def __call__(self, **kwargs: Any) -> VectorEnv:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@@ -695,7 +695,7 @@ def make_vec(
|
|||||||
vector_kwargs: dict[str, Any] | None = None,
|
vector_kwargs: dict[str, Any] | None = None,
|
||||||
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
|
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> gym.experimental.VectorEnv:
|
) -> VectorEnv:
|
||||||
"""Create a vector environment according to the given ID.
|
"""Create a vector environment according to the given ID.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@@ -778,12 +778,12 @@ def make_vec(
|
|||||||
return _env
|
return _env
|
||||||
|
|
||||||
if vectorization_mode == "sync":
|
if vectorization_mode == "sync":
|
||||||
env = gym.experimental.SyncVectorEnv(
|
env = SyncVectorEnv(
|
||||||
env_fns=[_create_env for _ in range(num_envs)],
|
env_fns=[_create_env for _ in range(num_envs)],
|
||||||
**vector_kwargs,
|
**vector_kwargs,
|
||||||
)
|
)
|
||||||
elif vectorization_mode == "async":
|
elif vectorization_mode == "async":
|
||||||
env = gym.experimental.AsyncVectorEnv(
|
env = AsyncVectorEnv(
|
||||||
env_fns=[_create_env for _ in range(num_envs)],
|
env_fns=[_create_env for _ in range(num_envs)],
|
||||||
**vector_kwargs,
|
**vector_kwargs,
|
||||||
)
|
)
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Root __init__ of the gym experimental wrappers."""
|
"""Root __init__ of the gym experimental wrappers."""
|
||||||
|
|
||||||
|
|
||||||
from gymnasium.experimental import functional, wrappers
|
from gymnasium.experimental import functional
|
||||||
from gymnasium.experimental.functional import FuncEnv
|
from gymnasium.experimental.functional import FuncEnv
|
||||||
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||||
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||||
@@ -12,12 +12,9 @@ __all__ = [
|
|||||||
# Functional
|
# Functional
|
||||||
"FuncEnv",
|
"FuncEnv",
|
||||||
"functional",
|
"functional",
|
||||||
# Wrappers
|
|
||||||
"wrappers",
|
|
||||||
# Vector
|
# Vector
|
||||||
"VectorEnv",
|
"VectorEnv",
|
||||||
"VectorWrapper",
|
"VectorWrapper",
|
||||||
"SyncVectorEnv",
|
"SyncVectorEnv",
|
||||||
"AsyncVectorEnv",
|
"AsyncVectorEnv",
|
||||||
# "vector",
|
|
||||||
]
|
]
|
||||||
|
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import EnvSpec
|
from gymnasium.envs.registration import EnvSpec
|
||||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||||
from gymnasium.utils import seeding
|
from gymnasium.utils import seeding
|
||||||
from gymnasium.vector.utils import batch_space
|
from gymnasium.vector.utils import batch_space
|
||||||
|
|
||||||
|
@@ -23,9 +23,6 @@ from gymnasium.experimental.wrappers.lambda_reward import (
|
|||||||
LambdaRewardV0,
|
LambdaRewardV0,
|
||||||
NormalizeRewardV0,
|
NormalizeRewardV0,
|
||||||
)
|
)
|
||||||
from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0
|
|
||||||
from gymnasium.experimental.wrappers.jax_to_torch import JaxToTorchV0
|
|
||||||
from gymnasium.experimental.wrappers.numpy_to_torch import NumpyToTorchV0
|
|
||||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||||
TimeAwareObservationV0,
|
TimeAwareObservationV0,
|
||||||
@@ -85,10 +82,6 @@ __all__ = [
|
|||||||
"RenderCollectionV0",
|
"RenderCollectionV0",
|
||||||
"RecordVideoV0",
|
"RecordVideoV0",
|
||||||
"HumanRenderingV0",
|
"HumanRenderingV0",
|
||||||
# --- Data Conversion ---
|
|
||||||
"JaxToNumpyV0",
|
|
||||||
"JaxToTorchV0",
|
|
||||||
"NumpyToTorchV0",
|
|
||||||
# --- Vector ---
|
# --- Vector ---
|
||||||
"VectorRecordEpisodeStatistics",
|
"VectorRecordEpisodeStatistics",
|
||||||
"VectorListInfo",
|
"VectorListInfo",
|
||||||
|
1
gymnasium/experimental/wrappers/conversion/__init__.py
Normal file
1
gymnasium/experimental/wrappers/conversion/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually."""
|
@@ -17,7 +17,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
|||||||
from gymnasium import Env, Wrapper
|
from gymnasium import Env, Wrapper
|
||||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||||
from gymnasium.error import DependencyNotInstalled
|
from gymnasium.error import DependencyNotInstalled
|
||||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
@@ -4,8 +4,11 @@ import jax.numpy as jnp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers import JaxToNumpyV0
|
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import (
|
||||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
JaxToNumpyV0,
|
||||||
|
jax_to_numpy,
|
||||||
|
numpy_to_jax,
|
||||||
|
)
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
@@ -5,8 +5,11 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers import JaxToTorchV0
|
from gymnasium.experimental.wrappers.conversion.jax_to_torch import (
|
||||||
from gymnasium.experimental.wrappers.jax_to_torch import jax_to_torch, torch_to_jax
|
JaxToTorchV0,
|
||||||
|
jax_to_torch,
|
||||||
|
torch_to_jax,
|
||||||
|
)
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user