Skip Jax and PyTorch tests if module is missing (#466)

This commit is contained in:
ChristofKaufmann
2023-04-25 03:47:51 -07:00
committed by GitHub
parent fa5fdb435e
commit 4350e9c6b9
7 changed files with 45 additions and 30 deletions

View File

@@ -1,11 +1,13 @@
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
import numpy as np # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])

View File

@@ -1,12 +1,14 @@
"""Test the functional jax environment."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
from gymnasium.envs.phys2d.pendulum import PendulumFunctional
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])

View File

@@ -1,12 +1,14 @@
"""Tests for Jax Blackjack functional env."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import pytest
from gymnasium.envs.tabular.blackjack import BlackjackFunctional
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
from gymnasium.envs.tabular.blackjack import BlackjackFunctional # noqa: E402
def test_normal_BlackjackFunctional():

View File

@@ -1,12 +1,14 @@
"""Tests for Jax cliffwalking functional env."""
import jax
import jax.numpy as jnp
import jax.random as jrng
import pytest
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional # noqa: E402
def test_normal_CliffWalkingFunctional():

View File

@@ -10,7 +10,7 @@ from gymnasium.experimental.wrappers import __all__
def test_all_wrapper_shorten():
"""Test that all wrappers in `__alL__` are contained within the `_wrapper_to_class` conversion."""
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__)
all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys())
@@ -18,6 +18,9 @@ def test_all_wrapper_shorten():
@pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded."""
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
if wrapper_name != "vector":
try:
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
except gymnasium.error.DependencyNotInstalled as e:
pytest.skip(str(e))

View File

@@ -1,16 +1,18 @@
"""Test suite for JaxToNumpyV0."""
import jax.numpy as jnp
import numpy as np
import pytest
from gymnasium.experimental.wrappers.jax_to_numpy import (
jnp = pytest.importorskip("jax.numpy")
from gymnasium.experimental.wrappers.jax_to_numpy import ( # noqa: E402
JaxToNumpyV0,
jax_to_numpy,
numpy_to_jax,
)
from gymnasium.utils.env_checker import data_equivalence
from tests.testing_env import GenericTestEnv
from gymnasium.utils.env_checker import data_equivalence # noqa: E402
from tests.testing_env import GenericTestEnv # noqa: E402
@pytest.mark.parametrize(

View File

@@ -1,16 +1,18 @@
"""Test suite for TorchToJaxV0."""
import jax.numpy as jnp
import numpy as np
import pytest
import torch
from gymnasium.experimental.wrappers.jax_to_torch import (
jnp = pytest.importorskip("jax.numpy")
torch = pytest.importorskip("torch")
from gymnasium.experimental.wrappers.jax_to_torch import ( # noqa: E402
JaxToTorchV0,
jax_to_torch,
torch_to_jax,
)
from tests.testing_env import GenericTestEnv
from tests.testing_env import GenericTestEnv # noqa: E402
def torch_data_equivalence(data_1, data_2) -> bool: