mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-20 22:12:03 +00:00
Py36+ code style in tests (#2547)
This commit is contained in:
@@ -43,7 +43,7 @@ def should_skip_env_spec_for_tests(spec):
|
|||||||
and not spec.id.startswith("Seaquest")
|
and not spec.id.startswith("Seaquest")
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
logger.warn("Skipping tests for env {}".format(ep))
|
logger.warn(f"Skipping tests for env {ep}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@@ -34,9 +34,7 @@ def test_env(spec):
|
|||||||
print("action_samples1=", action_samples1)
|
print("action_samples1=", action_samples1)
|
||||||
print("action_samples2=", action_samples2)
|
print("action_samples2=", action_samples2)
|
||||||
print(
|
print(
|
||||||
"[{}] action_sample1: {}, action_sample2: {}".format(
|
f"[{i}] action_sample1: {action_sample1}, action_sample2: {action_sample2}"
|
||||||
i, action_sample1, action_sample2
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -50,23 +48,21 @@ def test_env(spec):
|
|||||||
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
|
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
|
||||||
zip(step_responses1, step_responses2)
|
zip(step_responses1, step_responses2)
|
||||||
):
|
):
|
||||||
assert_equals(o1, o2, "[{}] ".format(i))
|
assert_equals(o1, o2, f"[{i}] ")
|
||||||
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
|
assert r1 == r2, f"[{i}] r1: {r1}, r2: {r2}"
|
||||||
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
|
assert d1 == d2, f"[{i}] d1: {d1}, d2: {d2}"
|
||||||
|
|
||||||
# Go returns a Pachi game board in info, which doesn't
|
# Go returns a Pachi game board in info, which doesn't
|
||||||
# properly check equality. For now, we hack around this by
|
# properly check equality. For now, we hack around this by
|
||||||
# just skipping Go.
|
# just skipping Go.
|
||||||
if spec.id not in ["Go9x9-v0", "Go19x19-v0"]:
|
if spec.id not in ["Go9x9-v0", "Go19x19-v0"]:
|
||||||
assert_equals(i1, i2, "[{}] ".format(i))
|
assert_equals(i1, i2, f"[{i}] ")
|
||||||
|
|
||||||
|
|
||||||
def assert_equals(a, b, prefix=None):
|
def assert_equals(a, b, prefix=None):
|
||||||
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
|
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
|
||||||
if isinstance(a, dict):
|
if isinstance(a, dict):
|
||||||
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(
|
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
|
||||||
prefix, a, b
|
|
||||||
)
|
|
||||||
|
|
||||||
for k in a.keys():
|
for k in a.keys():
|
||||||
v_a = a[k]
|
v_a = a[k]
|
||||||
|
@@ -26,24 +26,24 @@ def test_env(spec):
|
|||||||
ob_space = env.observation_space
|
ob_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
ob = env.reset()
|
ob = env.reset()
|
||||||
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
|
assert ob_space.contains(ob), f"Reset observation: {ob!r} not in space"
|
||||||
if isinstance(ob_space, Box):
|
if isinstance(ob_space, Box):
|
||||||
# Only checking dtypes for Box spaces to avoid iterating through tuple entries
|
# Only checking dtypes for Box spaces to avoid iterating through tuple entries
|
||||||
assert (
|
assert (
|
||||||
ob.dtype == ob_space.dtype
|
ob.dtype == ob_space.dtype
|
||||||
), "Reset observation dtype: {}, expected: {}".format(ob.dtype, ob_space.dtype)
|
), f"Reset observation dtype: {ob.dtype}, expected: {ob_space.dtype}"
|
||||||
|
|
||||||
a = act_space.sample()
|
a = act_space.sample()
|
||||||
observation, reward, done, _info = env.step(a)
|
observation, reward, done, _info = env.step(a)
|
||||||
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(
|
assert ob_space.contains(
|
||||||
observation
|
observation
|
||||||
)
|
), f"Step observation: {observation!r} not in space"
|
||||||
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
|
assert np.isscalar(reward), f"{reward} is not a scalar for {env}"
|
||||||
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)
|
assert isinstance(done, bool), f"Expected {done} to be a boolean"
|
||||||
if isinstance(ob_space, Box):
|
if isinstance(ob_space, Box):
|
||||||
assert (
|
assert (
|
||||||
observation.dtype == ob_space.dtype
|
observation.dtype == ob_space.dtype
|
||||||
), "Step observation dtype: {}, expected: {}".format(ob.dtype, ob_space.dtype)
|
), f"Step observation dtype: {ob.dtype}, expected: {ob_space.dtype}"
|
||||||
|
|
||||||
for mode in env.metadata.get("render.modes", []):
|
for mode in env.metadata.get("render.modes", []):
|
||||||
env.render(mode=mode)
|
env.render(mode=mode)
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import gym
|
import gym
|
||||||
from gym import error, envs
|
from gym import error, envs
|
||||||
from gym.envs import registration
|
from gym.envs import registration
|
||||||
@@ -89,11 +88,9 @@ def test_missing_lookup():
|
|||||||
def test_malformed_lookup():
|
def test_malformed_lookup():
|
||||||
registry = registration.EnvRegistry()
|
registry = registration.EnvRegistry()
|
||||||
try:
|
try:
|
||||||
registry.spec(u"“Breakout-v0”")
|
registry.spec("“Breakout-v0”")
|
||||||
except error.Error as e:
|
except error.Error as e:
|
||||||
assert "malformed environment ID" in "{}".format(
|
assert "malformed environment ID" in f"{e}", f"Unexpected message: {e}"
|
||||||
e
|
|
||||||
), "Unexpected message: {}".format(e)
|
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
@@ -50,8 +50,8 @@ def test_roundtripping(space):
|
|||||||
s1p = space.to_jsonable([sample_1_prime])
|
s1p = space.to_jsonable([sample_1_prime])
|
||||||
s2 = space.to_jsonable([sample_2])
|
s2 = space.to_jsonable([sample_2])
|
||||||
s2p = space.to_jsonable([sample_2_prime])
|
s2p = space.to_jsonable([sample_2_prime])
|
||||||
assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p)
|
assert s1 == s1p, f"Expected {s1} to equal {s1p}"
|
||||||
assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p)
|
assert s2 == s2p, f"Expected {s2} to equal {s2p}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -85,7 +85,7 @@ def test_roundtripping(space):
|
|||||||
def test_equality(space):
|
def test_equality(space):
|
||||||
space1 = space
|
space1 = space
|
||||||
space2 = copy.copy(space)
|
space2 = copy.copy(space)
|
||||||
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
|
assert space1 == space2, f"Expected {space1} to equal {space2}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -114,7 +114,7 @@ def test_equality(space):
|
|||||||
)
|
)
|
||||||
def test_inequality(spaces):
|
def test_inequality(spaces):
|
||||||
space1, space2 = spaces
|
space1, space2 = spaces
|
||||||
assert space1 != space2, "Expected {} != {}".format(space1, space2)
|
assert space1 != space2, f"Expected {space1} != {space2}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@@ -36,18 +36,16 @@ flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7]
|
|||||||
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
|
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
|
||||||
def test_flatdim(space, flatdim):
|
def test_flatdim(space, flatdim):
|
||||||
dim = utils.flatdim(space)
|
dim = utils.flatdim(space)
|
||||||
assert dim == flatdim, "Expected {} to equal {}".format(dim, flatdim)
|
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("space", spaces)
|
@pytest.mark.parametrize("space", spaces)
|
||||||
def test_flatten_space_boxes(space):
|
def test_flatten_space_boxes(space):
|
||||||
flat_space = utils.flatten_space(space)
|
flat_space = utils.flatten_space(space)
|
||||||
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(
|
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
|
||||||
type(flat_space), Box
|
|
||||||
)
|
|
||||||
flatdim = utils.flatdim(space)
|
flatdim = utils.flatdim(space)
|
||||||
(single_dim,) = flat_space.shape
|
(single_dim,) = flat_space.shape
|
||||||
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("space", spaces)
|
@pytest.mark.parametrize("space", spaces)
|
||||||
@@ -56,9 +54,9 @@ def test_flat_space_contains_flat_points(space):
|
|||||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||||
flat_space = utils.flatten_space(space)
|
flat_space = utils.flatten_space(space)
|
||||||
for i, flat_sample in enumerate(flattened_samples):
|
for i, flat_sample in enumerate(flattened_samples):
|
||||||
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(
|
assert (
|
||||||
i, flat_sample, flat_space
|
flat_sample in flat_space
|
||||||
)
|
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("space", spaces)
|
@pytest.mark.parametrize("space", spaces)
|
||||||
@@ -66,7 +64,7 @@ def test_flatten_dim(space):
|
|||||||
sample = utils.flatten(space, space.sample())
|
sample = utils.flatten(space, space.sample())
|
||||||
(single_dim,) = sample.shape
|
(single_dim,) = sample.shape
|
||||||
flatdim = utils.flatdim(space)
|
flatdim = utils.flatdim(space)
|
||||||
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("space", spaces)
|
@pytest.mark.parametrize("space", spaces)
|
||||||
@@ -81,7 +79,7 @@ def test_flatten_roundtripping(space):
|
|||||||
):
|
):
|
||||||
assert compare_nested(
|
assert compare_nested(
|
||||||
original, roundtripped
|
original, roundtripped
|
||||||
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
|
), f"Expected sample #{i} {original} to equal {roundtripped}"
|
||||||
|
|
||||||
|
|
||||||
def compare_nested(left, right):
|
def compare_nested(left, right):
|
||||||
@@ -144,9 +142,7 @@ def test_dtypes(original_space, expected_flattened_dtype):
|
|||||||
), "Expected flattened_space to contain flattened_sample"
|
), "Expected flattened_space to contain flattened_sample"
|
||||||
assert (
|
assert (
|
||||||
flattened_space.dtype == expected_flattened_dtype
|
flattened_space.dtype == expected_flattened_dtype
|
||||||
), "Expected flattened_space's dtype to equal " "{}".format(
|
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
|
||||||
expected_flattened_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
assert flattened_sample.dtype == flattened_space.dtype, (
|
assert flattened_sample.dtype == flattened_space.dtype, (
|
||||||
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
|
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from gym.utils.closer import Closer
|
from gym.utils.closer import Closer
|
||||||
|
|
||||||
|
|
||||||
class Closeable(object):
|
class Closeable:
|
||||||
close_called = False
|
close_called = False
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@@ -9,7 +9,7 @@ def test_invalid_seeds():
|
|||||||
except error.Error:
|
except error.Error:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
assert False, "Invalid seed {} passed validation".format(seed)
|
assert False, f"Invalid seed {seed} passed validation"
|
||||||
|
|
||||||
|
|
||||||
def test_valid_seeds():
|
def test_valid_seeds():
|
||||||
|
@@ -42,7 +42,7 @@ def test_concatenate(space):
|
|||||||
assert_nested_equal(lhs[key], rhs_T_key, n)
|
assert_nested_equal(lhs[key], rhs_T_key, n)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`.".format(type(lhs)))
|
raise TypeError(f"Got unknown type `{type(lhs)}`.")
|
||||||
|
|
||||||
samples = [space.sample() for _ in range(8)]
|
samples = [space.sample() for _ in range(8)]
|
||||||
array = create_empty_array(space, n=8)
|
array = create_empty_array(space, n=8)
|
||||||
@@ -76,7 +76,7 @@ def test_create_empty_array(space, n):
|
|||||||
assert_nested_type(arr[key], space.spaces[key], n)
|
assert_nested_type(arr[key], space.spaces[key], n)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`.".format(type(arr)))
|
raise TypeError(f"Got unknown type `{type(arr)}`.")
|
||||||
|
|
||||||
array = create_empty_array(space, n=n, fn=np.empty)
|
array = create_empty_array(space, n=n, fn=np.empty)
|
||||||
assert_nested_type(array, space, n=n)
|
assert_nested_type(array, space, n=n)
|
||||||
@@ -107,7 +107,7 @@ def test_create_empty_array_zeros(space, n):
|
|||||||
assert_nested_type(arr[key], space.spaces[key], n)
|
assert_nested_type(arr[key], space.spaces[key], n)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`.".format(type(arr)))
|
raise TypeError(f"Got unknown type `{type(arr)}`.")
|
||||||
|
|
||||||
array = create_empty_array(space, n=n, fn=np.zeros)
|
array = create_empty_array(space, n=n, fn=np.zeros)
|
||||||
assert_nested_type(array, space, n=n)
|
assert_nested_type(array, space, n=n)
|
||||||
@@ -137,7 +137,7 @@ def test_create_empty_array_none_shape_ones(space):
|
|||||||
assert_nested_type(arr[key], space.spaces[key])
|
assert_nested_type(arr[key], space.spaces[key])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`.".format(type(arr)))
|
raise TypeError(f"Got unknown type `{type(arr)}`.")
|
||||||
|
|
||||||
array = create_empty_array(space, n=None, fn=np.ones)
|
array = create_empty_array(space, n=None, fn=np.ones)
|
||||||
assert_nested_type(array, space)
|
assert_nested_type(array, space)
|
||||||
|
@@ -69,7 +69,7 @@ def test_create_shared_memory(space, expected_type, n, ctx):
|
|||||||
assert type(lhs[0]) == type(rhs[0]) # noqa: E721
|
assert type(lhs[0]) == type(rhs[0]) # noqa: E721
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`.".format(type(lhs)))
|
raise TypeError(f"Got unknown type `{type(lhs)}`.")
|
||||||
|
|
||||||
ctx = mp if (ctx is None) else mp.get_context(ctx)
|
ctx = mp if (ctx is None) else mp.get_context(ctx)
|
||||||
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
|
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
|
||||||
@@ -105,7 +105,7 @@ def test_write_to_shared_memory(space):
|
|||||||
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
|
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`.".format(type(lhs)))
|
raise TypeError(f"Got unknown type `{type(lhs)}`.")
|
||||||
|
|
||||||
def write(i, shared_memory, sample):
|
def write(i, shared_memory, sample):
|
||||||
write_to_shared_memory(i, sample, shared_memory, space)
|
write_to_shared_memory(i, sample, shared_memory, space)
|
||||||
@@ -152,7 +152,7 @@ def test_read_from_shared_memory(space):
|
|||||||
assert np.all(lhs == np.stack(rhs, axis=0))
|
assert np.all(lhs == np.stack(rhs, axis=0))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Got unknown type `{0}`".format(type(space)))
|
raise TypeError(f"Got unknown type `{type(space)}`")
|
||||||
|
|
||||||
def write(i, shared_memory, sample):
|
def write(i, shared_memory, sample):
|
||||||
write_to_shared_memory(i, sample, shared_memory, space)
|
write_to_shared_memory(i, sample, shared_memory, space)
|
||||||
|
@@ -50,7 +50,7 @@ HEIGHT, WIDTH = 64, 64
|
|||||||
|
|
||||||
class UnittestSlowEnv(gym.Env):
|
class UnittestSlowEnv(gym.Env):
|
||||||
def __init__(self, slow_reset=0.3):
|
def __init__(self, slow_reset=0.3):
|
||||||
super(UnittestSlowEnv, self).__init__()
|
super().__init__()
|
||||||
self.slow_reset = slow_reset
|
self.slow_reset = slow_reset
|
||||||
self.observation_space = Box(
|
self.observation_space = Box(
|
||||||
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
|
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
|
||||||
@@ -91,7 +91,7 @@ custom_spaces = [
|
|||||||
|
|
||||||
class CustomSpaceEnv(gym.Env):
|
class CustomSpaceEnv(gym.Env):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CustomSpaceEnv, self).__init__()
|
super().__init__()
|
||||||
self.observation_space = CustomSpace()
|
self.observation_space = CustomSpace()
|
||||||
self.action_space = CustomSpace()
|
self.action_space = CustomSpace()
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ class CustomSpaceEnv(gym.Env):
|
|||||||
return "reset"
|
return "reset"
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation = "step({0:s})".format(action)
|
observation = f"step({action:s})"
|
||||||
reward, done = 0.0, False
|
reward, done = 0.0, False
|
||||||
return observation, reward, done, {}
|
return observation, reward, done, {}
|
||||||
|
|
||||||
|
@@ -59,7 +59,7 @@ OBSERVATION_SPACES = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFlattenEnvironment(object):
|
class TestFlattenEnvironment:
|
||||||
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
|
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
|
||||||
def test_flattened_environment(self, observation_space, ordered_values):
|
def test_flattened_environment(self, observation_space, ordered_values):
|
||||||
"""
|
"""
|
||||||
|
@@ -98,7 +98,7 @@ NESTED_DICT_TEST_CASES = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestNestedDictWrapper(object):
|
class TestNestedDictWrapper:
|
||||||
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
|
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
|
||||||
def test_nested_dicts_size(self, observation_space, flat_shape):
|
def test_nested_dicts_size(self, observation_space, flat_shape):
|
||||||
env = FakeEnvironment(observation_space=observation_space)
|
env = FakeEnvironment(observation_space=observation_space)
|
||||||
|
@@ -77,13 +77,13 @@ def test_atari_preprocessing_scale(env_fn):
|
|||||||
max_obs = 1 if scaled else 255
|
max_obs = 1 if scaled else 255
|
||||||
assert (0 <= obs).all() and (
|
assert (0 <= obs).all() and (
|
||||||
obs <= max_obs
|
obs <= max_obs
|
||||||
).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
).all(), f"Obs. must be in range [0,{max_obs}]"
|
||||||
while not done or step_i <= max_test_steps:
|
while not done or step_i <= max_test_steps:
|
||||||
obs, _, done, _ = env.step(env.action_space.sample())
|
obs, _, done, _ = env.step(env.action_space.sample())
|
||||||
obs = obs.flatten()
|
obs = obs.flatten()
|
||||||
assert (0 <= obs).all() and (
|
assert (0 <= obs).all() and (
|
||||||
obs <= max_obs
|
obs <= max_obs
|
||||||
).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
).all(), f"Obs. must be in range [0,{max_obs}]"
|
||||||
step_i += 1
|
step_i += 1
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -50,7 +50,7 @@ ERROR_TEST_CASES = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFilterObservation(object):
|
class TestFilterObservation:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
|
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
|
||||||
)
|
)
|
||||||
|
@@ -36,7 +36,7 @@ class FakeArrayObservationEnvironment(FakeEnvironment):
|
|||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
shape=(2,), low=-1, high=1, dtype=np.float32
|
shape=(2,), low=-1, high=1, dtype=np.float32
|
||||||
)
|
)
|
||||||
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FakeDictObservationEnvironment(FakeEnvironment):
|
class FakeDictObservationEnvironment(FakeEnvironment):
|
||||||
@@ -46,10 +46,10 @@ class FakeDictObservationEnvironment(FakeEnvironment):
|
|||||||
"state": spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32),
|
"state": spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
super(FakeDictObservationEnvironment, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestPixelObservationWrapper(object):
|
class TestPixelObservationWrapper:
|
||||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||||
def test_dict_observation(self, pixels_only):
|
def test_dict_observation(self, pixels_only):
|
||||||
pixel_key = "rgb"
|
pixel_key = "rgb"
|
||||||
|
@@ -24,7 +24,7 @@ def test_record_video_using_default_trigger():
|
|||||||
assert os.path.isdir("videos")
|
assert os.path.isdir("videos")
|
||||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||||
assert len(mp4_files) == sum(
|
assert len(mp4_files) == sum(
|
||||||
[capped_cubic_video_schedule(i) for i in range(env.episode_id + 1)]
|
capped_cubic_video_schedule(i) for i in range(env.episode_id + 1)
|
||||||
)
|
)
|
||||||
shutil.rmtree("videos")
|
shutil.rmtree("videos")
|
||||||
|
|
||||||
|
@@ -6,14 +6,14 @@ import gym
|
|||||||
from gym.wrappers.monitoring.video_recorder import VideoRecorder
|
from gym.wrappers.monitoring.video_recorder import VideoRecorder
|
||||||
|
|
||||||
|
|
||||||
class BrokenRecordableEnv(object):
|
class BrokenRecordableEnv:
|
||||||
metadata = {"render.modes": [None, "rgb_array"]}
|
metadata = {"render.modes": [None, "rgb_array"]}
|
||||||
|
|
||||||
def render(self, mode=None):
|
def render(self, mode=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnrecordableEnv(object):
|
class UnrecordableEnv:
|
||||||
metadata = {"render.modes": [None]}
|
metadata = {"render.modes": [None]}
|
||||||
|
|
||||||
def render(self, mode=None):
|
def render(self, mode=None):
|
||||||
|
Reference in New Issue
Block a user