Py36+ code style in tests (#2547)

This commit is contained in:
Ilya Kamen
2022-01-11 18:12:05 +01:00
committed by GitHub
parent 3746741708
commit e9df493243
18 changed files with 53 additions and 64 deletions

View File

@@ -43,7 +43,7 @@ def should_skip_env_spec_for_tests(spec):
and not spec.id.startswith("Seaquest")
)
):
logger.warn("Skipping tests for env {}".format(ep))
logger.warn(f"Skipping tests for env {ep}")
return True
return False

View File

@@ -34,9 +34,7 @@ def test_env(spec):
print("action_samples1=", action_samples1)
print("action_samples2=", action_samples2)
print(
"[{}] action_sample1: {}, action_sample2: {}".format(
i, action_sample1, action_sample2
)
f"[{i}] action_sample1: {action_sample1}, action_sample2: {action_sample2}"
)
raise
@@ -50,23 +48,21 @@ def test_env(spec):
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
zip(step_responses1, step_responses2)
):
assert_equals(o1, o2, "[{}] ".format(i))
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
assert_equals(o1, o2, f"[{i}] ")
assert r1 == r2, f"[{i}] r1: {r1}, r2: {r2}"
assert d1 == d2, f"[{i}] d1: {d1}, d2: {d2}"
# Go returns a Pachi game board in info, which doesn't
# properly check equality. For now, we hack around this by
# just skipping Go.
if spec.id not in ["Go9x9-v0", "Go19x19-v0"]:
assert_equals(i1, i2, "[{}] ".format(i))
assert_equals(i1, i2, f"[{i}] ")
def assert_equals(a, b, prefix=None):
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(
prefix, a, b
)
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
for k in a.keys():
v_a = a[k]

View File

@@ -26,24 +26,24 @@ def test_env(spec):
ob_space = env.observation_space
act_space = env.action_space
ob = env.reset()
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
assert ob_space.contains(ob), f"Reset observation: {ob!r} not in space"
if isinstance(ob_space, Box):
# Only checking dtypes for Box spaces to avoid iterating through tuple entries
assert (
ob.dtype == ob_space.dtype
), "Reset observation dtype: {}, expected: {}".format(ob.dtype, ob_space.dtype)
), f"Reset observation dtype: {ob.dtype}, expected: {ob_space.dtype}"
a = act_space.sample()
observation, reward, done, _info = env.step(a)
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(
assert ob_space.contains(
observation
)
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)
), f"Step observation: {observation!r} not in space"
assert np.isscalar(reward), f"{reward} is not a scalar for {env}"
assert isinstance(done, bool), f"Expected {done} to be a boolean"
if isinstance(ob_space, Box):
assert (
observation.dtype == ob_space.dtype
), "Step observation dtype: {}, expected: {}".format(ob.dtype, ob_space.dtype)
), f"Step observation dtype: {ob.dtype}, expected: {ob_space.dtype}"
for mode in env.metadata.get("render.modes", []):
env.render(mode=mode)

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import gym
from gym import error, envs
from gym.envs import registration
@@ -89,11 +88,9 @@ def test_missing_lookup():
def test_malformed_lookup():
registry = registration.EnvRegistry()
try:
registry.spec(u"“Breakout-v0”")
registry.spec("“Breakout-v0”")
except error.Error as e:
assert "malformed environment ID" in "{}".format(
e
), "Unexpected message: {}".format(e)
assert "malformed environment ID" in f"{e}", f"Unexpected message: {e}"
else:
assert False

View File

@@ -50,8 +50,8 @@ def test_roundtripping(space):
s1p = space.to_jsonable([sample_1_prime])
s2 = space.to_jsonable([sample_2])
s2p = space.to_jsonable([sample_2_prime])
assert s1 == s1p, "Expected {} to equal {}".format(s1, s1p)
assert s2 == s2p, "Expected {} to equal {}".format(s2, s2p)
assert s1 == s1p, f"Expected {s1} to equal {s1p}"
assert s2 == s2p, f"Expected {s2} to equal {s2p}"
@pytest.mark.parametrize(
@@ -85,7 +85,7 @@ def test_roundtripping(space):
def test_equality(space):
space1 = space
space2 = copy.copy(space)
assert space1 == space2, "Expected {} to equal {}".format(space1, space2)
assert space1 == space2, f"Expected {space1} to equal {space2}"
@pytest.mark.parametrize(
@@ -114,7 +114,7 @@ def test_equality(space):
)
def test_inequality(spaces):
space1, space2 = spaces
assert space1 != space2, "Expected {} != {}".format(space1, space2)
assert space1 != space2, f"Expected {space1} != {space2}"
@pytest.mark.parametrize(

View File

@@ -36,18 +36,16 @@ flatdims = [3, 4, 4, 15, 7, 9, 14, 10, 7]
@pytest.mark.parametrize(["space", "flatdim"], zip(spaces, flatdims))
def test_flatdim(space, flatdim):
dim = utils.flatdim(space)
assert dim == flatdim, "Expected {} to equal {}".format(dim, flatdim)
assert dim == flatdim, f"Expected {dim} to equal {flatdim}"
@pytest.mark.parametrize("space", spaces)
def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(
type(flat_space), Box
)
assert isinstance(flat_space, Box), f"Expected {type(flat_space)} to equal {Box}"
flatdim = utils.flatdim(space)
(single_dim,) = flat_space.shape
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", spaces)
@@ -56,9 +54,9 @@ def test_flat_space_contains_flat_points(space):
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(
i, flat_sample, flat_space
)
assert (
flat_sample in flat_space
), f"Expected sample #{i} {flat_sample} to be in {flat_space}"
@pytest.mark.parametrize("space", spaces)
@@ -66,7 +64,7 @@ def test_flatten_dim(space):
sample = utils.flatten(space, space.sample())
(single_dim,) = sample.shape
flatdim = utils.flatdim(space)
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
assert single_dim == flatdim, f"Expected {single_dim} to equal {flatdim}"
@pytest.mark.parametrize("space", spaces)
@@ -81,7 +79,7 @@ def test_flatten_roundtripping(space):
):
assert compare_nested(
original, roundtripped
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
), f"Expected sample #{i} {original} to equal {roundtripped}"
def compare_nested(left, right):
@@ -144,9 +142,7 @@ def test_dtypes(original_space, expected_flattened_dtype):
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), "Expected flattened_space's dtype to equal " "{}".format(
expected_flattened_dtype
)
), f"Expected flattened_space's dtype to equal {expected_flattened_dtype}"
assert flattened_sample.dtype == flattened_space.dtype, (
"Expected flattened_space's dtype to equal " "flattened_sample's dtype "

View File

@@ -1,7 +1,7 @@
from gym.utils.closer import Closer
class Closeable(object):
class Closeable:
close_called = False
def close(self):

View File

@@ -9,7 +9,7 @@ def test_invalid_seeds():
except error.Error:
pass
else:
assert False, "Invalid seed {} passed validation".format(seed)
assert False, f"Invalid seed {seed} passed validation"
def test_valid_seeds():

View File

@@ -42,7 +42,7 @@ def test_concatenate(space):
assert_nested_equal(lhs[key], rhs_T_key, n)
else:
raise TypeError("Got unknown type `{0}`.".format(type(lhs)))
raise TypeError(f"Got unknown type `{type(lhs)}`.")
samples = [space.sample() for _ in range(8)]
array = create_empty_array(space, n=8)
@@ -76,7 +76,7 @@ def test_create_empty_array(space, n):
assert_nested_type(arr[key], space.spaces[key], n)
else:
raise TypeError("Got unknown type `{0}`.".format(type(arr)))
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=n, fn=np.empty)
assert_nested_type(array, space, n=n)
@@ -107,7 +107,7 @@ def test_create_empty_array_zeros(space, n):
assert_nested_type(arr[key], space.spaces[key], n)
else:
raise TypeError("Got unknown type `{0}`.".format(type(arr)))
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=n, fn=np.zeros)
assert_nested_type(array, space, n=n)
@@ -137,7 +137,7 @@ def test_create_empty_array_none_shape_ones(space):
assert_nested_type(arr[key], space.spaces[key])
else:
raise TypeError("Got unknown type `{0}`.".format(type(arr)))
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=None, fn=np.ones)
assert_nested_type(array, space)

View File

@@ -69,7 +69,7 @@ def test_create_shared_memory(space, expected_type, n, ctx):
assert type(lhs[0]) == type(rhs[0]) # noqa: E721
else:
raise TypeError("Got unknown type `{0}`.".format(type(lhs)))
raise TypeError(f"Got unknown type `{type(lhs)}`.")
ctx = mp if (ctx is None) else mp.get_context(ctx)
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
@@ -105,7 +105,7 @@ def test_write_to_shared_memory(space):
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
else:
raise TypeError("Got unknown type `{0}`.".format(type(lhs)))
raise TypeError(f"Got unknown type `{type(lhs)}`.")
def write(i, shared_memory, sample):
write_to_shared_memory(i, sample, shared_memory, space)
@@ -152,7 +152,7 @@ def test_read_from_shared_memory(space):
assert np.all(lhs == np.stack(rhs, axis=0))
else:
raise TypeError("Got unknown type `{0}`".format(type(space)))
raise TypeError(f"Got unknown type `{type(space)}`")
def write(i, shared_memory, sample):
write_to_shared_memory(i, sample, shared_memory, space)

View File

@@ -50,7 +50,7 @@ HEIGHT, WIDTH = 64, 64
class UnittestSlowEnv(gym.Env):
def __init__(self, slow_reset=0.3):
super(UnittestSlowEnv, self).__init__()
super().__init__()
self.slow_reset = slow_reset
self.observation_space = Box(
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
@@ -91,7 +91,7 @@ custom_spaces = [
class CustomSpaceEnv(gym.Env):
def __init__(self):
super(CustomSpaceEnv, self).__init__()
super().__init__()
self.observation_space = CustomSpace()
self.action_space = CustomSpace()
@@ -100,7 +100,7 @@ class CustomSpaceEnv(gym.Env):
return "reset"
def step(self, action):
observation = "step({0:s})".format(action)
observation = f"step({action:s})"
reward, done = 0.0, False
return observation, reward, done, {}

View File

@@ -59,7 +59,7 @@ OBSERVATION_SPACES = (
)
class TestFlattenEnvironment(object):
class TestFlattenEnvironment:
@pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
def test_flattened_environment(self, observation_space, ordered_values):
"""

View File

@@ -98,7 +98,7 @@ NESTED_DICT_TEST_CASES = (
)
class TestNestedDictWrapper(object):
class TestNestedDictWrapper:
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_size(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)

View File

@@ -77,13 +77,13 @@ def test_atari_preprocessing_scale(env_fn):
max_obs = 1 if scaled else 255
assert (0 <= obs).all() and (
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
).all(), f"Obs. must be in range [0,{max_obs}]"
while not done or step_i <= max_test_steps:
obs, _, done, _ = env.step(env.action_space.sample())
obs = obs.flatten()
assert (0 <= obs).all() and (
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
).all(), f"Obs. must be in range [0,{max_obs}]"
step_i += 1
env.close()

View File

@@ -50,7 +50,7 @@ ERROR_TEST_CASES = (
)
class TestFilterObservation(object):
class TestFilterObservation:
@pytest.mark.parametrize(
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
)

View File

@@ -36,7 +36,7 @@ class FakeArrayObservationEnvironment(FakeEnvironment):
self.observation_space = spaces.Box(
shape=(2,), low=-1, high=1, dtype=np.float32
)
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
class FakeDictObservationEnvironment(FakeEnvironment):
@@ -46,10 +46,10 @@ class FakeDictObservationEnvironment(FakeEnvironment):
"state": spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
)
super(FakeDictObservationEnvironment, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
class TestPixelObservationWrapper(object):
class TestPixelObservationWrapper:
@pytest.mark.parametrize("pixels_only", (True, False))
def test_dict_observation(self, pixels_only):
pixel_key = "rgb"

View File

@@ -24,7 +24,7 @@ def test_record_video_using_default_trigger():
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == sum(
[capped_cubic_video_schedule(i) for i in range(env.episode_id + 1)]
capped_cubic_video_schedule(i) for i in range(env.episode_id + 1)
)
shutil.rmtree("videos")

View File

@@ -6,14 +6,14 @@ import gym
from gym.wrappers.monitoring.video_recorder import VideoRecorder
class BrokenRecordableEnv(object):
class BrokenRecordableEnv:
metadata = {"render.modes": [None, "rgb_array"]}
def render(self, mode=None):
pass
class UnrecordableEnv(object):
class UnrecordableEnv:
metadata = {"render.modes": [None]}
def render(self, mode=None):