mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
switch to pytest (#495)
* switch to pytest * remove observation space sampling * fix test
This commit is contained in:
@@ -249,13 +249,11 @@ See the ``examples`` directory.
|
||||
Testing
|
||||
=======
|
||||
|
||||
We are using `nose2 <https://github.com/nose-devs/nose2>`_ for tests. You can run them via:
|
||||
We are using `pytest <http://doc.pytest.org>`_ for tests. You can run them via:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
nose2
|
||||
|
||||
You can also run tests in a specific directory by using the ``-s`` option, or by passing in the specific name of the test. See the `nose2 docs <http://nose2.readthedocs.org/en/latest/usage.html#naming-tests>`_ for more details.
|
||||
pytest
|
||||
|
||||
What's new
|
||||
----------
|
||||
|
@@ -254,6 +254,8 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
|
||||
pos = self.read_head_position
|
||||
if pos < 0:
|
||||
return self.base
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = pos.item()
|
||||
try:
|
||||
return self.input_data[pos]
|
||||
except IndexError:
|
||||
|
25
gym/envs/tests/spec_list.py
Normal file
25
gym/envs/tests/spec_list.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from gym import envs
|
||||
import os
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def should_skip_env_spec_for_tests(spec):
|
||||
# We skip tests for envs that require dependencies or are otherwise
|
||||
# troublesome to run frequently
|
||||
ep = spec._entry_point
|
||||
# Skip mujoco tests for pull request CI
|
||||
skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
|
||||
if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
|
||||
return True
|
||||
if ( spec.id.startswith("Go") or
|
||||
spec.id.startswith("Hex") or
|
||||
ep.startswith('gym.envs.box2d:') or
|
||||
ep.startswith('gym.envs.parameter_tuning:') or
|
||||
ep.startswith('gym.envs.safety:Semisuper') or
|
||||
(ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong"))
|
||||
):
|
||||
logger.warning("Skipping tests for env {}".format(ep))
|
||||
return True
|
||||
return False
|
||||
|
||||
spec_list = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None and not should_skip_env_spec_for_tests(spec)]
|
@@ -1,20 +1,14 @@
|
||||
import numpy as np
|
||||
from nose2 import tools
|
||||
import pytest
|
||||
import os
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import gym
|
||||
from gym import envs, spaces
|
||||
from gym.envs.tests.spec_list import spec_list
|
||||
|
||||
from gym.envs.tests.test_envs import should_skip_env_spec_for_tests
|
||||
|
||||
specs = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None]
|
||||
@tools.params(*specs)
|
||||
@pytest.mark.parametrize("spec", spec_list)
|
||||
def test_env(spec):
|
||||
if should_skip_env_spec_for_tests(spec):
|
||||
return
|
||||
|
||||
# Note that this precludes running this test in multiple
|
||||
# threads. However, we probably already can't do multithreading
|
||||
@@ -24,7 +18,6 @@ def test_env(spec):
|
||||
env1 = spec.make()
|
||||
env1.seed(0)
|
||||
action_samples1 = [env1.action_space.sample() for i in range(4)]
|
||||
observation_samples1 = [env1.observation_space.sample() for i in range(4)]
|
||||
initial_observation1 = env1.reset()
|
||||
step_responses1 = [env1.step(action) for action in action_samples1]
|
||||
env1.close()
|
||||
@@ -34,7 +27,6 @@ def test_env(spec):
|
||||
env2 = spec.make()
|
||||
env2.seed(0)
|
||||
action_samples2 = [env2.action_space.sample() for i in range(4)]
|
||||
observation_samples2 = [env2.observation_space.sample() for i in range(4)]
|
||||
initial_observation2 = env2.reset()
|
||||
step_responses2 = [env2.step(action) for action in action_samples2]
|
||||
env2.close()
|
||||
@@ -42,9 +34,6 @@ def test_env(spec):
|
||||
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
|
||||
assert_equals(action_sample1, action_sample2), '[{}] action_sample1: {}, action_sample2: {}'.format(i, action_sample1, action_sample2)
|
||||
|
||||
for (observation_sample1, observation_sample2) in zip(observation_samples1, observation_samples2):
|
||||
assert_equals(observation_sample1, observation_sample2)
|
||||
|
||||
# Don't check rollout equality if it's a a nondeterministic
|
||||
# environment.
|
||||
if spec.nondeterministic:
|
||||
|
@@ -1,46 +1,18 @@
|
||||
import numpy as np
|
||||
from nose2 import tools
|
||||
import pytest
|
||||
import os
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import gym
|
||||
from gym import envs
|
||||
|
||||
def should_skip_env_spec_for_tests(spec):
|
||||
# We skip tests for envs that require dependencies or are otherwise
|
||||
# troublesome to run frequently
|
||||
|
||||
ep = spec._entry_point
|
||||
|
||||
# Skip mujoco tests for pull request CI
|
||||
skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
|
||||
if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
|
||||
return True
|
||||
if ( spec.id.startswith("Go") or
|
||||
spec.id.startswith("Hex") or
|
||||
ep.startswith('gym.envs.box2d:') or
|
||||
ep.startswith('gym.envs.parameter_tuning:') or
|
||||
ep.startswith('gym.envs.safety:Semisuper') or
|
||||
(ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong"))
|
||||
):
|
||||
logger.warning("Skipping tests for env {}".format(ep))
|
||||
return True
|
||||
|
||||
return False
|
||||
from gym.envs.tests.spec_list import spec_list
|
||||
|
||||
|
||||
# This runs a smoketest on each official registered env. We may want
|
||||
# to try also running environments which are not officially registered
|
||||
# envs.
|
||||
specs = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None]
|
||||
|
||||
@tools.params(*specs)
|
||||
@pytest.mark.parametrize("spec", spec_list)
|
||||
def test_env(spec):
|
||||
if should_skip_env_spec_for_tests(spec):
|
||||
return
|
||||
|
||||
env = spec.make()
|
||||
ob_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
@@ -3,14 +3,11 @@ import json
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from nose2 import tools
|
||||
import logging
|
||||
import pytest
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from gym import envs, spaces
|
||||
|
||||
from gym.envs.tests.test_envs import should_skip_env_spec_for_tests
|
||||
from gym.envs.tests.spec_list import spec_list
|
||||
|
||||
DATA_DIR = os.path.dirname(__file__)
|
||||
ROLLOUT_STEPS = 100
|
||||
@@ -62,14 +59,13 @@ def generate_rollout_hash(spec):
|
||||
|
||||
return observations_hash, actions_hash, rewards_hash, dones_hash
|
||||
|
||||
specs = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None]
|
||||
@tools.params(*specs)
|
||||
@pytest.mark.parametrize("spec", spec_list)
|
||||
def test_env_semantics(spec):
|
||||
with open(ROLLOUT_FILE) as data_file:
|
||||
rollout_dict = json.load(data_file)
|
||||
|
||||
if spec.id not in rollout_dict or should_skip_env_spec_for_tests(spec):
|
||||
if not spec.nondeterministic or should_skip_env_spec_for_tests(spec):
|
||||
if spec.id not in rollout_dict:
|
||||
if not spec.nondeterministic:
|
||||
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
|
||||
return
|
||||
|
||||
|
@@ -2,9 +2,7 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from nose2 import tools
|
||||
|
||||
import gym
|
||||
from gym.monitoring import VideoRecorder
|
||||
|
@@ -1,16 +1,16 @@
|
||||
import json # note: ujson fails this test due to float equality
|
||||
|
||||
import numpy as np
|
||||
from nose2 import tools
|
||||
|
||||
import pytest
|
||||
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete
|
||||
|
||||
@tools.params(Discrete(3),
|
||||
|
||||
@pytest.mark.parametrize("space", [
|
||||
Discrete(3),
|
||||
Tuple([Discrete(5), Discrete(10)]),
|
||||
Tuple([Discrete(5), Box(np.array([0,0]),np.array([1,5]))]),
|
||||
Tuple((Discrete(5), Discrete(2), Discrete(2))),
|
||||
MultiDiscrete([ [0, 1], [0, 1], [0, 100] ]),
|
||||
)
|
||||
MultiDiscrete([ [0, 1], [0, 1], [0, 100] ])
|
||||
])
|
||||
def test_roundtripping(space):
|
||||
sample_1 = space.sample()
|
||||
sample_2 = space.sample()
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Testing
|
||||
nose2
|
||||
pytest
|
||||
mock
|
||||
|
||||
-e .[all]
|
||||
|
2
setup.py
2
setup.py
@@ -36,5 +36,5 @@ setup(name='gym',
|
||||
],
|
||||
extras_require=extras,
|
||||
package_data={'gym': ['envs/mujoco/assets/*.xml', 'envs/classic_control/assets/*.png']},
|
||||
tests_require=['nose2', 'mock'],
|
||||
tests_require=['pytest', 'mock'],
|
||||
)
|
||||
|
8
tox.ini
8
tox.ini
@@ -10,7 +10,7 @@ envlist = py27, py34
|
||||
whitelist_externals=make
|
||||
passenv=DISPLAY TRAVIS*
|
||||
deps =
|
||||
nose2
|
||||
pytest
|
||||
mock
|
||||
atari_py>=0.0.17
|
||||
Pillow
|
||||
@@ -27,13 +27,13 @@ deps =
|
||||
six
|
||||
pyglet>=1.2.0
|
||||
commands =
|
||||
nose2 {posargs}
|
||||
pytest {posargs}
|
||||
|
||||
[testenv:py27]
|
||||
whitelist_externals=make
|
||||
passenv=DISPLAY TRAVIS*
|
||||
deps =
|
||||
nose2
|
||||
pytest
|
||||
mock
|
||||
atari_py>=0.0.17
|
||||
Pillow
|
||||
@@ -50,4 +50,4 @@ deps =
|
||||
six
|
||||
pyglet>=1.2.0
|
||||
commands =
|
||||
nose2 {posargs}
|
||||
pytest {posargs}
|
||||
|
Reference in New Issue
Block a user