Skip Jax and PyTorch tests if module is missing (#466)

This commit is contained in:
ChristofKaufmann
2023-04-25 03:47:51 -07:00
committed by GitHub
parent fa5fdb435e
commit 4350e9c6b9
7 changed files with 45 additions and 30 deletions

View File

@@ -1,11 +1,13 @@
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import pytest import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
import numpy as np # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])

View File

@@ -1,12 +1,14 @@
"""Test the functional jax environment.""" """Test the functional jax environment."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import pytest import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])

View File

@@ -1,12 +1,14 @@
"""Tests for Jax Blackjack functional env.""" """Tests for Jax Blackjack functional env."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import pytest import pytest
from gymnasium.envs.tabular.blackjack import BlackjackFunctional
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
from gymnasium.envs.tabular.blackjack import BlackjackFunctional # noqa: E402
def test_normal_BlackjackFunctional(): def test_normal_BlackjackFunctional():

View File

@@ -1,12 +1,14 @@
"""Tests for Jax cliffwalking functional env.""" """Tests for Jax cliffwalking functional env."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import pytest import pytest
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional # noqa: E402
def test_normal_CliffWalkingFunctional(): def test_normal_CliffWalkingFunctional():

View File

@@ -10,7 +10,7 @@ from gymnasium.experimental.wrappers import __all__
def test_all_wrapper_shorten(): def test_all_wrapper_shorten():
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion.""" """Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__) all_wrappers = set(__all__)
all_wrappers.remove("vector") all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys()) assert all_wrappers == set(_wrapper_to_class.keys())
@@ -18,6 +18,9 @@ def test_all_wrapper_shorten():
@pytest.mark.parametrize("wrapper_name", __all__) @pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name): def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded.""" """Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
if wrapper_name != "vector": if wrapper_name != "vector":
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None try:
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
except gymnasium.error.DependencyNotInstalled as e:
pytest.skip(str(e))

View File

@@ -1,16 +1,18 @@
"""Test suite for JaxToNumpyV0.""" """Test suite for JaxToNumpyV0."""
import jax.numpy as jnp
import numpy as np import numpy as np
import pytest import pytest
from gymnasium.experimental.wrappers.jax_to_numpy import (
jnp = pytest.importorskip("jax.numpy")
from gymnasium.experimental.wrappers.jax_to_numpy import ( # noqa: E402
JaxToNumpyV0, JaxToNumpyV0,
jax_to_numpy, jax_to_numpy,
numpy_to_jax, numpy_to_jax,
) )
from gymnasium.utils.env_checker import data_equivalence from gymnasium.utils.env_checker import data_equivalence # noqa: E402
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv # noqa: E402
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -1,16 +1,18 @@
"""Test suite for TorchToJaxV0.""" """Test suite for TorchToJaxV0."""
import jax.numpy as jnp
import numpy as np import numpy as np
import pytest import pytest
import torch
from gymnasium.experimental.wrappers.jax_to_torch import (
jnp = pytest.importorskip("jax.numpy")
torch = pytest.importorskip("torch")
from gymnasium.experimental.wrappers.jax_to_torch import ( # noqa: E402
JaxToTorchV0, JaxToTorchV0,
jax_to_torch, jax_to_torch,
torch_to_jax, torch_to_jax,
) )
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv # noqa: E402
def torch_data_equivalence(data_1, data_2) -> bool: def torch_data_equivalence(data_1, data_2) -> bool: