mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Properly hide additional requirements between separate imports (#323)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f6d41e85f9
commit
0bb8b0cc73
@@ -17,7 +17,7 @@ from gymnasium.envs.registration import (
|
||||
pprint_registry,
|
||||
make_vec,
|
||||
)
|
||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
|
||||
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -43,7 +43,6 @@ __all__ = [
|
||||
"wrappers",
|
||||
"error",
|
||||
"logger",
|
||||
"experimental",
|
||||
]
|
||||
__version__ = "0.27.1"
|
||||
|
||||
|
@@ -12,6 +12,7 @@ import gymnasium as gym
|
||||
from gymnasium import logger, spaces
|
||||
from gymnasium.envs.classic_control import utils
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.vector import VectorEnv
|
||||
from gymnasium.vector.utils import batch_space
|
||||
|
||||
|
||||
@@ -315,7 +316,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
self.isopen = False
|
||||
|
||||
|
||||
class CartPoleVectorEnv(gym.experimental.VectorEnv):
|
||||
class CartPoleVectorEnv(VectorEnv):
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 50,
|
||||
|
@@ -13,8 +13,8 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Iterable, Sequence
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import Env, Wrapper, error, logger
|
||||
from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
|
||||
from gymnasium.wrappers import (
|
||||
AutoResetWrapper,
|
||||
HumanRendering,
|
||||
@@ -63,7 +63,7 @@ class EnvCreator(Protocol):
|
||||
class VectorEnvCreator(Protocol):
|
||||
"""Function type expected for an environment."""
|
||||
|
||||
def __call__(self, **kwargs: Any) -> gym.experimental.VectorEnv:
|
||||
def __call__(self, **kwargs: Any) -> VectorEnv:
|
||||
...
|
||||
|
||||
|
||||
@@ -695,7 +695,7 @@ def make_vec(
|
||||
vector_kwargs: dict[str, Any] | None = None,
|
||||
wrappers: Sequence[Callable[[Env], Wrapper]] | None = None,
|
||||
**kwargs,
|
||||
) -> gym.experimental.VectorEnv:
|
||||
) -> VectorEnv:
|
||||
"""Create a vector environment according to the given ID.
|
||||
|
||||
Note:
|
||||
@@ -778,12 +778,12 @@ def make_vec(
|
||||
return _env
|
||||
|
||||
if vectorization_mode == "sync":
|
||||
env = gym.experimental.SyncVectorEnv(
|
||||
env = SyncVectorEnv(
|
||||
env_fns=[_create_env for _ in range(num_envs)],
|
||||
**vector_kwargs,
|
||||
)
|
||||
elif vectorization_mode == "async":
|
||||
env = gym.experimental.AsyncVectorEnv(
|
||||
env = AsyncVectorEnv(
|
||||
env_fns=[_create_env for _ in range(num_envs)],
|
||||
**vector_kwargs,
|
||||
)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Root __init__ of the gym experimental wrappers."""
|
||||
|
||||
|
||||
from gymnasium.experimental import functional, wrappers
|
||||
from gymnasium.experimental import functional
|
||||
from gymnasium.experimental.functional import FuncEnv
|
||||
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
|
||||
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
|
||||
@@ -12,12 +12,9 @@ __all__ = [
|
||||
# Functional
|
||||
"FuncEnv",
|
||||
"functional",
|
||||
# Wrappers
|
||||
"wrappers",
|
||||
# Vector
|
||||
"VectorEnv",
|
||||
"VectorWrapper",
|
||||
"SyncVectorEnv",
|
||||
"AsyncVectorEnv",
|
||||
# "vector",
|
||||
]
|
||||
|
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.vector.utils import batch_space
|
||||
|
||||
|
@@ -23,9 +23,6 @@ from gymnasium.experimental.wrappers.lambda_reward import (
|
||||
LambdaRewardV0,
|
||||
NormalizeRewardV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import JaxToTorchV0
|
||||
from gymnasium.experimental.wrappers.numpy_to_torch import NumpyToTorchV0
|
||||
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
|
||||
from gymnasium.experimental.wrappers.stateful_observation import (
|
||||
TimeAwareObservationV0,
|
||||
@@ -85,10 +82,6 @@ __all__ = [
|
||||
"RenderCollectionV0",
|
||||
"RecordVideoV0",
|
||||
"HumanRenderingV0",
|
||||
# --- Data Conversion ---
|
||||
"JaxToNumpyV0",
|
||||
"JaxToTorchV0",
|
||||
"NumpyToTorchV0",
|
||||
# --- Vector ---
|
||||
"VectorRecordEpisodeStatistics",
|
||||
"VectorListInfo",
|
||||
|
1
gymnasium/experimental/wrappers/conversion/__init__.py
Normal file
1
gymnasium/experimental/wrappers/conversion/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""This is deliberately empty to avoid introducing redundant imports -- import each submodule individually."""
|
@@ -17,7 +17,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union
|
||||
from gymnasium import Env, Wrapper
|
||||
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import jax_to_numpy
|
||||
|
||||
|
||||
try:
|
@@ -4,8 +4,11 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.experimental.wrappers import JaxToNumpyV0
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_numpy import (
|
||||
JaxToNumpyV0,
|
||||
jax_to_numpy,
|
||||
numpy_to_jax,
|
||||
)
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
@@ -5,8 +5,11 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from gymnasium.experimental.wrappers import JaxToTorchV0
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import jax_to_torch, torch_to_jax
|
||||
from gymnasium.experimental.wrappers.conversion.jax_to_torch import (
|
||||
JaxToTorchV0,
|
||||
jax_to_torch,
|
||||
torch_to_jax,
|
||||
)
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user