mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
Skip Jax and PyTorch tests if module is missing (#466)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
import jax.numpy as jnp # noqa: E402
|
||||
import jax.random as jrng # noqa: E402
|
||||
import numpy as np # noqa: E402
|
||||
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||
|
@@ -1,12 +1,14 @@
|
||||
"""Test the functional jax environment."""
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
import jax.numpy as jnp # noqa: E402
|
||||
import jax.random as jrng # noqa: E402
|
||||
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||
|
@@ -1,12 +1,14 @@
|
||||
"""Tests for Jax Blackjack functional env."""
|
||||
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.tabular.blackjack import BlackjackFunctional
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
import jax.numpy as jnp # noqa: E402
|
||||
import jax.random as jrng # noqa: E402
|
||||
|
||||
from gymnasium.envs.tabular.blackjack import BlackjackFunctional # noqa: E402
|
||||
|
||||
|
||||
def test_normal_BlackjackFunctional():
|
||||
|
@@ -1,12 +1,14 @@
|
||||
"""Tests for Jax cliffwalking functional env."""
|
||||
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
import jax.numpy as jnp # noqa: E402
|
||||
import jax.random as jrng # noqa: E402
|
||||
|
||||
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional # noqa: E402
|
||||
|
||||
|
||||
def test_normal_CliffWalkingFunctional():
|
||||
|
@@ -10,7 +10,7 @@ from gymnasium.experimental.wrappers import __all__
|
||||
|
||||
|
||||
def test_all_wrapper_shorten():
|
||||
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion."""
|
||||
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
|
||||
all_wrappers = set(__all__)
|
||||
all_wrappers.remove("vector")
|
||||
assert all_wrappers == set(_wrapper_to_class.keys())
|
||||
@@ -18,6 +18,9 @@ def test_all_wrapper_shorten():
|
||||
|
||||
@pytest.mark.parametrize("wrapper_name", __all__)
|
||||
def test_all_wrappers_shortened(wrapper_name):
|
||||
"""Check that each element of the `__all__` wrappers can be loaded."""
|
||||
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
|
||||
if wrapper_name != "vector":
|
||||
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
||||
try:
|
||||
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
||||
except gymnasium.error.DependencyNotInstalled as e:
|
||||
pytest.skip(str(e))
|
||||
|
@@ -1,16 +1,18 @@
|
||||
"""Test suite for JaxToNumpyV0."""
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import (
|
||||
|
||||
jnp = pytest.importorskip("jax.numpy")
|
||||
|
||||
from gymnasium.experimental.wrappers.jax_to_numpy import ( # noqa: E402
|
||||
JaxToNumpyV0,
|
||||
jax_to_numpy,
|
||||
numpy_to_jax,
|
||||
)
|
||||
from gymnasium.utils.env_checker import data_equivalence
|
||||
from tests.testing_env import GenericTestEnv
|
||||
from gymnasium.utils.env_checker import data_equivalence # noqa: E402
|
||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@@ -1,16 +1,18 @@
|
||||
"""Test suite for TorchToJaxV0."""
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import (
|
||||
|
||||
jnp = pytest.importorskip("jax.numpy")
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from gymnasium.experimental.wrappers.jax_to_torch import ( # noqa: E402
|
||||
JaxToTorchV0,
|
||||
jax_to_torch,
|
||||
torch_to_jax,
|
||||
)
|
||||
from tests.testing_env import GenericTestEnv
|
||||
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||
|
||||
|
||||
def torch_data_equivalence(data_1, data_2) -> bool:
|
||||
|
Reference in New Issue
Block a user