mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
Skip Jax and PyTorch tests if module is missing (#466)
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax.random as jrng
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
|
||||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
jax = pytest.importorskip("jax")
|
||||||
|
import jax.numpy as jnp # noqa: E402
|
||||||
|
import jax.random as jrng # noqa: E402
|
||||||
|
import numpy as np # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
|
||||||
|
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
"""Test the functional jax environment."""
|
"""Test the functional jax environment."""
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax.random as jrng
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
|
||||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
|
jax = pytest.importorskip("jax")
|
||||||
|
import jax.numpy as jnp # noqa: E402
|
||||||
|
import jax.random as jrng # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
|
||||||
|
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
"""Tests for Jax Blackjack functional env."""
|
"""Tests for Jax Blackjack functional env."""
|
||||||
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax.random as jrng
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.envs.tabular.blackjack import BlackjackFunctional
|
|
||||||
|
jax = pytest.importorskip("jax")
|
||||||
|
import jax.numpy as jnp # noqa: E402
|
||||||
|
import jax.random as jrng # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.envs.tabular.blackjack import BlackjackFunctional # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def test_normal_BlackjackFunctional():
|
def test_normal_BlackjackFunctional():
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
"""Tests for Jax cliffwalking functional env."""
|
"""Tests for Jax cliffwalking functional env."""
|
||||||
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax.random as jrng
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional
|
|
||||||
|
jax = pytest.importorskip("jax")
|
||||||
|
import jax.numpy as jnp # noqa: E402
|
||||||
|
import jax.random as jrng # noqa: E402
|
||||||
|
|
||||||
|
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def test_normal_CliffWalkingFunctional():
|
def test_normal_CliffWalkingFunctional():
|
||||||
|
@@ -10,7 +10,7 @@ from gymnasium.experimental.wrappers import __all__
|
|||||||
|
|
||||||
|
|
||||||
def test_all_wrapper_shorten():
|
def test_all_wrapper_shorten():
|
||||||
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion."""
|
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
|
||||||
all_wrappers = set(__all__)
|
all_wrappers = set(__all__)
|
||||||
all_wrappers.remove("vector")
|
all_wrappers.remove("vector")
|
||||||
assert all_wrappers == set(_wrapper_to_class.keys())
|
assert all_wrappers == set(_wrapper_to_class.keys())
|
||||||
@@ -18,6 +18,9 @@ def test_all_wrapper_shorten():
|
|||||||
|
|
||||||
@pytest.mark.parametrize("wrapper_name", __all__)
|
@pytest.mark.parametrize("wrapper_name", __all__)
|
||||||
def test_all_wrappers_shortened(wrapper_name):
|
def test_all_wrappers_shortened(wrapper_name):
|
||||||
"""Check that each element of the `__all__` wrappers can be loaded."""
|
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
|
||||||
if wrapper_name != "vector":
|
if wrapper_name != "vector":
|
||||||
|
try:
|
||||||
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
||||||
|
except gymnasium.error.DependencyNotInstalled as e:
|
||||||
|
pytest.skip(str(e))
|
||||||
|
@@ -1,16 +1,18 @@
|
|||||||
"""Test suite for JaxToNumpyV0."""
|
"""Test suite for JaxToNumpyV0."""
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers.jax_to_numpy import (
|
|
||||||
|
jnp = pytest.importorskip("jax.numpy")
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_numpy import ( # noqa: E402
|
||||||
JaxToNumpyV0,
|
JaxToNumpyV0,
|
||||||
jax_to_numpy,
|
jax_to_numpy,
|
||||||
numpy_to_jax,
|
numpy_to_jax,
|
||||||
)
|
)
|
||||||
from gymnasium.utils.env_checker import data_equivalence
|
from gymnasium.utils.env_checker import data_equivalence # noqa: E402
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@@ -1,16 +1,18 @@
|
|||||||
"""Test suite for TorchToJaxV0."""
|
"""Test suite for TorchToJaxV0."""
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from gymnasium.experimental.wrappers.jax_to_torch import (
|
|
||||||
|
jnp = pytest.importorskip("jax.numpy")
|
||||||
|
torch = pytest.importorskip("torch")
|
||||||
|
|
||||||
|
from gymnasium.experimental.wrappers.jax_to_torch import ( # noqa: E402
|
||||||
JaxToTorchV0,
|
JaxToTorchV0,
|
||||||
jax_to_torch,
|
jax_to_torch,
|
||||||
torch_to_jax,
|
torch_to_jax,
|
||||||
)
|
)
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def torch_data_equivalence(data_1, data_2) -> bool:
|
def torch_data_equivalence(data_1, data_2) -> bool:
|
||||||
|
Reference in New Issue
Block a user