diff --git a/bin/render.py b/bin/render.py index 257269ec7..836ee34d3 100755 --- a/bin/render.py +++ b/bin/render.py @@ -3,9 +3,7 @@ import argparse import gym -parser = argparse.ArgumentParser( - description="Renders a Gym environment for quick inspection." -) +parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.") parser.add_argument( "env_id", type=str, diff --git a/examples/agents/cem.py b/examples/agents/cem.py index e944f118e..2b646e06a 100644 --- a/examples/agents/cem.py +++ b/examples/agents/cem.py @@ -35,12 +35,7 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0): th_std = np.ones_like(th_mean) * initial_std for _ in range(n_iter): - ths = np.array( - [ - th_mean + dth - for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size) - ] - ) + ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)]) ys = np.array([f(th) for th in ths]) elite_inds = ys.argsort()[::-1][:n_elite] elite_ths = ths[elite_inds] @@ -101,9 +96,7 @@ if __name__ == "__main__": return rew # Train the agent, and snapshot each stage - for (i, iterdata) in enumerate( - cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params) - ): + for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)): print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"])) agent = BinaryActionLinearPolicy(iterdata["theta_mean"]) if args.display: diff --git a/examples/agents/random_agent.py b/examples/agents/random_agent.py index 32d703001..ee7c6b052 100644 --- a/examples/agents/random_agent.py +++ b/examples/agents/random_agent.py @@ -17,9 +17,7 @@ class RandomAgent(object): if __name__ == "__main__": parser = argparse.ArgumentParser(description=None) - parser.add_argument( - "env_id", nargs="?", default="CartPole-v0", help="Select the environment to run" - ) + parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run") args = parser.parse_args() # You can set the level to logger.DEBUG or logger.WARN if you diff --git a/gym/core.py b/gym/core.py index 0d264c770..3ed589669 100644 --- a/gym/core.py +++ b/gym/core.py @@ -173,16 +173,10 @@ class GoalEnv(Env): def reset(self): # Enforce that each GoalEnv uses a Goal-compatible observation space. if not isinstance(self.observation_space, gym.spaces.Dict): - raise error.Error( - "GoalEnv requires an observation space of type gym.spaces.Dict" - ) + raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict") for key in ["observation", "achieved_goal", "desired_goal"]: if key not in self.observation_space.spaces: - raise error.Error( - 'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format( - key - ) - ) + raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key)) def compute_reward(self, achieved_goal, desired_goal, info): """Compute the step reward. This externalizes the reward function and makes @@ -227,9 +221,7 @@ class Wrapper(Env): def __getattr__(self, name): if name.startswith("_"): - raise AttributeError( - "attempted to get missing private attribute '{}'".format(name) - ) + raise AttributeError("attempted to get missing private attribute '{}'".format(name)) return getattr(self.env, name) @property diff --git a/gym/envs/__init__.py b/gym/envs/__init__.py index 2e2f1f0ba..74f5bfabe 100644 --- a/gym/envs/__init__.py +++ b/gym/envs/__init__.py @@ -422,9 +422,7 @@ for reward_type in ["sparse", "dense"]: register( id="HandManipulateBlockRotateParallel{}-v0".format(suffix), entry_point="gym.envs.robotics:HandBlockEnv", - kwargs=_merge( - {"target_position": "ignore", "target_rotation": "parallel"}, kwargs - ), + kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs), max_episode_steps=100, ) diff --git a/gym/envs/algorithmic/algorithmic_env.py b/gym/envs/algorithmic/algorithmic_env.py index ef33db90f..5d00071e7 100644 --- a/gym/envs/algorithmic/algorithmic_env.py +++ b/gym/envs/algorithmic/algorithmic_env.py @@ -73,9 +73,7 @@ class AlgorithmicEnv(Env): # 1. Move read head left or right (or up/down) # 2. Write or not # 3. Which character to write. (Ignored if should_write=0) - self.action_space = Tuple( - [Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)] - ) + self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]) # Can see just what is on the input tape (one of n characters, or # nothing) self.observation_space = Discrete(self.base + 1) @@ -147,10 +145,7 @@ class AlgorithmicEnv(Env): move = self.MOVEMENTS[inp_act] outfile.write("Action : Tuple(move over input: %s,\n" % move) out_act = out_act == 1 - outfile.write( - " write to the output tape: %s,\n" - % out_act - ) + outfile.write(" write to the output tape: %s,\n" % out_act) outfile.write(" prediction: %s)\n" % pred_str) else: outfile.write("\n" * 5) @@ -276,9 +271,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv): x_str = "Observation Tape : " for i in range(-2, self.input_width + 2): if i == x: - x_str += colorize( - self._get_str_obs(np.array([i])), "green", highlight=True - ) + x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True) else: x_str += self._get_str_obs(np.array([i])) x_str += "\n" @@ -311,10 +304,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv): self.read_head_position = x, y def generate_input_data(self, size): - return [ - [self.np_random.randint(self.base) for _ in range(self.rows)] - for __ in range(size) - ] + return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)] def _get_obs(self, pos=None): if pos is None: @@ -336,9 +326,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv): x_str += " " * len(label) for i in range(-2, self.input_width + 2): if i == x[0] and j == x[1]: - x_str += colorize( - self._get_str_obs((i, j)), "green", highlight=True - ) + x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True) else: x_str += self._get_str_obs((i, j)) x_str += "\n" diff --git a/gym/envs/algorithmic/tests/test_algorithmic.py b/gym/envs/algorithmic/tests/test_algorithmic.py index 7ee33db83..ee8c6ad80 100644 --- a/gym/envs/algorithmic/tests/test_algorithmic.py +++ b/gym/envs/algorithmic/tests/test_algorithmic.py @@ -10,12 +10,8 @@ ALL_ENVS = [ alg.reverse.ReverseEnv, alg.reversed_addition.ReversedAdditionEnv, ] -ALL_TAPE_ENVS = [ - env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv) -] -ALL_GRID_ENVS = [ - env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv) -] +ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)] +ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)] def imprint(env, input_arr): @@ -92,10 +88,7 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase): def test_grid_naviation(self): env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6) - N, S, E, W = [ - env._movement_idx(named_dir) - for named_dir in ["up", "down", "right", "left"] - ] + N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]] # Corresponds to a grid that looks like... # 0 1 2 # 3 4 5 @@ -204,9 +197,7 @@ class TestTargets(unittest.TestCase): def test_repeat_copy_target(self): env = alg.repeat_copy.RepeatCopyEnv() - self.assertEqual( - env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2] - ) + self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]) class TestInputGeneration(unittest.TestCase): diff --git a/gym/envs/atari/atari_env.py b/gym/envs/atari/atari_env.py index 9e3e91b5c..cac994f08 100644 --- a/gym/envs/atari/atari_env.py +++ b/gym/envs/atari/atari_env.py @@ -9,8 +9,7 @@ try: import atari_py except ImportError as e: raise error.DependencyNotInstalled( - "{}. (HINT: you can install Atari dependencies by running " - "'pip install gym[atari]'.)".format(e) + "{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e) ) @@ -64,35 +63,23 @@ class AtariEnv(gym.Env, utils.EzPickle): # Tune (or disable) ALE's action repeat: # https://github.com/openai/gym/issues/349 - assert isinstance( - repeat_action_probability, (float, int) - ), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability) - self.ale.setFloat( - "repeat_action_probability".encode("utf-8"), repeat_action_probability + assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format( + repeat_action_probability ) + self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability) self.seed() - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) + self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet() self.action_space = spaces.Discrete(len(self._action_set)) (screen_width, screen_height) = self.ale.getScreenDims() if self._obs_type == "ram": - self.observation_space = spaces.Box( - low=0, high=255, dtype=np.uint8, shape=(128,) - ) + self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,)) elif self._obs_type == "image": - self.observation_space = spaces.Box( - low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8 - ) + self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8) else: - raise error.Error( - "Unrecognized observation type: {}".format(self._obs_type) - ) + raise error.Error("Unrecognized observation type: {}".format(self._obs_type)) def seed(self, seed=None): self.np_random, seed1 = seeding.np_random(seed) @@ -107,9 +94,9 @@ class AtariEnv(gym.Env, utils.EzPickle): if self.game_mode is not None: modes = self.ale.getAvailableModes() - assert self.game_mode in modes, ( - 'Invalid game mode "{}" for game {}.\nAvailable modes are: {}' - ).format(self.game_mode, self.game, modes) + assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format( + self.game_mode, self.game, modes + ) self.ale.setMode(self.game_mode) if self.game_difficulty is not None: diff --git a/gym/envs/box2d/bipedal_walker.py b/gym/envs/box2d/bipedal_walker.py index cea2a8f1f..6b41fbdd2 100644 --- a/gym/envs/box2d/bipedal_walker.py +++ b/gym/envs/box2d/bipedal_walker.py @@ -100,10 +100,7 @@ class ContactDetector(contactListener): self.env = env def BeginContact(self, contact): - if ( - self.env.hull == contact.fixtureA.body - or self.env.hull == contact.fixtureB.body - ): + if self.env.hull == contact.fixtureA.body or self.env.hull == contact.fixtureB.body: self.env.game_over = True for leg in [self.env.legs[1], self.env.legs[3]]: if leg in [contact.fixtureA.body, contact.fixtureB.body]: @@ -202,9 +199,7 @@ class BipedalWalker(gym.Env, EzPickle): t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6) self.terrain.append(t) - self.fd_polygon.shape.vertices = [ - (p[0] + TERRAIN_STEP * counter, p[1]) for p in poly - ] + self.fd_polygon.shape.vertices = [(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly] t = self.world.CreateStaticBody(fixtures=self.fd_polygon) t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6) self.terrain.append(t) @@ -301,12 +296,8 @@ class BipedalWalker(gym.Env, EzPickle): y = VIEWPORT_H / SCALE * 3 / 4 poly = [ ( - x - + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) - + self.np_random.uniform(0, 5 * TERRAIN_STEP), - y - + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) - + self.np_random.uniform(0, 5 * TERRAIN_STEP), + x + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP), + y + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP), ) for a in range(5) ] @@ -331,14 +322,10 @@ class BipedalWalker(gym.Env, EzPickle): init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2 init_y = TERRAIN_HEIGHT + 2 * LEG_H - self.hull = self.world.CreateDynamicBody( - position=(init_x, init_y), fixtures=HULL_FD - ) + self.hull = self.world.CreateDynamicBody(position=(init_x, init_y), fixtures=HULL_FD) self.hull.color1 = (0.5, 0.4, 0.9) self.hull.color2 = (0.3, 0.3, 0.5) - self.hull.ApplyForceToCenter( - (self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True - ) + self.hull.ApplyForceToCenter((self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True) self.legs = [] self.joints = [] @@ -412,21 +399,13 @@ class BipedalWalker(gym.Env, EzPickle): self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1)) else: self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0])) - self.joints[0].maxMotorTorque = float( - MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1) - ) + self.joints[0].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1)) self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1])) - self.joints[1].maxMotorTorque = float( - MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1) - ) + self.joints[1].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1)) self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2])) - self.joints[2].maxMotorTorque = float( - MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1) - ) + self.joints[2].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1)) self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3])) - self.joints[3].maxMotorTorque = float( - MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1) - ) + self.joints[3].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1)) self.world.Step(1.0 / FPS, 6 * 30, 2 * 30) @@ -465,12 +444,8 @@ class BipedalWalker(gym.Env, EzPickle): self.scroll = pos.x - VIEWPORT_W / SCALE / 5 - shaping = ( - 130 * pos[0] / SCALE - ) # moving forward is a way to receive reward (normalized to get 300 on completion) - shaping -= 5.0 * abs( - state[0] - ) # keep head straight, other than that and falling, any behavior is unpunished + shaping = 130 * pos[0] / SCALE # moving forward is a way to receive reward (normalized to get 300 on completion) + shaping -= 5.0 * abs(state[0]) # keep head straight, other than that and falling, any behavior is unpunished reward = 0 if self.prev_shaping is not None: @@ -494,9 +469,7 @@ class BipedalWalker(gym.Env, EzPickle): if self.viewer is None: self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H) - self.viewer.set_bounds( - self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE - ) + self.viewer.set_bounds(self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE) self.viewer.draw_polygon( [ @@ -512,9 +485,7 @@ class BipedalWalker(gym.Env, EzPickle): continue if x1 > self.scroll / 2 + VIEWPORT_W / SCALE: continue - self.viewer.draw_polygon( - [(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1) - ) + self.viewer.draw_polygon([(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1)) for poly, color in self.terrain_poly: if poly[1][0] < self.scroll: continue @@ -525,11 +496,7 @@ class BipedalWalker(gym.Env, EzPickle): self.lidar_render = (self.lidar_render + 1) % 100 i = self.lidar_render if i < 2 * len(self.lidar): - l = ( - self.lidar[i] - if i < len(self.lidar) - else self.lidar[len(self.lidar) - i - 1] - ) + l = self.lidar[i] if i < len(self.lidar) else self.lidar[len(self.lidar) - i - 1] self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1) for obj in self.drawlist: @@ -537,12 +504,8 @@ class BipedalWalker(gym.Env, EzPickle): trans = f.body.transform if type(f.shape) is circleShape: t = rendering.Transform(translation=trans * f.shape.pos) - self.viewer.draw_circle( - f.shape.radius, 30, color=obj.color1 - ).add_attr(t) - self.viewer.draw_circle( - f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2 - ).add_attr(t) + self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t) + self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t) else: path = [trans * v for v in f.shape.vertices] self.viewer.draw_polygon(path, color=obj.color1) @@ -552,9 +515,7 @@ class BipedalWalker(gym.Env, EzPickle): flagy1 = TERRAIN_HEIGHT flagy2 = flagy1 + 50 / SCALE x = TERRAIN_STEP * 3 - self.viewer.draw_polyline( - [(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2 - ) + self.viewer.draw_polyline([(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2) f = [ (x, flagy2), (x, flagy2 - 10 / SCALE), diff --git a/gym/envs/box2d/car_dynamics.py b/gym/envs/box2d/car_dynamics.py index 756e1e179..65aa31dc9 100644 --- a/gym/envs/box2d/car_dynamics.py +++ b/gym/envs/box2d/car_dynamics.py @@ -23,9 +23,7 @@ from Box2D.b2 import ( SIZE = 0.02 ENGINE_POWER = 100000000 * SIZE * SIZE WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE -FRICTION_LIMIT = ( - 1000000 * SIZE * SIZE -) # friction ~= mass ~= size^2 (calculated implicitly using density) +FRICTION_LIMIT = 1000000 * SIZE * SIZE # friction ~= mass ~= size^2 (calculated implicitly using density) WHEEL_R = 27 WHEEL_W = 14 WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)] @@ -55,27 +53,19 @@ class Car: angle=init_angle, fixtures=[ fixtureDef( - shape=polygonShape( - vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1] - ), + shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]), density=1.0, ), fixtureDef( - shape=polygonShape( - vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2] - ), + shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]), density=1.0, ), fixtureDef( - shape=polygonShape( - vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3] - ), + shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]), density=1.0, ), fixtureDef( - shape=polygonShape( - vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4] - ), + shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]), density=1.0, ), ], @@ -95,12 +85,7 @@ class Car: position=(init_x + wx * SIZE, init_y + wy * SIZE), angle=init_angle, fixtures=fixtureDef( - shape=polygonShape( - vertices=[ - (x * front_k * SIZE, y * front_k * SIZE) - for x, y in WHEEL_POLY - ] - ), + shape=polygonShape(vertices=[(x * front_k * SIZE, y * front_k * SIZE) for x, y in WHEEL_POLY]), density=0.1, categoryBits=0x0020, maskBits=0x001, @@ -175,9 +160,7 @@ class Car: grass = True friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile for tile in w.tiles: - friction_limit = max( - friction_limit, FRICTION_LIMIT * tile.road_friction - ) + friction_limit = max(friction_limit, FRICTION_LIMIT * tile.road_friction) grass = False # Force @@ -192,13 +175,7 @@ class Car: # domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega # add small coef not to divide by zero - w.omega += ( - dt - * ENGINE_POWER - * w.gas - / WHEEL_MOMENT_OF_INERTIA - / (abs(w.omega) + 5.0) - ) + w.omega += dt * ENGINE_POWER * w.gas / WHEEL_MOMENT_OF_INERTIA / (abs(w.omega) + 5.0) self.fuel_spent += dt * ENGINE_POWER * w.gas if w.brake >= 0.9: @@ -226,18 +203,12 @@ class Car: # Skid trace if abs(force) > 2.0 * friction_limit: - if ( - w.skid_particle - and w.skid_particle.grass == grass - and len(w.skid_particle.poly) < 30 - ): + if w.skid_particle and w.skid_particle.grass == grass and len(w.skid_particle.poly) < 30: w.skid_particle.poly.append((w.position[0], w.position[1])) elif w.skid_start is None: w.skid_start = w.position else: - w.skid_particle = self._create_particle( - w.skid_start, w.position, grass - ) + w.skid_particle = self._create_particle(w.skid_start, w.position, grass) w.skid_start = None else: w.skid_start = None diff --git a/gym/envs/box2d/car_racing.py b/gym/envs/box2d/car_racing.py index 1fd52043e..a055ba4f3 100644 --- a/gym/envs/box2d/car_racing.py +++ b/gym/envs/box2d/car_racing.py @@ -132,18 +132,14 @@ class CarRacing(gym.Env, EzPickle): self.reward = 0.0 self.prev_reward = 0.0 self.verbose = verbose - self.fd_tile = fixtureDef( - shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)]) - ) + self.fd_tile = fixtureDef(shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])) self.action_space = spaces.Box( np.array([-1, 0, 0]).astype(np.float32), np.array([+1, +1, +1]).astype(np.float32), ) # steer, gas, brake - self.observation_space = spaces.Box( - low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8 - ) + self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8) def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -246,9 +242,7 @@ class CarRacing(gym.Env, EzPickle): i -= 1 if i == 0: return False # Failed - pass_through_start = ( - track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha - ) + pass_through_start = track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha if pass_through_start and i2 == -1: i2 = i elif pass_through_start and i1 == -1: @@ -266,8 +260,7 @@ class CarRacing(gym.Env, EzPickle): first_perp_y = math.sin(first_beta) # Length of perpendicular jump to put together head and tail well_glued_together = np.sqrt( - np.square(first_perp_x * (track[0][2] - track[-1][2])) - + np.square(first_perp_y * (track[0][3] - track[-1][3])) + np.square(first_perp_x * (track[0][2] - track[-1][2])) + np.square(first_perp_y * (track[0][3] - track[-1][3])) ) if well_glued_together > TRACK_DETAIL_STEP: return False @@ -337,9 +330,7 @@ class CarRacing(gym.Env, EzPickle): x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2), y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2), ) - self.road_poly.append( - ([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0)) - ) + self.road_poly.append(([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0))) self.track = track return True @@ -356,10 +347,7 @@ class CarRacing(gym.Env, EzPickle): if success: break if self.verbose == 1: - print( - "retry to generate track (normal if there are not many" - "instances of this message)" - ) + print("retry to generate track (normal if there are not many" "instances of this message)") self.car = Car(self.world, *self.track[0][1:4]) return self.step(None)[0] @@ -424,10 +412,8 @@ class CarRacing(gym.Env, EzPickle): angle = math.atan2(vel[0], vel[1]) self.transform.set_scale(zoom, zoom) self.transform.set_translation( - WINDOW_W / 2 - - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)), - WINDOW_H / 4 - - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)), + WINDOW_W / 2 - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)), + WINDOW_H / 4 - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)), ) self.transform.set_rotation(angle) @@ -449,9 +435,7 @@ class CarRacing(gym.Env, EzPickle): else: pixel_scale = 1 if hasattr(win.context, "_nscontext"): - pixel_scale = ( - win.context._nscontext.view().backingScaleFactor() - ) # pylint: disable=protected-access + pixel_scale = win.context._nscontext.view().backingScaleFactor() # pylint: disable=protected-access VP_W = int(pixel_scale * WINDOW_W) VP_H = int(pixel_scale * WINDOW_H) @@ -468,9 +452,7 @@ class CarRacing(gym.Env, EzPickle): win.flip() return self.viewer.isopen - image_data = ( - pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() - ) + image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="") arr = arr.reshape(VP_H, VP_W, 4) arr = arr[::-1, :, 0:3] @@ -525,9 +507,7 @@ class CarRacing(gym.Env, EzPickle): for p in poly: polygons_.extend([p[0], p[1], 0]) - vl = pyglet.graphics.vertex_list( - len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors) # gl.GL_QUADS, - ) + vl = pyglet.graphics.vertex_list(len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)) # gl.GL_QUADS, vl.draw(gl.GL_QUADS) vl.delete() @@ -575,10 +555,7 @@ class CarRacing(gym.Env, EzPickle): ] ) - true_speed = np.sqrt( - np.square(self.car.hull.linearVelocity[0]) - + np.square(self.car.hull.linearVelocity[1]) - ) + true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1])) vertical_ind(5, 0.02 * true_speed, (1, 1, 1)) vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors @@ -587,9 +564,7 @@ class CarRacing(gym.Env, EzPickle): vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1)) horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0)) horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0)) - vl = pyglet.graphics.vertex_list( - len(polygons) // 3, ("v3f", polygons), ("c4f", colors) # gl.GL_QUADS, - ) + vl = pyglet.graphics.vertex_list(len(polygons) // 3, ("v3f", polygons), ("c4f", colors)) # gl.GL_QUADS, vl.draw(gl.GL_QUADS) vl.delete() self.score_label.text = "%04i" % self.reward diff --git a/gym/envs/box2d/lunar_lander.py b/gym/envs/box2d/lunar_lander.py index efa68243e..eb911a776 100644 --- a/gym/envs/box2d/lunar_lander.py +++ b/gym/envs/box2d/lunar_lander.py @@ -71,10 +71,7 @@ class ContactDetector(contactListener): self.env = env def BeginContact(self, contact): - if ( - self.env.lander == contact.fixtureA.body - or self.env.lander == contact.fixtureB.body - ): + if self.env.lander == contact.fixtureA.body or self.env.lander == contact.fixtureB.body: self.env.game_over = True for i in range(2): if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]: @@ -104,9 +101,7 @@ class LunarLander(gym.Env, EzPickle): self.prev_reward = None # useful range is -1 .. +1, but spikes can be higher - self.observation_space = spaces.Box( - -np.inf, np.inf, shape=(8,), dtype=np.float32 - ) + self.observation_space = spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32) if self.continuous: # Action is two floats [main engine, left-right engines]. @@ -157,14 +152,9 @@ class LunarLander(gym.Env, EzPickle): height[CHUNKS // 2 + 0] = self.helipad_y height[CHUNKS // 2 + 1] = self.helipad_y height[CHUNKS // 2 + 2] = self.helipad_y - smooth_y = [ - 0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) - for i in range(CHUNKS) - ] + smooth_y = [0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) for i in range(CHUNKS)] - self.moon = self.world.CreateStaticBody( - shapes=edgeShape(vertices=[(0, 0), (W, 0)]) - ) + self.moon = self.world.CreateStaticBody(shapes=edgeShape(vertices=[(0, 0), (W, 0)])) self.sky_polys = [] for i in range(CHUNKS - 1): p1 = (chunk_x[i], smooth_y[i]) @@ -180,9 +170,7 @@ class LunarLander(gym.Env, EzPickle): position=(VIEWPORT_W / SCALE / 2, initial_y), angle=0.0, fixtures=fixtureDef( - shape=polygonShape( - vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY] - ), + shape=polygonShape(vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]), density=5.0, friction=0.1, categoryBits=0x0010, @@ -227,9 +215,7 @@ class LunarLander(gym.Env, EzPickle): motorSpeed=+0.3 * i, # low enough not to jump back into the sky ) if i == -1: - rjd.lowerAngle = ( - +0.9 - 0.5 - ) # The most esoteric numbers here, angled legs have freedom to travel within + rjd.lowerAngle = +0.9 - 0.5 # The most esoteric numbers here, angled legs have freedom to travel within rjd.upperAngle = +0.9 else: rjd.lowerAngle = -0.9 @@ -278,9 +264,7 @@ class LunarLander(gym.Env, EzPickle): dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)] m_power = 0.0 - if (self.continuous and action[0] > 0.0) or ( - not self.continuous and action == 2 - ): + if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2): # Main engine if self.continuous: m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0 @@ -310,9 +294,7 @@ class LunarLander(gym.Env, EzPickle): ) s_power = 0.0 - if (self.continuous and np.abs(action[1]) > 0.5) or ( - not self.continuous and action in [1, 3] - ): + if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]): # Orientation engines if self.continuous: direction = np.sign(action[1]) @@ -321,12 +303,8 @@ class LunarLander(gym.Env, EzPickle): else: direction = action - 2 s_power = 1.0 - ox = tip[0] * dispersion[0] + side[0] * ( - 3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE - ) - oy = -tip[1] * dispersion[0] - side[1] * ( - 3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE - ) + ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE) + oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE) impulse_pos = ( self.lander.position[0] + ox - tip[0] * 17 / SCALE, self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE, @@ -372,9 +350,7 @@ class LunarLander(gym.Env, EzPickle): reward = shaping - self.prev_shaping self.prev_shaping = shaping - reward -= ( - m_power * 0.30 - ) # less fuel spent is better, about -30 for heuristic landing + reward -= m_power * 0.30 # less fuel spent is better, about -30 for heuristic landing reward -= s_power * 0.03 done = False @@ -416,12 +392,8 @@ class LunarLander(gym.Env, EzPickle): trans = f.body.transform if type(f.shape) is circleShape: t = rendering.Transform(translation=trans * f.shape.pos) - self.viewer.draw_circle( - f.shape.radius, 20, color=obj.color1 - ).add_attr(t) - self.viewer.draw_circle( - f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2 - ).add_attr(t) + self.viewer.draw_circle(f.shape.radius, 20, color=obj.color1).add_attr(t) + self.viewer.draw_circle(f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2).add_attr(t) else: path = [trans * v for v in f.shape.vertices] self.viewer.draw_polygon(path, color=obj.color1) @@ -479,18 +451,14 @@ def heuristic(env, s): angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad if angle_targ < -0.4: angle_targ = -0.4 - hover_targ = 0.55 * np.abs( - s[0] - ) # target y should be proportional to horizontal offset + hover_targ = 0.55 * np.abs(s[0]) # target y should be proportional to horizontal offset angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0 hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5 if s[6] or s[7]: # legs have contact angle_todo = 0 - hover_todo = ( - -(s[3]) * 0.5 - ) # override to reduce fall speed, that's all we need after contact + hover_todo = -(s[3]) * 0.5 # override to reduce fall speed, that's all we need after contact if env.continuous: a = np.array([hover_todo * 20 - 1, -angle_todo * 20]) diff --git a/gym/envs/classic_control/acrobot.py b/gym/envs/classic_control/acrobot.py index 00f76fc37..c54fc51f6 100644 --- a/gym/envs/classic_control/acrobot.py +++ b/gym/envs/classic_control/acrobot.py @@ -88,9 +88,7 @@ class AcrobotEnv(core.Env): def __init__(self): self.viewer = None - high = np.array( - [1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32 - ) + high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32) low = -high self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.action_space = spaces.Discrete(3) @@ -111,9 +109,7 @@ class AcrobotEnv(core.Env): # Add noise to the force action if self.torque_noise_max > 0: - torque += self.np_random.uniform( - -self.torque_noise_max, self.torque_noise_max - ) + torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max) # Now, augment the state with our force action so it can be passed to # _dsdt @@ -160,12 +156,7 @@ class AcrobotEnv(core.Env): theta2 = s[1] dtheta1 = s[2] dtheta2 = s[3] - d1 = ( - m1 * lc1 ** 2 - + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) - + I1 - + I2 - ) + d1 = m1 * lc1 ** 2 + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2 d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2 phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0) phi1 = ( @@ -181,9 +172,9 @@ class AcrobotEnv(core.Env): else: # the following line is consistent with the java implementation and the # book - ddtheta2 = ( - a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2 - ) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) + ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2) / ( + m2 * lc2 ** 2 + I2 - d2 ** 2 / d1 + ) ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) diff --git a/gym/envs/classic_control/cartpole.py b/gym/envs/classic_control/cartpole.py index 20e490931..157381f50 100644 --- a/gym/envs/classic_control/cartpole.py +++ b/gym/envs/classic_control/cartpole.py @@ -111,9 +111,7 @@ class CartPoleEnv(gym.Env): # For the interested reader: # https://coneural.org/florian/papers/05_cart_pole.pdf - temp = ( - force + self.polemass_length * theta_dot ** 2 * sintheta - ) / self.total_mass + temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass thetaacc = (self.gravity * sintheta - costheta * temp) / ( self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass) ) diff --git a/gym/envs/classic_control/continuous_mountain_car.py b/gym/envs/classic_control/continuous_mountain_car.py index 4db8e5ea4..8f3a6bb16 100644 --- a/gym/envs/classic_control/continuous_mountain_car.py +++ b/gym/envs/classic_control/continuous_mountain_car.py @@ -62,27 +62,17 @@ class Continuous_MountainCarEnv(gym.Env): self.min_position = -1.2 self.max_position = 0.6 self.max_speed = 0.07 - self.goal_position = ( - 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version - ) + self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version self.goal_velocity = goal_velocity self.power = 0.0015 - self.low_state = np.array( - [self.min_position, -self.max_speed], dtype=np.float32 - ) - self.high_state = np.array( - [self.max_position, self.max_speed], dtype=np.float32 - ) + self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32) + self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32) self.viewer = None - self.action_space = spaces.Box( - low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32 - ) - self.observation_space = spaces.Box( - low=self.low_state, high=self.high_state, dtype=np.float32 - ) + self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32) + self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32) self.seed() self.reset() @@ -159,15 +149,11 @@ class Continuous_MountainCarEnv(gym.Env): self.viewer.add_geom(car) frontwheel = rendering.make_circle(carheight / 2.5) frontwheel.set_color(0.5, 0.5, 0.5) - frontwheel.add_attr( - rendering.Transform(translation=(carwidth / 4, clearance)) - ) + frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance))) frontwheel.add_attr(self.cartrans) self.viewer.add_geom(frontwheel) backwheel = rendering.make_circle(carheight / 2.5) - backwheel.add_attr( - rendering.Transform(translation=(-carwidth / 4, clearance)) - ) + backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance))) backwheel.add_attr(self.cartrans) backwheel.set_color(0.5, 0.5, 0.5) self.viewer.add_geom(backwheel) @@ -176,16 +162,12 @@ class Continuous_MountainCarEnv(gym.Env): flagy2 = flagy1 + 50 flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) self.viewer.add_geom(flagpole) - flag = rendering.FilledPolygon( - [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)] - ) + flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]) flag.set_color(0.8, 0.8, 0) self.viewer.add_geom(flag) pos = self.state[0] - self.cartrans.set_translation( - (pos - self.min_position) * scale, self._height(pos) * scale - ) + self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale) self.cartrans.set_rotation(math.cos(3 * pos)) return self.viewer.render(return_rgb_array=mode == "rgb_array") diff --git a/gym/envs/classic_control/mountain_car.py b/gym/envs/classic_control/mountain_car.py index fe6777948..165712b59 100644 --- a/gym/envs/classic_control/mountain_car.py +++ b/gym/envs/classic_control/mountain_car.py @@ -136,15 +136,11 @@ class MountainCarEnv(gym.Env): self.viewer.add_geom(car) frontwheel = rendering.make_circle(carheight / 2.5) frontwheel.set_color(0.5, 0.5, 0.5) - frontwheel.add_attr( - rendering.Transform(translation=(carwidth / 4, clearance)) - ) + frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance))) frontwheel.add_attr(self.cartrans) self.viewer.add_geom(frontwheel) backwheel = rendering.make_circle(carheight / 2.5) - backwheel.add_attr( - rendering.Transform(translation=(-carwidth / 4, clearance)) - ) + backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance))) backwheel.add_attr(self.cartrans) backwheel.set_color(0.5, 0.5, 0.5) self.viewer.add_geom(backwheel) @@ -153,16 +149,12 @@ class MountainCarEnv(gym.Env): flagy2 = flagy1 + 50 flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) self.viewer.add_geom(flagpole) - flag = rendering.FilledPolygon( - [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)] - ) + flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]) flag.set_color(0.8, 0.8, 0) self.viewer.add_geom(flag) pos = self.state[0] - self.cartrans.set_translation( - (pos - self.min_position) * scale, self._height(pos) * scale - ) + self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale) self.cartrans.set_rotation(math.cos(3 * pos)) return self.viewer.render(return_rgb_array=mode == "rgb_array") diff --git a/gym/envs/classic_control/pendulum.py b/gym/envs/classic_control/pendulum.py index e8b38da9c..1d64749dd 100644 --- a/gym/envs/classic_control/pendulum.py +++ b/gym/envs/classic_control/pendulum.py @@ -18,9 +18,7 @@ class PendulumEnv(gym.Env): self.viewer = None high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32) - self.action_space = spaces.Box( - low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32 - ) + self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32) self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32) self.seed() @@ -41,10 +39,7 @@ class PendulumEnv(gym.Env): self.last_u = u # for rendering costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2) - newthdot = ( - thdot - + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt - ) + newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt newth = th + newthdot * dt newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) diff --git a/gym/envs/classic_control/rendering.py b/gym/envs/classic_control/rendering.py index b3390886d..465a231e2 100644 --- a/gym/envs/classic_control/rendering.py +++ b/gym/envs/classic_control/rendering.py @@ -54,11 +54,7 @@ def get_display(spec): elif isinstance(spec, str): return pyglet.canvas.Display(spec) else: - raise error.Error( - "Invalid display specification: {}. (Must be a string like :0 or None.)".format( - spec - ) - ) + raise error.Error("Invalid display specification: {}. (Must be a string like :0 or None.)".format(spec)) def get_window(width, height, display, **kwargs): @@ -69,14 +65,7 @@ def get_window(width, height, display, **kwargs): config = screen[0].get_best_config() # selecting the first screen context = config.create_context(None) # create GL context - return pyglet.window.Window( - width=width, - height=height, - display=display, - config=config, - context=context, - **kwargs - ) + return pyglet.window.Window(width=width, height=height, display=display, config=config, context=context, **kwargs) class Viewer(object): @@ -108,9 +97,7 @@ class Viewer(object): assert right > left and top > bottom scalex = self.width / (right - left) scaley = self.height / (top - bottom) - self.transform = Transform( - translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley) - ) + self.transform = Transform(translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley)) def add_geom(self, geom): self.geoms.append(geom) @@ -173,9 +160,7 @@ class Viewer(object): def get_array(self): self.window.flip() - image_data = ( - pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() - ) + image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() self.window.flip() arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="") arr = arr.reshape(self.height, self.width, 4) @@ -230,9 +215,7 @@ class Transform(Attr): def enable(self): glPushMatrix() - glTranslatef( - self.translation[0], self.translation[1], 0 - ) # translate to GL loc ppint + glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0) glScalef(self.scale[0], self.scale[1], 1) @@ -392,9 +375,7 @@ class Image(Geom): self.flip = False def render1(self): - self.img.blit( - -self.width / 2, -self.height / 2, width=self.width, height=self.height - ) + self.img.blit(-self.width / 2, -self.height / 2, width=self.width, height=self.height) # ================================================================ @@ -435,9 +416,7 @@ class SimpleImageViewer(object): self.isopen = False assert len(arr.shape) == 3, "You passed in an image with the wrong number shape" - image = pyglet.image.ImageData( - arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3 - ) + image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST) texture = image.get_texture() texture.width = self.width diff --git a/gym/envs/mujoco/ant.py b/gym/envs/mujoco/ant.py index ae0a2f8dc..5c2a65f5d 100644 --- a/gym/envs/mujoco/ant.py +++ b/gym/envs/mujoco/ant.py @@ -14,9 +14,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): xposafter = self.get_body_com("torso")[0] forward_reward = (xposafter - xposbefore) / self.dt ctrl_cost = 0.5 * np.square(a).sum() - contact_cost = ( - 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) - ) + contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) survive_reward = 1.0 reward = forward_reward - ctrl_cost - contact_cost + survive_reward state = self.state_vector() @@ -45,9 +43,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): ) def reset_model(self): - qpos = self.init_qpos + self.np_random.uniform( - size=self.model.nq, low=-0.1, high=0.1 - ) + qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1) qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/ant_v3.py b/gym/envs/mujoco/ant_v3.py index 473f85daa..eaf0d5a65 100644 --- a/gym/envs/mujoco/ant_v3.py +++ b/gym/envs/mujoco/ant_v3.py @@ -34,18 +34,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._reset_noise_scale = reset_noise_scale - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation mujoco_env.MujocoEnv.__init__(self, xml_file, 5) @property def healthy_reward(self): - return ( - float(self.is_healthy or self._terminate_when_unhealthy) - * self._healthy_reward - ) + return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward def control_cost(self, action): control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) @@ -60,9 +55,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): @property def contact_cost(self): - contact_cost = self._contact_cost_weight * np.sum( - np.square(self.contact_forces) - ) + contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces)) return contact_cost @property @@ -128,12 +121,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): noise_low = -self._reset_noise_scale noise_high = self._reset_noise_scale - qpos = self.init_qpos + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nq - ) - qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( - self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv) self.set_state(qpos, qvel) observation = self._get_obs() diff --git a/gym/envs/mujoco/half_cheetah.py b/gym/envs/mujoco/half_cheetah.py index 7044e9459..e0e3636dc 100644 --- a/gym/envs/mujoco/half_cheetah.py +++ b/gym/envs/mujoco/half_cheetah.py @@ -28,9 +28,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): ) def reset_model(self): - qpos = self.init_qpos + self.np_random.uniform( - low=-0.1, high=0.1, size=self.model.nq - ) + qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/half_cheetah_v3.py b/gym/envs/mujoco/half_cheetah_v3.py index 84f69b2bd..416dcd246 100644 --- a/gym/envs/mujoco/half_cheetah_v3.py +++ b/gym/envs/mujoco/half_cheetah_v3.py @@ -25,9 +25,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._reset_noise_scale = reset_noise_scale - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation mujoco_env.MujocoEnv.__init__(self, xml_file, 5) @@ -71,12 +69,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): noise_low = -self._reset_noise_scale noise_high = self._reset_noise_scale - qpos = self.init_qpos + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nq - ) - qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( - self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv) self.set_state(qpos, qvel) diff --git a/gym/envs/mujoco/hopper.py b/gym/envs/mujoco/hopper.py index 46bf0a7f5..6e440229c 100644 --- a/gym/envs/mujoco/hopper.py +++ b/gym/envs/mujoco/hopper.py @@ -17,27 +17,16 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): reward += alive_bonus reward -= 1e-3 * np.square(a).sum() s = self.state_vector() - done = not ( - np.isfinite(s).all() - and (np.abs(s[2:]) < 100).all() - and (height > 0.7) - and (abs(ang) < 0.2) - ) + done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2)) ob = self._get_obs() return ob, reward, done, {} def _get_obs(self): - return np.concatenate( - [self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)] - ) + return np.concatenate([self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]) def reset_model(self): - qpos = self.init_qpos + self.np_random.uniform( - low=-0.005, high=0.005, size=self.model.nq - ) - qvel = self.init_qvel + self.np_random.uniform( - low=-0.005, high=0.005, size=self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq) + qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv) self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/hopper_v3.py b/gym/envs/mujoco/hopper_v3.py index 8bb800241..65e2769f4 100644 --- a/gym/envs/mujoco/hopper_v3.py +++ b/gym/envs/mujoco/hopper_v3.py @@ -40,18 +40,13 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._reset_noise_scale = reset_noise_scale - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation mujoco_env.MujocoEnv.__init__(self, xml_file, 4) @property def healthy_reward(self): - return ( - float(self.is_healthy or self._terminate_when_unhealthy) - * self._healthy_reward - ) + return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward def control_cost(self, action): control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) @@ -117,12 +112,8 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): noise_low = -self._reset_noise_scale noise_high = self._reset_noise_scale - qpos = self.init_qpos + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nq - ) - qvel = self.init_qvel + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) self.set_state(qpos, qvel) diff --git a/gym/envs/mujoco/humanoid_v3.py b/gym/envs/mujoco/humanoid_v3.py index 94a809d3e..9e021533f 100644 --- a/gym/envs/mujoco/humanoid_v3.py +++ b/gym/envs/mujoco/humanoid_v3.py @@ -43,18 +43,13 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._reset_noise_scale = reset_noise_scale - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation mujoco_env.MujocoEnv.__init__(self, xml_file, 5) @property def healthy_reward(self): - return ( - float(self.is_healthy or self._terminate_when_unhealthy) - * self._healthy_reward - ) + return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward def control_cost(self, action): control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl)) @@ -143,12 +138,8 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): noise_low = -self._reset_noise_scale noise_high = self._reset_noise_scale - qpos = self.init_qpos + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nq - ) - qvel = self.init_qvel + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) self.set_state(qpos, qvel) observation = self._get_obs() diff --git a/gym/envs/mujoco/inverted_double_pendulum.py b/gym/envs/mujoco/inverted_double_pendulum.py index 81dd3c561..782a5bea7 100644 --- a/gym/envs/mujoco/inverted_double_pendulum.py +++ b/gym/envs/mujoco/inverted_double_pendulum.py @@ -33,8 +33,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): def reset_model(self): self.set_state( - self.init_qpos - + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), + self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), self.init_qvel + self.np_random.randn(self.model.nv) * 0.1, ) return self._get_obs() diff --git a/gym/envs/mujoco/inverted_pendulum.py b/gym/envs/mujoco/inverted_pendulum.py index b87ec498b..36eafbe70 100644 --- a/gym/envs/mujoco/inverted_pendulum.py +++ b/gym/envs/mujoco/inverted_pendulum.py @@ -17,12 +17,8 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): return ob, reward, done, {} def reset_model(self): - qpos = self.init_qpos + self.np_random.uniform( - size=self.model.nq, low=-0.01, high=0.01 - ) - qvel = self.init_qvel + self.np_random.uniform( - size=self.model.nv, low=-0.01, high=0.01 - ) + qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01) + qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/mujoco_env.py b/gym/envs/mujoco/mujoco_env.py index 4d0fe8865..3aa095813 100644 --- a/gym/envs/mujoco/mujoco_env.py +++ b/gym/envs/mujoco/mujoco_env.py @@ -22,14 +22,7 @@ DEFAULT_SIZE = 500 def convert_observation_to_space(observation): if isinstance(observation, dict): - space = spaces.Dict( - OrderedDict( - [ - (key, convert_observation_to_space(value)) - for key, value in observation.items() - ] - ) - ) + space = spaces.Dict(OrderedDict([(key, convert_observation_to_space(value)) for key, value in observation.items()])) elif isinstance(observation, np.ndarray): low = np.full(observation.shape, -float("inf"), dtype=np.float32) high = np.full(observation.shape, float("inf"), dtype=np.float32) @@ -117,9 +110,7 @@ class MujocoEnv(gym.Env): def set_state(self, qpos, qvel): assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) old_state = self.sim.get_state() - new_state = mujoco_py.MjSimState( - old_state.time, qpos, qvel, old_state.act, old_state.udd_state - ) + new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state) self.sim.set_state(new_state) self.sim.forward() @@ -142,10 +133,7 @@ class MujocoEnv(gym.Env): ): if mode == "rgb_array" or mode == "depth_array": if camera_id is not None and camera_name is not None: - raise ValueError( - "Both `camera_id` and `camera_name` cannot be" - " specified at the same time." - ) + raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.") no_camera_specified = camera_name is None and camera_id is None if no_camera_specified: diff --git a/gym/envs/mujoco/pusher.py b/gym/envs/mujoco/pusher.py index e82adc663..ee54db64d 100644 --- a/gym/envs/mujoco/pusher.py +++ b/gym/envs/mujoco/pusher.py @@ -44,9 +44,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): qpos[-4:-2] = self.cylinder_pos qpos[-2:] = self.goal_pos - qvel = self.init_qvel + self.np_random.uniform( - low=-0.005, high=0.005, size=self.model.nv - ) + qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv) qvel[-4:] = 0 self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/reacher.py b/gym/envs/mujoco/reacher.py index 3cc1f30e0..48be6f8e5 100644 --- a/gym/envs/mujoco/reacher.py +++ b/gym/envs/mujoco/reacher.py @@ -22,18 +22,13 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): self.viewer.cam.trackbodyid = 0 def reset_model(self): - qpos = ( - self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) - + self.init_qpos - ) + qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos while True: self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2) if np.linalg.norm(self.goal) < 0.2: break qpos[-2:] = self.goal - qvel = self.init_qvel + self.np_random.uniform( - low=-0.005, high=0.005, size=self.model.nv - ) + qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv) qvel[-2:] = 0 self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/striker.py b/gym/envs/mujoco/striker.py index 186f9d7fe..73aa8ee72 100644 --- a/gym/envs/mujoco/striker.py +++ b/gym/envs/mujoco/striker.py @@ -62,9 +62,7 @@ class StrikerEnv(mujoco_env.MujocoEnv, utils.EzPickle): diff = self.ball - self.goal angle = -np.arctan(diff[0] / (diff[1] + 1e-8)) qpos[-1] = angle / 3.14 - qvel = self.init_qvel + self.np_random.uniform( - low=-0.1, high=0.1, size=self.model.nv - ) + qvel = self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv) qvel[7:] = 0 self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/swimmer.py b/gym/envs/mujoco/swimmer.py index f903184da..2071c3ebf 100644 --- a/gym/envs/mujoco/swimmer.py +++ b/gym/envs/mujoco/swimmer.py @@ -26,9 +26,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): def reset_model(self): self.set_state( - self.init_qpos - + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), - self.init_qvel - + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv), + self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), + self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv), ) return self._get_obs() diff --git a/gym/envs/mujoco/swimmer_v3.py b/gym/envs/mujoco/swimmer_v3.py index 8cf794d47..0f73af038 100644 --- a/gym/envs/mujoco/swimmer_v3.py +++ b/gym/envs/mujoco/swimmer_v3.py @@ -22,9 +22,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._reset_noise_scale = reset_noise_scale - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation mujoco_env.MujocoEnv.__init__(self, xml_file, 4) @@ -74,12 +72,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): noise_low = -self._reset_noise_scale noise_high = self._reset_noise_scale - qpos = self.init_qpos + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nq - ) - qvel = self.init_qvel + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) self.set_state(qpos, qvel) diff --git a/gym/envs/mujoco/thrower.py b/gym/envs/mujoco/thrower.py index b8b21a22a..1aa8f34e1 100644 --- a/gym/envs/mujoco/thrower.py +++ b/gym/envs/mujoco/thrower.py @@ -48,9 +48,7 @@ class ThrowerEnv(mujoco_env.MujocoEnv, utils.EzPickle): ) qpos[-9:-7] = self.goal - qvel = self.init_qvel + self.np_random.uniform( - low=-0.005, high=0.005, size=self.model.nv - ) + qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv) qvel[7:] = 0 self.set_state(qpos, qvel) return self._get_obs() diff --git a/gym/envs/mujoco/walker2d.py b/gym/envs/mujoco/walker2d.py index 5f49b4bf2..efaa1207f 100644 --- a/gym/envs/mujoco/walker2d.py +++ b/gym/envs/mujoco/walker2d.py @@ -27,10 +27,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): def reset_model(self): self.set_state( - self.init_qpos - + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq), - self.init_qvel - + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv), + self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq), + self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv), ) return self._get_obs() diff --git a/gym/envs/mujoco/walker2d_v3.py b/gym/envs/mujoco/walker2d_v3.py index eee6bb7d4..b6150d95b 100644 --- a/gym/envs/mujoco/walker2d_v3.py +++ b/gym/envs/mujoco/walker2d_v3.py @@ -37,18 +37,13 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): self._reset_noise_scale = reset_noise_scale - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) + self._exclude_current_positions_from_observation = exclude_current_positions_from_observation mujoco_env.MujocoEnv.__init__(self, xml_file, 4) @property def healthy_reward(self): - return ( - float(self.is_healthy or self._terminate_when_unhealthy) - * self._healthy_reward - ) + return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward def control_cost(self, action): control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) @@ -110,12 +105,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): noise_low = -self._reset_noise_scale noise_high = self._reset_noise_scale - qpos = self.init_qpos + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nq - ) - qvel = self.init_qvel + self.np_random.uniform( - low=noise_low, high=noise_high, size=self.model.nv - ) + qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) self.set_state(qpos, qvel) diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 347b77a2e..32a4f521a 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -108,11 +108,7 @@ class EnvRegistry(object): # reset/step. Set _gym_disable_underscore_compat = True on # your environment if you use these methods and don't want # compatibility code to be invoked. - if ( - hasattr(env, "_reset") - and hasattr(env, "_step") - and not getattr(env, "_gym_disable_underscore_compat", False) - ): + if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False): patch_deprecated_methods(env) if env.spec.max_episode_steps is not None: from gym.wrappers.time_limit import TimeLimit @@ -158,11 +154,7 @@ class EnvRegistry(object): if env_name == valid_env_spec._env_name ] if matching_envs: - raise error.DeprecatedEnv( - "Env {} not found (valid versions include {})".format( - id, matching_envs - ) - ) + raise error.DeprecatedEnv("Env {} not found (valid versions include {})".format(id, matching_envs)) else: raise error.UnregisteredEnv("No registered env with id: {}".format(id)) diff --git a/gym/envs/robotics/fetch_env.py b/gym/envs/robotics/fetch_env.py index 8142acab0..d484c7874 100644 --- a/gym/envs/robotics/fetch_env.py +++ b/gym/envs/robotics/fetch_env.py @@ -81,9 +81,7 @@ class FetchEnv(robot_env.RobotEnv): def _set_action(self, action): assert action.shape == (4,) - action = ( - action.copy() - ) # ensure that we don't change the action outside of this scope + action = action.copy() # ensure that we don't change the action outside of this scope pos_ctrl, gripper_ctrl = action[:3], action[3] pos_ctrl *= 0.05 # limit maximum change in position @@ -120,13 +118,9 @@ class FetchEnv(robot_env.RobotEnv): object_rel_pos = object_pos - grip_pos object_velp -= grip_velp else: - object_pos = ( - object_rot - ) = object_velp = object_velr = object_rel_pos = np.zeros(0) + object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0) gripper_state = robot_qpos[-2:] - gripper_vel = ( - robot_qvel[-2:] * dt - ) # change to a scalar if the gripper is made symmetric + gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric if not self.has_object: achieved_goal = grip_pos.copy() @@ -175,9 +169,7 @@ class FetchEnv(robot_env.RobotEnv): if self.has_object: object_xpos = self.initial_gripper_xpos[:2] while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1: - object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform( - -self.obj_range, self.obj_range, size=2 - ) + object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2) object_qpos = self.sim.data.get_joint_qpos("object0:joint") assert object_qpos.shape == (7,) object_qpos[:2] = object_xpos @@ -188,17 +180,13 @@ class FetchEnv(robot_env.RobotEnv): def _sample_goal(self): if self.has_object: - goal = self.initial_gripper_xpos[:3] + self.np_random.uniform( - -self.target_range, self.target_range, size=3 - ) + goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3) goal += self.target_offset goal[2] = self.height_offset if self.target_in_the_air and self.np_random.uniform() < 0.5: goal[2] += self.np_random.uniform(0, 0.45) else: - goal = self.initial_gripper_xpos[:3] + self.np_random.uniform( - -self.target_range, self.target_range, size=3 - ) + goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3) return goal.copy() def _is_success(self, achieved_goal, desired_goal): @@ -212,9 +200,9 @@ class FetchEnv(robot_env.RobotEnv): self.sim.forward() # Move end effector into position. - gripper_target = np.array( - [-0.498, 0.005, -0.431 + self.gripper_extra_height] - ) + self.sim.data.get_site_xpos("robot0:grip") + gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos( + "robot0:grip" + ) gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0]) self.sim.data.set_mocap_pos("robot0:mocap", gripper_target) self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation) diff --git a/gym/envs/robotics/hand/manipulate.py b/gym/envs/robotics/hand/manipulate.py index e3eca062d..0c2a5a98e 100644 --- a/gym/envs/robotics/hand/manipulate.py +++ b/gym/envs/robotics/hand/manipulate.py @@ -74,9 +74,7 @@ class ManipulateEnv(hand_env.HandEnv): self.target_position = target_position self.target_rotation = target_rotation self.target_position_range = target_position_range - self.parallel_quats = [ - rotations.euler2quat(r) for r in rotations.get_parallel_rotations() - ] + self.parallel_quats = [rotations.euler2quat(r) for r in rotations.get_parallel_rotations()] self.randomize_initial_rotation = randomize_initial_rotation self.randomize_initial_position = randomize_initial_position self.distance_threshold = distance_threshold @@ -182,9 +180,7 @@ class ManipulateEnv(hand_env.HandEnv): angle = self.np_random.uniform(-np.pi, np.pi) axis = np.array([0.0, 0.0, 1.0]) z_quat = quat_from_angle_and_axis(angle, axis) - parallel_quat = self.parallel_quats[ - self.np_random.randint(len(self.parallel_quats)) - ] + parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))] offset_quat = rotations.quat_mul(z_quat, parallel_quat) initial_quat = rotations.quat_mul(initial_quat, offset_quat) elif self.target_rotation in ["xyz", "ignore"]: @@ -195,9 +191,7 @@ class ManipulateEnv(hand_env.HandEnv): elif self.target_rotation == "fixed": pass else: - raise error.Error( - 'Unknown target_rotation option "{}".'.format(self.target_rotation) - ) + raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation)) # Randomize initial position. if self.randomize_initial_position: @@ -229,17 +223,13 @@ class ManipulateEnv(hand_env.HandEnv): target_pos = None if self.target_position == "random": assert self.target_position_range.shape == (3, 2) - offset = self.np_random.uniform( - self.target_position_range[:, 0], self.target_position_range[:, 1] - ) + offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1]) assert offset.shape == (3,) target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset elif self.target_position in ["ignore", "fixed"]: target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] else: - raise error.Error( - 'Unknown target_position option "{}".'.format(self.target_position) - ) + raise error.Error('Unknown target_position option "{}".'.format(self.target_position)) assert target_pos is not None assert target_pos.shape == (3,) @@ -253,9 +243,7 @@ class ManipulateEnv(hand_env.HandEnv): angle = self.np_random.uniform(-np.pi, np.pi) axis = np.array([0.0, 0.0, 1.0]) target_quat = quat_from_angle_and_axis(angle, axis) - parallel_quat = self.parallel_quats[ - self.np_random.randint(len(self.parallel_quats)) - ] + parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))] target_quat = rotations.quat_mul(target_quat, parallel_quat) elif self.target_rotation == "xyz": angle = self.np_random.uniform(-np.pi, np.pi) @@ -264,9 +252,7 @@ class ManipulateEnv(hand_env.HandEnv): elif self.target_rotation in ["ignore", "fixed"]: target_quat = self.sim.data.get_joint_qpos("object:joint") else: - raise error.Error( - 'Unknown target_rotation option "{}".'.format(self.target_rotation) - ) + raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation)) assert target_quat is not None assert target_quat.shape == (4,) @@ -293,12 +279,8 @@ class ManipulateEnv(hand_env.HandEnv): def _get_obs(self): robot_qpos, robot_qvel = robot_get_obs(self.sim) object_qvel = self.sim.data.get_joint_qvel("object:joint") - achieved_goal = ( - self._get_achieved_goal().ravel() - ) # this contains the object position + rotation - observation = np.concatenate( - [robot_qpos, robot_qvel, object_qvel, achieved_goal] - ) + achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation + observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, achieved_goal]) return { "observation": observation.copy(), "achieved_goal": achieved_goal.copy(), @@ -307,9 +289,7 @@ class ManipulateEnv(hand_env.HandEnv): class HandBlockEnv(ManipulateEnv, utils.EzPickle): - def __init__( - self, target_position="random", target_rotation="xyz", reward_type="sparse" - ): + def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"): utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) ManipulateEnv.__init__( self, @@ -322,9 +302,7 @@ class HandBlockEnv(ManipulateEnv, utils.EzPickle): class HandEggEnv(ManipulateEnv, utils.EzPickle): - def __init__( - self, target_position="random", target_rotation="xyz", reward_type="sparse" - ): + def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"): utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) ManipulateEnv.__init__( self, @@ -337,9 +315,7 @@ class HandEggEnv(ManipulateEnv, utils.EzPickle): class HandPenEnv(ManipulateEnv, utils.EzPickle): - def __init__( - self, target_position="random", target_rotation="xyz", reward_type="sparse" - ): + def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"): utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) ManipulateEnv.__init__( self, diff --git a/gym/envs/robotics/hand/manipulate_touch_sensors.py b/gym/envs/robotics/hand/manipulate_touch_sensors.py index 63d49c95a..76e7cb687 100644 --- a/gym/envs/robotics/hand/manipulate_touch_sensors.py +++ b/gym/envs/robotics/hand/manipulate_touch_sensors.py @@ -70,16 +70,12 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv): for ( k, v, - ) in ( - self.sim.model._sensor_name2id.items() - ): # get touch sensor site names and their ids + ) in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids if "robot0:TS_" in k: self._touch_sensor_id_site_id.append( ( v, - self.sim.model._site_name2id[ - k.replace("robot0:TS_", "robot0:T_") - ], + self.sim.model._site_name2id[k.replace("robot0:TS_", "robot0:T_")], ) ) self._touch_sensor_id.append(v) @@ -93,15 +89,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv): obs = self._get_obs() self.observation_space = spaces.Dict( dict( - desired_goal=spaces.Box( - -np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32" - ), - achieved_goal=spaces.Box( - -np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32" - ), - observation=spaces.Box( - -np.inf, np.inf, shape=obs["observation"].shape, dtype="float32" - ), + desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"), + achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"), + observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"), ) ) @@ -117,9 +107,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv): def _get_obs(self): robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim) object_qvel = self.sim.data.get_joint_qvel("object:joint") - achieved_goal = ( - self._get_achieved_goal().ravel() - ) # this contains the object position + rotation + achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation touch_values = [] # get touch sensor readings. if there is one, set value to 1 if self.touch_get_obs == "sensordata": touch_values = self.sim.data.sensordata[self._touch_sensor_id] @@ -127,9 +115,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv): touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0 elif self.touch_get_obs == "log": touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0) - observation = np.concatenate( - [robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal] - ) + observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal]) return { "observation": observation.copy(), @@ -146,9 +132,7 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle): touch_get_obs="sensordata", reward_type="sparse", ): - utils.EzPickle.__init__( - self, target_position, target_rotation, touch_get_obs, reward_type - ) + utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type) ManipulateTouchSensorsEnv.__init__( self, model_path=MANIPULATE_BLOCK_XML, @@ -168,9 +152,7 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle): touch_get_obs="sensordata", reward_type="sparse", ): - utils.EzPickle.__init__( - self, target_position, target_rotation, touch_get_obs, reward_type - ) + utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type) ManipulateTouchSensorsEnv.__init__( self, model_path=MANIPULATE_EGG_XML, @@ -190,9 +172,7 @@ class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle): touch_get_obs="sensordata", reward_type="sparse", ): - utils.EzPickle.__init__( - self, target_position, target_rotation, touch_get_obs, reward_type - ) + utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type) ManipulateTouchSensorsEnv.__init__( self, model_path=MANIPULATE_PEN_XML, diff --git a/gym/envs/robotics/hand/reach.py b/gym/envs/robotics/hand/reach.py index f09e1673b..25eda013c 100644 --- a/gym/envs/robotics/hand/reach.py +++ b/gym/envs/robotics/hand/reach.py @@ -96,9 +96,7 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle): self.sim.forward() self.initial_goal = self._get_achieved_goal().copy() - self.palm_xpos = self.sim.data.body_xpos[ - self.sim.model.body_name2id("robot0:palm") - ].copy() + self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id("robot0:palm")].copy() def _get_obs(self): robot_qpos, robot_qvel = robot_get_obs(self.sim) @@ -155,7 +153,5 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle): for finger_idx in range(5): site_name = "finger{}".format(finger_idx) site_id = self.sim.model.site_name2id(site_name) - self.sim.model.site_pos[site_id] = ( - achieved_goal[finger_idx] - sites_offset[site_id] - ) + self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id] self.sim.forward() diff --git a/gym/envs/robotics/hand_env.py b/gym/envs/robotics/hand_env.py index aaa0fec1d..5d0f446ce 100644 --- a/gym/envs/robotics/hand_env.py +++ b/gym/envs/robotics/hand_env.py @@ -30,22 +30,14 @@ class HandEnv(robot_env.RobotEnv): if self.relative_control: actuation_center = np.zeros_like(action) for i in range(self.sim.data.ctrl.shape[0]): - actuation_center[i] = self.sim.data.get_joint_qpos( - self.sim.model.actuator_names[i].replace(":A_", ":") - ) + actuation_center[i] = self.sim.data.get_joint_qpos(self.sim.model.actuator_names[i].replace(":A_", ":")) for joint_name in ["FF", "MF", "RF", "LF"]: - act_idx = self.sim.model.actuator_name2id( - "robot0:A_{}J1".format(joint_name) - ) - actuation_center[act_idx] += self.sim.data.get_joint_qpos( - "robot0:{}J0".format(joint_name) - ) + act_idx = self.sim.model.actuator_name2id("robot0:A_{}J1".format(joint_name)) + actuation_center[act_idx] += self.sim.data.get_joint_qpos("robot0:{}J0".format(joint_name)) else: actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0 self.sim.data.ctrl[:] = actuation_center + action * actuation_range - self.sim.data.ctrl[:] = np.clip( - self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1] - ) + self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]) def _viewer_setup(self): body_id = self.sim.model.body_name2id("robot0:palm") diff --git a/gym/envs/robotics/robot_env.py b/gym/envs/robotics/robot_env.py index 1fe6ca653..19e5145cc 100644 --- a/gym/envs/robotics/robot_env.py +++ b/gym/envs/robotics/robot_env.py @@ -46,15 +46,9 @@ class RobotEnv(gym.GoalEnv): self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32") self.observation_space = spaces.Dict( dict( - desired_goal=spaces.Box( - -np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32" - ), - achieved_goal=spaces.Box( - -np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32" - ), - observation=spaces.Box( - -np.inf, np.inf, shape=obs["observation"].shape, dtype="float32" - ), + desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"), + achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"), + observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"), ) ) diff --git a/gym/envs/robotics/rotations.py b/gym/envs/robotics/rotations.py index fe241a108..7a5891499 100644 --- a/gym/envs/robotics/rotations.py +++ b/gym/envs/robotics/rotations.py @@ -164,12 +164,8 @@ def mat2euler(mat): -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]), ) - euler[..., 1] = np.where( - condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy) - ) - euler[..., 0] = np.where( - condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0 - ) + euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)) + euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0) return euler diff --git a/gym/envs/robotics/utils.py b/gym/envs/robotics/utils.py index baa767200..500031bc2 100644 --- a/gym/envs/robotics/utils.py +++ b/gym/envs/robotics/utils.py @@ -75,15 +75,9 @@ def reset_mocap2body_xpos(sim): values as the bodies they're welded to. """ - if ( - sim.model.eq_type is None - or sim.model.eq_obj1id is None - or sim.model.eq_obj2id is None - ): + if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None: return - for eq_type, obj1_id, obj2_id in zip( - sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id - ): + for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id): if eq_type != mujoco_py.const.EQ_WELD: continue diff --git a/gym/envs/tests/spec_list.py b/gym/envs/tests/spec_list.py index 5ebda7f64..bf2aeef7a 100644 --- a/gym/envs/tests/spec_list.py +++ b/gym/envs/tests/spec_list.py @@ -2,10 +2,7 @@ from gym import envs, logger import os -SKIP_MUJOCO_WARNING_MESSAGE = ( - "Cannot run mujoco test (either license key not found or mujoco not" - "installed properly)." -) +SKIP_MUJOCO_WARNING_MESSAGE = "Cannot run mujoco test (either license key not found or mujoco not" "installed properly)." skip_mujoco = not (os.environ.get("MUJOCO_KEY")) @@ -21,9 +18,7 @@ def should_skip_env_spec_for_tests(spec): # troublesome to run frequently ep = spec.entry_point # Skip mujoco tests for pull request CI - if skip_mujoco and ( - ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:") - ): + if skip_mujoco and (ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")): return True try: import atari_py @@ -39,11 +34,7 @@ def should_skip_env_spec_for_tests(spec): if ( "GoEnv" in ep or "HexEnv" in ep - or ( - ep.startswith("gym.envs.atari") - and not spec.id.startswith("Pong") - and not spec.id.startswith("Seaquest") - ) + or (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest")) ): logger.warn("Skipping tests for env {}".format(ep)) return True diff --git a/gym/envs/tests/test_determinism.py b/gym/envs/tests/test_determinism.py index 1505c559d..5d2eefb1c 100644 --- a/gym/envs/tests/test_determinism.py +++ b/gym/envs/tests/test_determinism.py @@ -25,9 +25,7 @@ def test_env(spec): step_responses2 = [env2.step(action) for action in action_samples2] env2.close() - for i, (action_sample1, action_sample2) in enumerate( - zip(action_samples1, action_samples2) - ): + for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)): try: assert_equals(action_sample1, action_sample2) except AssertionError: @@ -35,11 +33,7 @@ def test_env(spec): print("env2.action_space=", env2.action_space) print("action_samples1=", action_samples1) print("action_samples2=", action_samples2) - print( - "[{}] action_sample1: {}, action_sample2: {}".format( - i, action_sample1, action_sample2 - ) - ) + print("[{}] action_sample1: {}, action_sample2: {}".format(i, action_sample1, action_sample2)) raise # Don't check rollout equality if it's a a nondeterministic @@ -49,9 +43,7 @@ def test_env(spec): assert_equals(initial_observation1, initial_observation2) - for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate( - zip(step_responses1, step_responses2) - ): + for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)): assert_equals(o1, o2, "[{}] ".format(i)) assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2) assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2) @@ -66,9 +58,7 @@ def test_env(spec): def assert_equals(a, b, prefix=None): assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b) if isinstance(a, dict): - assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format( - prefix, a, b - ) + assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b) for k in a.keys(): v_a = a[k] diff --git a/gym/envs/tests/test_envs.py b/gym/envs/tests/test_envs.py index 61dc0a5fa..17b3c006e 100644 --- a/gym/envs/tests/test_envs.py +++ b/gym/envs/tests/test_envs.py @@ -24,9 +24,7 @@ def test_env(spec): assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob) a = act_space.sample() observation, reward, done, _info = env.step(a) - assert ob_space.contains(observation), "Step observation: {!r} not in space".format( - observation - ) + assert ob_space.contains(observation), "Step observation: {!r} not in space".format(observation) assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env) assert isinstance(done, bool), "Expected {} to be a boolean".format(done) diff --git a/gym/envs/tests/test_envs_semantics.py b/gym/envs/tests/test_envs_semantics.py index b82e2eca3..e2c272d5b 100644 --- a/gym/envs/tests/test_envs_semantics.py +++ b/gym/envs/tests/test_envs_semantics.py @@ -81,9 +81,7 @@ def test_env_semantics(spec): if spec.id not in rollout_dict: if not spec.nondeterministic: logger.warn( - "Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format( - spec.id - ) + "Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id) ) return @@ -100,21 +98,15 @@ def test_env_semantics(spec): ) if rollout_dict[spec.id]["actions"] != actions_now: errors.append( - "Actions not equal for {} -- expected {} but got {}".format( - spec.id, rollout_dict[spec.id]["actions"], actions_now - ) + "Actions not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["actions"], actions_now) ) if rollout_dict[spec.id]["rewards"] != rewards_now: errors.append( - "Rewards not equal for {} -- expected {} but got {}".format( - spec.id, rollout_dict[spec.id]["rewards"], rewards_now - ) + "Rewards not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["rewards"], rewards_now) ) if rollout_dict[spec.id]["dones"] != dones_now: errors.append( - "Dones not equal for {} -- expected {} but got {}".format( - spec.id, rollout_dict[spec.id]["dones"], dones_now - ) + "Dones not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["dones"], dones_now) ) if len(errors): for error in errors: diff --git a/gym/envs/tests/test_mujoco_v2_to_v3_conversion.py b/gym/envs/tests/test_mujoco_v2_to_v3_conversion.py index a379a8454..fa8986406 100644 --- a/gym/envs/tests/test_mujoco_v2_to_v3_conversion.py +++ b/gym/envs/tests/test_mujoco_v2_to_v3_conversion.py @@ -4,9 +4,7 @@ from gym import envs from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE -def verify_environments_match( - old_environment_id, new_environment_id, seed=1, num_actions=1000 -): +def verify_environments_match(old_environment_id, new_environment_id, seed=1, num_actions=1000): old_environment = envs.make(old_environment_id) new_environment = envs.make(new_environment_id) diff --git a/gym/envs/tests/test_registration.py b/gym/envs/tests/test_registration.py index 3dbf58c53..a363e2afc 100644 --- a/gym/envs/tests/test_registration.py +++ b/gym/envs/tests/test_registration.py @@ -83,8 +83,6 @@ def test_malformed_lookup(): try: registry.spec(u"“Breakout-v0”") except error.Error as e: - assert "malformed environment ID" in "{}".format( - e - ), "Unexpected message: {}".format(e) + assert "malformed environment ID" in "{}".format(e), "Unexpected message: {}".format(e) else: assert False diff --git a/gym/envs/toy_text/blackjack.py b/gym/envs/toy_text/blackjack.py index 7b8b1d07d..1bd2576d5 100644 --- a/gym/envs/toy_text/blackjack.py +++ b/gym/envs/toy_text/blackjack.py @@ -75,9 +75,7 @@ class BlackjackEnv(gym.Env): def __init__(self, natural=False): self.action_space = spaces.Discrete(2) - self.observation_space = spaces.Tuple( - (spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)) - ) + self.observation_space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))) self.seed() # Flag to payout 1.5 on a "natural" blackjack win, like casino rules diff --git a/gym/envs/toy_text/frozen_lake.py b/gym/envs/toy_text/frozen_lake.py index 1b6de5985..509d56300 100644 --- a/gym/envs/toy_text/frozen_lake.py +++ b/gym/envs/toy_text/frozen_lake.py @@ -141,9 +141,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv): else: if is_slippery: for b in [(a - 1) % 4, a, (a + 1) % 4]: - li.append( - (1.0 / 3.0, *update_probability_matrix(row, col, b)) - ) + li.append((1.0 / 3.0, *update_probability_matrix(row, col, b))) else: li.append((1.0, *update_probability_matrix(row, col, a))) @@ -157,9 +155,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv): desc = [[c.decode("utf-8") for c in line] for line in desc] desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True) if self.lastaction is not None: - outfile.write( - " ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction]) - ) + outfile.write(" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction])) else: outfile.write("\n") outfile.write("\n".join("".join(line) for line in desc) + "\n") diff --git a/gym/envs/toy_text/guessing_game.py b/gym/envs/toy_text/guessing_game.py index 79854f179..699147b1d 100644 --- a/gym/envs/toy_text/guessing_game.py +++ b/gym/envs/toy_text/guessing_game.py @@ -80,11 +80,7 @@ class GuessingGame(gym.Env): reward = 0 done = False - if ( - (self.number - self.range * 0.01) - < action - < (self.number + self.range * 0.01) - ): + if (self.number - self.range * 0.01) < action < (self.number + self.range * 0.01): reward = 1 done = True diff --git a/gym/envs/toy_text/hotter_colder.py b/gym/envs/toy_text/hotter_colder.py index b427d9c3a..055a498d2 100644 --- a/gym/envs/toy_text/hotter_colder.py +++ b/gym/envs/toy_text/hotter_colder.py @@ -62,10 +62,7 @@ class HotterColder(gym.Env): elif action > self.number: self.observation = 3 - reward = ( - (min(action, self.number) + self.bounds) - / (max(action, self.number) + self.bounds) - ) ** 2 + reward = ((min(action, self.number) + self.bounds) / (max(action, self.number) + self.bounds)) ** 2 self.guess_count += 1 done = self.guess_count >= self.guess_max diff --git a/gym/envs/toy_text/kellycoinflip.py b/gym/envs/toy_text/kellycoinflip.py index 1a47c0ae6..cb6f7f6af 100644 --- a/gym/envs/toy_text/kellycoinflip.py +++ b/gym/envs/toy_text/kellycoinflip.py @@ -65,9 +65,7 @@ class KellyCoinflipEnv(gym.Env): return [seed] def step(self, action): - bet_in_dollars = min( - action / 100.0, self.wealth - ) # action = desired bet in pennies + bet_in_dollars = min(action / 100.0, self.wealth) # action = desired bet in pennies self.rounds -= 1 coinflip = flip(self.edge, self.np_random) @@ -149,35 +147,19 @@ class KellyCoinflipGeneralizedEnv(gym.Env): edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta) if self.clip_distributions: # (clip/resample some parameters to be able to fix obs/action space sizes/bounds) - max_wealth_bound = round( - genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m) - ) + max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m)) max_wealth = max_wealth_bound + 1.0 while max_wealth > max_wealth_bound: - max_wealth = round( - genpareto.rvs( - max_wealth_alpha, max_wealth_m, random_state=self.np_random - ) - ) - max_rounds_bound = int( - round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)) - ) + max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random)) + max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd))) max_rounds = max_rounds_bound + 1 while max_rounds > max_rounds_bound: - max_rounds = int( - round(self.np_random.normal(max_rounds_mean, max_rounds_sd)) - ) + max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd))) else: - max_wealth = round( - genpareto.rvs( - max_wealth_alpha, max_wealth_m, random_state=self.np_random - ) - ) + max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random)) max_wealth_bound = max_wealth - max_rounds = int( - round(self.np_random.normal(max_rounds_mean, max_rounds_sd)) - ) + max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd))) max_rounds_bound = max_rounds # add an additional global variable which is the sufficient statistic for the @@ -194,9 +176,7 @@ class KellyCoinflipGeneralizedEnv(gym.Env): self.action_space = spaces.Discrete(int(max_wealth_bound * 100)) self.observation_space = spaces.Tuple( ( - spaces.Box( - 0, max_wealth_bound, shape=[1], dtype=np.float32 - ), # current wealth + spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth spaces.Discrete(max_rounds_bound + 1), # rounds elapsed spaces.Discrete(max_rounds_bound + 1), # wins spaces.Discrete(max_rounds_bound + 1), # losses diff --git a/gym/envs/toy_text/taxi.py b/gym/envs/toy_text/taxi.py index e9b729da4..c139b719e 100644 --- a/gym/envs/toy_text/taxi.py +++ b/gym/envs/toy_text/taxi.py @@ -80,10 +80,7 @@ class TaxiEnv(discrete.DiscreteEnv): max_col = num_columns - 1 initial_state_distrib = np.zeros(num_states) num_actions = 6 - P = { - state: {action: [] for action in range(num_actions)} - for state in range(num_states) - } + P = {state: {action: [] for action in range(num_actions)} for state in range(num_states)} for row in range(num_rows): for col in range(num_columns): for pass_idx in range(len(locs) + 1): # +1 for being inside taxi @@ -94,9 +91,7 @@ class TaxiEnv(discrete.DiscreteEnv): for action in range(num_actions): # defaults new_row, new_col, new_pass_idx = row, col, pass_idx - reward = ( - -1 - ) # default reward when there is no pickup/dropoff + reward = -1 # default reward when there is no pickup/dropoff done = False taxi_loc = (row, col) @@ -122,14 +117,10 @@ class TaxiEnv(discrete.DiscreteEnv): new_pass_idx = locs.index(taxi_loc) else: # dropoff at wrong location reward = -10 - new_state = self.encode( - new_row, new_col, new_pass_idx, dest_idx - ) + new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx) P[state][action].append((1.0, new_state, reward, done)) initial_state_distrib /= initial_state_distrib.sum() - discrete.DiscreteEnv.__init__( - self, num_states, num_actions, P, initial_state_distrib - ) + discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, initial_state_distrib) def encode(self, taxi_row, taxi_col, pass_loc, dest_idx): # (5) 5, 5, 4 @@ -165,13 +156,9 @@ class TaxiEnv(discrete.DiscreteEnv): return "_" if x == " " else x if pass_idx < 4: - out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize( - out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True - ) + out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True) pi, pj = self.locs[pass_idx] - out[1 + pi][2 * pj + 1] = utils.colorize( - out[1 + pi][2 * pj + 1], "blue", bold=True - ) + out[1 + pi][2 * pj + 1] = utils.colorize(out[1 + pi][2 * pj + 1], "blue", bold=True) else: # passenger in taxi out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize( ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True @@ -181,13 +168,7 @@ class TaxiEnv(discrete.DiscreteEnv): out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta") outfile.write("\n".join(["".join(row) for row in out]) + "\n") if self.lastaction is not None: - outfile.write( - " ({})\n".format( - ["South", "North", "East", "West", "Pickup", "Dropoff"][ - self.lastaction - ] - ) - ) + outfile.write(" ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction])) else: outfile.write("\n") diff --git a/gym/envs/unittest/cube_crash.py b/gym/envs/unittest/cube_crash.py index 61a7f5313..ae288fa22 100644 --- a/gym/envs/unittest/cube_crash.py +++ b/gym/envs/unittest/cube_crash.py @@ -55,9 +55,7 @@ class CubeCrash(gym.Env): self.seed() self.viewer = None - self.observation_space = spaces.Box( - 0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8 - ) + self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8) self.action_space = spaces.Discrete(3) self.reset() @@ -83,16 +81,9 @@ class CubeCrash(gym.Env): self.potential = None self.step_n = 0 while 1: - self.wall_color = ( - self.random_color() if self.use_random_colors else color_white - ) - self.cube_color = ( - self.random_color() if self.use_random_colors else color_green - ) - if ( - np.linalg.norm(self.wall_color - self.bg_color) < 50 - or np.linalg.norm(self.cube_color - self.bg_color) < 50 - ): + self.wall_color = self.random_color() if self.use_random_colors else color_white + self.cube_color = self.random_color() if self.use_random_colors else color_green + if np.linalg.norm(self.wall_color - self.bg_color) < 50 or np.linalg.norm(self.cube_color - self.bg_color) < 50: continue break return self.step(0)[0] @@ -117,9 +108,7 @@ class CubeCrash(gym.Env): self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1, :, ] = self.bg_color - obs[ - self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, : - ] = self.cube_color + obs[self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :] = self.cube_color if self.use_black_screen and self.step_n > 4: obs[:] = np.zeros((3,), dtype=np.uint8) diff --git a/gym/envs/unittest/memorize_digits.py b/gym/envs/unittest/memorize_digits.py index 6b48c9f5f..e1f005a4d 100644 --- a/gym/envs/unittest/memorize_digits.py +++ b/gym/envs/unittest/memorize_digits.py @@ -62,16 +62,12 @@ class MemorizeDigits(gym.Env): def __init__(self): self.seed() self.viewer = None - self.observation_space = spaces.Box( - 0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8 - ) + self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8) self.action_space = spaces.Discrete(10) self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8) for digit in range(10): for y in range(6): - self.bogus_mnist[digit, y, :] = [ - ord(char) for char in bogus_mnist[digit][y] - ] + self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]] self.reset() def seed(self, seed=None): @@ -93,9 +89,7 @@ class MemorizeDigits(gym.Env): self.color_bg = self.random_color() if self.use_random_colors else color_black self.step_n = 0 while 1: - self.color_digit = ( - self.random_color() if self.use_random_colors else color_white - ) + self.color_digit = self.random_color() if self.use_random_colors else color_white if np.linalg.norm(self.color_digit - self.color_bg) < 50: continue break @@ -119,9 +113,7 @@ class MemorizeDigits(gym.Env): digit_img[:] = self.color_bg xxx = self.bogus_mnist[self.digit] == 42 digit_img[xxx] = self.color_digit - obs[ - self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3 - ] = digit_img + obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img self.last_obs = obs return obs, reward, done, {} diff --git a/gym/error.py b/gym/error.py index 5884e5911..16fa6d35f 100644 --- a/gym/error.py +++ b/gym/error.py @@ -102,10 +102,7 @@ class APIError(Error): try: http_body = http_body.decode("utf-8") except: - http_body = ( - "" - ) + http_body = "" self._message = message self.http_body = http_body @@ -142,9 +139,7 @@ class InvalidRequestError(APIError): json_body=None, headers=None, ): - super(InvalidRequestError, self).__init__( - message, http_body, http_status, json_body, headers - ) + super(InvalidRequestError, self).__init__(message, http_body, http_status, json_body, headers) self.param = param diff --git a/gym/spaces/box.py b/gym/spaces/box.py index 93c78aa32..adebc596e 100644 --- a/gym/spaces/box.py +++ b/gym/spaces/box.py @@ -29,26 +29,16 @@ class Box(Space): # determine shape if it isn't provided directly if shape is not None: shape = tuple(shape) - assert ( - np.isscalar(low) or low.shape == shape - ), "low.shape doesn't match provided shape" - assert ( - np.isscalar(high) or high.shape == shape - ), "high.shape doesn't match provided shape" + assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape" + assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape" elif not np.isscalar(low): shape = low.shape - assert ( - np.isscalar(high) or high.shape == shape - ), "high.shape doesn't match low.shape" + assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape" elif not np.isscalar(high): shape = high.shape - assert ( - np.isscalar(low) or low.shape == shape - ), "low.shape doesn't match high.shape" + assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape" else: - raise ValueError( - "shape must be provided or inferred from the shapes of low or high" - ) + raise ValueError("shape must be provided or inferred from the shapes of low or high") if np.isscalar(low): low = np.full(shape, low, dtype=dtype) @@ -70,9 +60,7 @@ class Box(Space): high_precision = _get_precision(self.high.dtype) dtype_precision = _get_precision(self.dtype) if min(low_precision, high_precision) > dtype_precision: - logger.warn( - "Box bound precision lowered by casting to {}".format(self.dtype) - ) + logger.warn("Box bound precision lowered by casting to {}".format(self.dtype)) self.low = self.low.astype(self.dtype) self.high = self.high.astype(self.dtype) @@ -119,19 +107,11 @@ class Box(Space): # Vectorized sampling by interval type sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape) - sample[low_bounded] = ( - self.np_random.exponential(size=low_bounded[low_bounded].shape) - + self.low[low_bounded] - ) + sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded] - sample[upp_bounded] = ( - -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) - + self.high[upp_bounded] - ) + sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded] - sample[bounded] = self.np_random.uniform( - low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape - ) + sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape) if self.dtype.kind == "i": sample = np.floor(sample) @@ -140,9 +120,7 @@ class Box(Space): def contains(self, x): if isinstance(x, list): x = np.array(x) # Promote list to array for contains check - return ( - x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high) - ) + return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high) def to_jsonable(self, sample_n): return np.array(sample_n).tolist() @@ -151,9 +129,7 @@ class Box(Space): return [np.asarray(sample) for sample in sample_n] def __repr__(self): - return "Box({}, {}, {}, {})".format( - self.low.min(), self.high.max(), self.shape, self.dtype - ) + return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype) def __eq__(self, other): return ( diff --git a/gym/spaces/dict.py b/gym/spaces/dict.py index 67926f444..267a3ff2b 100644 --- a/gym/spaces/dict.py +++ b/gym/spaces/dict.py @@ -33,9 +33,7 @@ class Dict(Space): """ def __init__(self, spaces=None, **spaces_kwargs): - assert (spaces is None) or ( - not spaces_kwargs - ), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" + assert (spaces is None) or (not spaces_kwargs), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" if spaces is None: spaces = spaces_kwargs if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict): @@ -44,12 +42,8 @@ class Dict(Space): spaces = OrderedDict(spaces) self.spaces = spaces for space in spaces.values(): - assert isinstance( - space, Space - ), "Values of the dict should be instances of gym.Space" - super(Dict, self).__init__( - None, None - ) # None for shape and dtype, since it'll require special handling + assert isinstance(space, Space), "Values of the dict should be instances of gym.Space" + super(Dict, self).__init__(None, None) # None for shape and dtype, since it'll require special handling def seed(self, seed=None): [space.seed(seed) for space in self.spaces.values()] @@ -75,18 +69,11 @@ class Dict(Space): yield key def __repr__(self): - return ( - "Dict(" - + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) - + ")" - ) + return "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")" def to_jsonable(self, sample_n): # serialize as dict-repr of vectors - return { - key: space.to_jsonable([sample[key] for sample in sample_n]) - for key, space in self.spaces.items() - } + return {key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items()} def from_jsonable(self, sample_n): dict_of_list = {} diff --git a/gym/spaces/discrete.py b/gym/spaces/discrete.py index bdbecd942..1061bfa9b 100644 --- a/gym/spaces/discrete.py +++ b/gym/spaces/discrete.py @@ -22,9 +22,7 @@ class Discrete(Space): def contains(self, x): if isinstance(x, int): as_int = x - elif isinstance(x, (np.generic, np.ndarray)) and ( - x.dtype.char in np.typecodes["AllInteger"] and x.shape == () - ): + elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()): as_int = int(x) else: return False diff --git a/gym/spaces/multi_discrete.py b/gym/spaces/multi_discrete.py index e4d30a0fa..71cd67706 100644 --- a/gym/spaces/multi_discrete.py +++ b/gym/spaces/multi_discrete.py @@ -34,9 +34,7 @@ class MultiDiscrete(Space): super(MultiDiscrete, self).__init__(self.nvec.shape, dtype) def sample(self): - return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype( - self.dtype - ) + return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype) def contains(self, x): if isinstance(x, list): diff --git a/gym/spaces/tests/test_spaces.py b/gym/spaces/tests/test_spaces.py index 9dd4c6b89..70c00b66c 100644 --- a/gym/spaces/tests/test_spaces.py +++ b/gym/spaces/tests/test_spaces.py @@ -25,9 +25,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), ], @@ -71,9 +69,7 @@ def test_roundtripping(space): Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), ], diff --git a/gym/spaces/tests/test_utils.py b/gym/spaces/tests/test_utils.py index f9adac46d..9281d22c3 100644 --- a/gym/spaces/tests/test_utils.py +++ b/gym/spaces/tests/test_utils.py @@ -28,9 +28,7 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, u Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), 7, @@ -60,18 +58,14 @@ def test_flatdim(space, flatdim): Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), ], ) def test_flatten_space_boxes(space): flat_space = utils.flatten_space(space) - assert isinstance(flat_space, Box), "Expected {} to equal {}".format( - type(flat_space), Box - ) + assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box) flatdim = utils.flatdim(space) (single_dim,) = flat_space.shape assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim) @@ -95,9 +89,7 @@ def test_flatten_space_boxes(space): Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), ], @@ -107,9 +99,7 @@ def test_flat_space_contains_flat_points(space): flattened_samples = [utils.flatten(space, sample) for sample in some_samples] flat_space = utils.flatten_space(space) for i, flat_sample in enumerate(flattened_samples): - assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format( - i, flat_sample, flat_space - ) + assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(i, flat_sample, flat_space) @pytest.mark.parametrize( @@ -130,9 +120,7 @@ def test_flat_space_contains_flat_points(space): Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), ], @@ -162,9 +150,7 @@ def test_flatten_dim(space): Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32), } ), ], @@ -172,15 +158,9 @@ def test_flatten_dim(space): def test_flatten_roundtripping(space): some_samples = [space.sample() for _ in range(10)] flattened_samples = [utils.flatten(space, sample) for sample in some_samples] - roundtripped_samples = [ - utils.unflatten(space, sample) for sample in flattened_samples - ] - for i, (original, roundtripped) in enumerate( - zip(some_samples, roundtripped_samples) - ): - assert compare_nested( - original, roundtripped - ), "Expected sample #{} {} to equal {}".format(i, original, roundtripped) + roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples] + for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)): + assert compare_nested(original, roundtripped), "Expected sample #{} {} to equal {}".format(i, original, roundtripped) def compare_nested(left, right): @@ -188,9 +168,7 @@ def compare_nested(left, right): return np.allclose(left, right) elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict): res = len(left) == len(right) - for ((left_key, left_value), (right_key, right_value)) in zip( - left.items(), right.items() - ): + for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()): if not res: return False res = left_key == right_key and compare_nested(left_value, right_value) @@ -238,9 +216,7 @@ Expecteded flattened types are based off: Dict( { "position": Discrete(5), - "velocity": Box( - low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16 - ), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16), } ), np.float64, @@ -254,12 +230,8 @@ def test_dtypes(original_space, expected_flattened_dtype): flattened_sample = utils.flatten(original_space, original_sample) unflattened_sample = utils.unflatten(original_space, flattened_sample) - assert flattened_space.contains( - flattened_sample - ), "Expected flattened_space to contain flattened_sample" - assert ( - flattened_space.dtype == expected_flattened_dtype - ), "Expected flattened_space's dtype to equal " "{}".format( + assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample" + assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " "{}".format( expected_flattened_dtype ) @@ -272,9 +244,10 @@ def test_dtypes(original_space, expected_flattened_dtype): def compare_sample_types(original_space, original_sample, unflattened_sample): if isinstance(original_space, Discrete): - assert isinstance(unflattened_sample, int), ( - "Expected unflattened_sample to be an int. unflattened_sample: " - "{} original_sample: {}".format(unflattened_sample, original_sample) + assert isinstance( + unflattened_sample, int + ), "Expected unflattened_sample to be an int. unflattened_sample: " "{} original_sample: {}".format( + unflattened_sample, original_sample ) elif isinstance(original_space, Tuple): for index in range(len(original_space)): diff --git a/gym/spaces/tuple.py b/gym/spaces/tuple.py index fec18c41b..488cf8c3b 100644 --- a/gym/spaces/tuple.py +++ b/gym/spaces/tuple.py @@ -13,9 +13,7 @@ class Tuple(Space): def __init__(self, spaces): self.spaces = spaces for space in spaces: - assert isinstance( - space, Space - ), "Elements of the tuple must be instances of gym.Space" + assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space" super(Tuple, self).__init__(None, None) def seed(self, seed=None): @@ -38,21 +36,10 @@ class Tuple(Space): def to_jsonable(self, sample_n): # serialize as list-repr of tuple of vectors - return [ - space.to_jsonable([sample[i] for sample in sample_n]) - for i, space in enumerate(self.spaces) - ] + return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)] def from_jsonable(self, sample_n): - return [ - sample - for sample in zip( - *[ - space.from_jsonable(sample_n[i]) - for i, space in enumerate(self.spaces) - ] - ) - ] + return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])] def __getitem__(self, index): return self.spaces[index] diff --git a/gym/spaces/utils.py b/gym/spaces/utils.py index 1704ff996..d392b6ddc 100644 --- a/gym/spaces/utils.py +++ b/gym/spaces/utils.py @@ -49,9 +49,7 @@ def flatten(space, x): onehot[x] = 1 return onehot elif isinstance(space, Tuple): - return np.concatenate( - [flatten(s, x_part) for x_part, s in zip(x, space.spaces)] - ) + return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)]) elif isinstance(space, Dict): return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) elif isinstance(space, MultiBinary): @@ -79,17 +77,13 @@ def unflatten(space, x): elif isinstance(space, Tuple): dims = [flatdim(s) for s in space.spaces] list_flattened = np.split(x, np.cumsum(dims)[:-1]) - list_unflattened = [ - unflatten(s, flattened) - for flattened, s in zip(list_flattened, space.spaces) - ] + list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)] return tuple(list_unflattened) elif isinstance(space, Dict): dims = [flatdim(s) for s in space.spaces.values()] list_flattened = np.split(x, np.cumsum(dims)[:-1]) list_unflattened = [ - (key, unflatten(s, flattened)) - for flattened, (key, s) in zip(list_flattened, space.spaces.items()) + (key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items()) ] return OrderedDict(list_unflattened) elif isinstance(space, MultiBinary): diff --git a/gym/utils/play.py b/gym/utils/play.py index 2e06f4a32..fc447a14c 100644 --- a/gym/utils/play.py +++ b/gym/utils/play.py @@ -88,11 +88,7 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N elif hasattr(env.unwrapped, "get_keys_to_action"): keys_to_action = env.unwrapped.get_keys_to_action() else: - assert False, ( - env.spec.id - + " does not have explicit key to action mapping, " - + "please specify one manually" - ) + assert False, env.spec.id + " does not have explicit key to action mapping, " + "please specify one manually" relevant_keys = set(sum(map(list, keys_to_action.keys()), [])) video_size = [rendered.shape[1], rendered.shape[0]] @@ -172,9 +168,7 @@ class PlayPlot(object): for i, plot in enumerate(self.cur_plot): if plot is not None: plot.remove() - self.cur_plot[i] = self.ax[i].scatter( - range(xmin, xmax), list(self.data[i]), c="blue" - ) + self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]), c="blue") self.ax[i].set_xlim(xmin, xmax) plt.pause(0.000001) diff --git a/gym/utils/seeding.py b/gym/utils/seeding.py index d7726b01f..c8120f9cd 100644 --- a/gym/utils/seeding.py +++ b/gym/utils/seeding.py @@ -10,9 +10,7 @@ from gym import error def np_random(seed=None): if seed is not None and not (isinstance(seed, int) and 0 <= seed): - raise error.Error( - "Seed must be a non-negative integer or omitted, not {}".format(seed) - ) + raise error.Error("Seed must be a non-negative integer or omitted, not {}".format(seed)) seed = create_seed(seed) diff --git a/gym/vector/__init__.py b/gym/vector/__init__.py index c42a25caf..3ff788a27 100644 --- a/gym/vector/__init__.py +++ b/gym/vector/__init__.py @@ -53,9 +53,7 @@ def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs): if wrappers is not None: if callable(wrappers): env = wrappers(env) - elif isinstance(wrappers, Iterable) and all( - [callable(w) for w in wrappers] - ): + elif isinstance(wrappers, Iterable) and all([callable(w) for w in wrappers]): for wrapper in wrappers: env = wrapper(env) else: diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index b84fc2087..9a32512c7 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -107,12 +107,8 @@ class AsyncVectorEnv(VectorEnv): if self.shared_memory: try: - _obs_buffer = create_shared_memory( - self.single_observation_space, n=self.num_envs, ctx=ctx - ) - self.observations = read_from_shared_memory( - _obs_buffer, self.single_observation_space, n=self.num_envs - ) + _obs_buffer = create_shared_memory(self.single_observation_space, n=self.num_envs, ctx=ctx) + self.observations = read_from_shared_memory(_obs_buffer, self.single_observation_space, n=self.num_envs) except CustomSpaceError: raise ValueError( "Using `shared_memory=True` in `AsyncVectorEnv` " @@ -124,9 +120,7 @@ class AsyncVectorEnv(VectorEnv): ) else: _obs_buffer = None - self.observations = create_empty_array( - self.single_observation_space, n=self.num_envs, fn=np.zeros - ) + self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros) self.parent_pipes, self.processes = [], [] self.error_queue = ctx.Queue() @@ -168,8 +162,7 @@ class AsyncVectorEnv(VectorEnv): if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `seed` while waiting " - "for a pending call to `{0}` to complete.".format(self._state.value), + "Calling `seed` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value), self._state.value, ) @@ -182,8 +175,7 @@ class AsyncVectorEnv(VectorEnv): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `reset_async` while waiting " - "for a pending call to `{0}` to complete".format(self._state.value), + "Calling `reset_async` while waiting " "for a pending call to `{0}` to complete".format(self._state.value), self._state.value, ) @@ -214,8 +206,7 @@ class AsyncVectorEnv(VectorEnv): if not self._poll(timeout): self._state = AsyncState.DEFAULT raise mp.TimeoutError( - "The call to `reset_wait` has timed out after " - "{0} second{1}.".format(timeout, "s" if timeout > 1 else "") + "The call to `reset_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "") ) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) @@ -223,9 +214,7 @@ class AsyncVectorEnv(VectorEnv): self._state = AsyncState.DEFAULT if not self.shared_memory: - self.observations = concatenate( - results, self.observations, self.single_observation_space - ) + self.observations = concatenate(results, self.observations, self.single_observation_space) return deepcopy(self.observations) if self.copy else self.observations @@ -239,8 +228,7 @@ class AsyncVectorEnv(VectorEnv): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `step_async` while waiting " - "for a pending call to `{0}` to complete.".format(self._state.value), + "Calling `step_async` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value), self._state.value, ) @@ -280,8 +268,7 @@ class AsyncVectorEnv(VectorEnv): if not self._poll(timeout): self._state = AsyncState.DEFAULT raise mp.TimeoutError( - "The call to `step_wait` has timed out after " - "{0} second{1}.".format(timeout, "s" if timeout > 1 else "") + "The call to `step_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "") ) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) @@ -290,9 +277,7 @@ class AsyncVectorEnv(VectorEnv): observations_list, rewards, dones, infos = zip(*results) if not self.shared_memory: - self.observations = concatenate( - observations_list, self.observations, self.single_observation_space - ) + self.observations = concatenate(observations_list, self.observations, self.single_observation_space) return ( deepcopy(self.observations) if self.copy else self.observations, @@ -318,8 +303,7 @@ class AsyncVectorEnv(VectorEnv): try: if self._state != AsyncState.DEFAULT: logger.warn( - "Calling `close` while waiting for a pending " - "call to `{0}` to complete.".format(self._state.value) + "Calling `close` while waiting for a pending " "call to `{0}` to complete.".format(self._state.value) ) function = getattr(self, "{0}_wait".format(self._state.value)) function(timeout) @@ -375,8 +359,7 @@ class AsyncVectorEnv(VectorEnv): def _assert_is_running(self): if self.closed: raise ClosedEnvironmentError( - "Trying to operate on `{0}`, after a " - "call to `close()`.".format(type(self).__name__) + "Trying to operate on `{0}`, after a " "call to `close()`.".format(type(self).__name__) ) def _raise_if_errors(self, successes): @@ -387,10 +370,7 @@ class AsyncVectorEnv(VectorEnv): assert num_errors > 0 for _ in range(num_errors): index, exctype, value = self.error_queue.get() - logger.error( - "Received the following error from Worker-{0}: " - "{1}: {2}".format(index, exctype.__name__, value) - ) + logger.error("Received the following error from Worker-{0}: " "{1}: {2}".format(index, exctype.__name__, value)) logger.error("Shutting down Worker-{0}.".format(index)) self.parent_pipes[index].close() self.parent_pipes[index] = None @@ -445,17 +425,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error command, data = pipe.recv() if command == "reset": observation = env.reset() - write_to_shared_memory( - index, observation, shared_memory, observation_space - ) + write_to_shared_memory(index, observation, shared_memory, observation_space) pipe.send((None, True)) elif command == "step": observation, reward, done, info = env.step(data) if done: observation = env.reset() - write_to_shared_memory( - index, observation, shared_memory, observation_space - ) + write_to_shared_memory(index, observation, shared_memory, observation_space) pipe.send(((None, reward, done, info), True)) elif command == "seed": env.seed(data) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 49369073a..8a598a560 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -44,9 +44,7 @@ class SyncVectorEnv(VectorEnv): ) self._check_observation_spaces() - self.observations = create_empty_array( - self.single_observation_space, n=self.num_envs, fn=np.zeros - ) + self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros) self._rewards = np.zeros((self.num_envs,), dtype=np.float64) self._dones = np.zeros((self.num_envs,), dtype=np.bool_) self._actions = None @@ -67,9 +65,7 @@ class SyncVectorEnv(VectorEnv): for env in self.envs: observation = env.reset() observations.append(observation) - self.observations = concatenate( - observations, self.observations, self.single_observation_space - ) + self.observations = concatenate(observations, self.observations, self.single_observation_space) return deepcopy(self.observations) if self.copy else self.observations @@ -84,9 +80,7 @@ class SyncVectorEnv(VectorEnv): observation = env.reset() observations.append(observation) infos.append(info) - self.observations = concatenate( - observations, self.observations, self.single_observation_space - ) + self.observations = concatenate(observations, self.observations, self.single_observation_space) return ( deepcopy(self.observations) if self.copy else self.observations, diff --git a/gym/vector/tests/test_numpy_utils.py b/gym/vector/tests/test_numpy_utils.py index 6c19e7ab4..875031533 100644 --- a/gym/vector/tests/test_numpy_utils.py +++ b/gym/vector/tests/test_numpy_utils.py @@ -10,9 +10,7 @@ from gym.vector.tests.utils import spaces from gym.vector.utils.numpy_utils import concatenate, create_empty_array -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) +@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces]) def test_concatenate(space): def assert_type(lhs, rhs, n): # Special case: if rhs is a list of scalars, lhs must be an np.ndarray @@ -53,9 +51,7 @@ def test_concatenate(space): @pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) +@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces]) def test_create_empty_array(space, n): def assert_nested_type(arr, space, n): if isinstance(space, _BaseGymSpaces): @@ -83,9 +79,7 @@ def test_create_empty_array(space, n): @pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) +@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces]) def test_create_empty_array_zeros(space, n): def assert_nested_type(arr, space, n): if isinstance(space, _BaseGymSpaces): @@ -113,9 +107,7 @@ def test_create_empty_array_zeros(space, n): assert_nested_type(array, space, n=n) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) +@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces]) def test_create_empty_array_none_shape_ones(space): def assert_nested_type(arr, space): if isinstance(space, _BaseGymSpaces): diff --git a/gym/vector/tests/test_shared_memory.py b/gym/vector/tests/test_shared_memory.py index f59a4abd8..892893841 100644 --- a/gym/vector/tests/test_shared_memory.py +++ b/gym/vector/tests/test_shared_memory.py @@ -46,9 +46,7 @@ expected_types = [ list(zip(spaces, expected_types)), ids=[space.__class__.__name__ for space in spaces], ) -@pytest.mark.parametrize( - "ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"] -) +@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]) def test_create_shared_memory(space, expected_type, n, ctx): def assert_nested_type(lhs, rhs, n): assert type(lhs) == type(rhs) @@ -77,9 +75,7 @@ def test_create_shared_memory(space, expected_type, n, ctx): @pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"] -) +@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]) @pytest.mark.parametrize("space", custom_spaces) def test_create_shared_memory_custom_space(n, ctx, space): ctx = mp if (ctx is None) else mp.get_context(ctx) @@ -87,9 +83,7 @@ def test_create_shared_memory_custom_space(n, ctx, space): shared_memory = create_shared_memory(space, n=n, ctx=ctx) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) +@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces]) def test_write_to_shared_memory(space): def assert_nested_equal(lhs, rhs): assert isinstance(rhs, list) @@ -113,9 +107,7 @@ def test_write_to_shared_memory(space): shared_memory_n8 = create_shared_memory(space, n=8) samples = [space.sample() for _ in range(8)] - processes = [ - Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8) - ] + processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)] for process in processes: process.start() @@ -125,25 +117,19 @@ def test_write_to_shared_memory(space): assert_nested_equal(shared_memory_n8, samples) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) +@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces]) def test_read_from_shared_memory(space): def assert_nested_equal(lhs, rhs, space, n): assert isinstance(rhs, list) if isinstance(space, Tuple): assert isinstance(lhs, tuple) for i in range(len(lhs)): - assert_nested_equal( - lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n - ) + assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n) elif isinstance(space, Dict): assert isinstance(lhs, OrderedDict) for key in lhs.keys(): - assert_nested_equal( - lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n - ) + assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n) elif isinstance(space, _BaseGymSpaces): assert isinstance(lhs, np.ndarray) @@ -161,9 +147,7 @@ def test_read_from_shared_memory(space): memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8) samples = [space.sample() for _ in range(8)] - processes = [ - Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8) - ] + processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)] for process in processes: process.start() diff --git a/gym/vector/tests/test_spaces.py b/gym/vector/tests/test_spaces.py index a0738d040..3486231b0 100644 --- a/gym/vector/tests/test_spaces.py +++ b/gym/vector/tests/test_spaces.py @@ -10,12 +10,8 @@ expected_batch_spaces_4 = [ Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64), Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32), Box( - low=np.array( - [[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]] - ), - high=np.array( - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - ), + low=np.array([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]), + high=np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), dtype=np.float32, ), Box( diff --git a/gym/vector/tests/utils.py b/gym/vector/tests/utils.py index 3fdfc84c5..f8996e51e 100644 --- a/gym/vector/tests/utils.py +++ b/gym/vector/tests/utils.py @@ -7,12 +7,8 @@ from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict spaces = [ Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64), Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32), - Box( - low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32 - ), - Box( - low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32 - ), + Box(low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32), + Box(low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32), Box(low=0, high=255, shape=(), dtype=np.uint8), Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), Discrete(2), @@ -28,17 +24,13 @@ spaces = [ Dict( { "position": Discrete(23), - "velocity": Box( - low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32 - ), + "velocity": Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32), } ), Dict( { "position": Dict({"x": Discrete(29), "y": Discrete(31)}), - "velocity": Tuple( - (Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8)) - ), + "velocity": Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))), } ), ] @@ -50,9 +42,7 @@ class UnittestSlowEnv(gym.Env): def __init__(self, slow_reset=0.3): super(UnittestSlowEnv, self).__init__() self.slow_reset = slow_reset - self.observation_space = Box( - low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8 - ) + self.observation_space = Box(low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8) self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32) def reset(self): diff --git a/gym/vector/utils/numpy_utils.py b/gym/vector/utils/numpy_utils.py index 2465ab249..e09b67efa 100644 --- a/gym/vector/utils/numpy_utils.py +++ b/gym/vector/utils/numpy_utils.py @@ -46,10 +46,7 @@ def concatenate(items, out, space): elif isinstance(space, Space): return concatenate_custom(items, out, space) else: - raise ValueError( - "Space of type `{0}` is not a valid `gym.Space` " - "instance.".format(type(space)) - ) + raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space))) def concatenate_base(items, out, space): @@ -57,18 +54,12 @@ def concatenate_base(items, out, space): def concatenate_tuple(items, out, space): - return tuple( - concatenate([item[i] for item in items], out[i], subspace) - for (i, subspace) in enumerate(space.spaces) - ) + return tuple(concatenate([item[i] for item in items], out[i], subspace) for (i, subspace) in enumerate(space.spaces)) def concatenate_dict(items, out, space): return OrderedDict( - [ - (key, concatenate([item[key] for item in items], out[key], subspace)) - for (key, subspace) in space.spaces.items() - ] + [(key, concatenate([item[key] for item in items], out[key], subspace)) for (key, subspace) in space.spaces.items()] ) @@ -118,10 +109,7 @@ def create_empty_array(space, n=1, fn=np.zeros): elif isinstance(space, Space): return create_empty_array_custom(space, n=n, fn=fn) else: - raise ValueError( - "Space of type `{0}` is not a valid `gym.Space` " - "instance.".format(type(space)) - ) + raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space))) def create_empty_array_base(space, n=1, fn=np.zeros): @@ -134,12 +122,7 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros): def create_empty_array_dict(space, n=1, fn=np.zeros): - return OrderedDict( - [ - (key, create_empty_array(subspace, n=n, fn=fn)) - for (key, subspace) in space.spaces.items() - ] - ) + return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn)) for (key, subspace) in space.spaces.items()]) def create_empty_array_custom(space, n=1, fn=np.zeros): diff --git a/gym/vector/utils/shared_memory.py b/gym/vector/utils/shared_memory.py index 1021c139f..ebcedc65a 100644 --- a/gym/vector/utils/shared_memory.py +++ b/gym/vector/utils/shared_memory.py @@ -56,18 +56,11 @@ def create_base_shared_memory(space, n=1, ctx=mp): def create_tuple_shared_memory(space, n=1, ctx=mp): - return tuple( - create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces - ) + return tuple(create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces) def create_dict_shared_memory(space, n=1, ctx=mp): - return OrderedDict( - [ - (key, create_shared_memory(subspace, n=n, ctx=ctx)) - for (key, subspace) in space.spaces.items() - ] - ) + return OrderedDict([(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()]) def read_from_shared_memory(shared_memory, space, n=1): @@ -114,24 +107,16 @@ def read_from_shared_memory(shared_memory, space, n=1): def read_base_from_shared_memory(shared_memory, space, n=1): - return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( - (n,) + space.shape - ) + return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape((n,) + space.shape) def read_tuple_from_shared_memory(shared_memory, space, n=1): - return tuple( - read_from_shared_memory(memory, subspace, n=n) - for (memory, subspace) in zip(shared_memory, space.spaces) - ) + return tuple(read_from_shared_memory(memory, subspace, n=n) for (memory, subspace) in zip(shared_memory, space.spaces)) def read_dict_from_shared_memory(shared_memory, space, n=1): return OrderedDict( - [ - (key, read_from_shared_memory(shared_memory[key], subspace, n=n)) - for (key, subspace) in space.spaces.items() - ] + [(key, read_from_shared_memory(shared_memory[key], subspace, n=n)) for (key, subspace) in space.spaces.items()] ) diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index ac4727de5..2101ffb48 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -44,8 +44,7 @@ def batch_space(space, n=1): return batch_space_custom(space, n=n) else: raise ValueError( - "Cannot batch space with type `{0}`. The space must " - "be a valid `gym.Space` instance.".format(type(space)) + "Cannot batch space with type `{0}`. The space must " "be a valid `gym.Space` instance.".format(type(space)) ) @@ -75,14 +74,7 @@ def batch_space_tuple(space, n=1): def batch_space_dict(space, n=1): - return Dict( - OrderedDict( - [ - (key, batch_space(subspace, n=n)) - for (key, subspace) in space.spaces.items() - ] - ) - ) + return Dict(OrderedDict([(key, batch_space(subspace, n=n)) for (key, subspace) in space.spaces.items()])) def batch_space_custom(space, n=1): diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 375826ff0..2d2f3b06a 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -141,9 +141,7 @@ class VectorEnv(gym.Env): if self.spec is None: return "{}({})".format(self.__class__.__name__, self.num_envs) else: - return "{}({}, {})".format( - self.__class__.__name__, self.spec.id, self.num_envs - ) + return "{}({}, {})".format(self.__class__.__name__, self.spec.id, self.num_envs) class VectorEnvWrapper(VectorEnv): @@ -189,9 +187,7 @@ class VectorEnvWrapper(VectorEnv): # implicitly forward all other methods and attributes to self.env def __getattr__(self, name): if name.startswith("_"): - raise AttributeError( - "attempted to get missing private attribute '{}'".format(name) - ) + raise AttributeError("attempted to get missing private attribute '{}'".format(name)) return getattr(self.env, name) @property diff --git a/gym/wrappers/atari_preprocessing.py b/gym/wrappers/atari_preprocessing.py index bbf155d71..018f95467 100644 --- a/gym/wrappers/atari_preprocessing.py +++ b/gym/wrappers/atari_preprocessing.py @@ -62,12 +62,13 @@ class AtariPreprocessing(gym.Wrapper): assert noop_max >= 0 if frame_skip > 1: assert "NoFrameskip" in env.spec.id, ( - "disable frame-skipping in the original env. for more than one" - " frame-skip as it will be done by the wrapper" + "disable frame-skipping in the original env. for more than one" " frame-skip as it will be done by the wrapper" ) self.noop_max = noop_max assert env.unwrapped.get_action_meanings()[0] == "NOOP" - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) self.frame_skip = frame_skip self.screen_size = screen_size @@ -92,15 +93,11 @@ class AtariPreprocessing(gym.Wrapper): self.lives = 0 self.game_over = False - _low, _high, _obs_dtype = ( - (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32) - ) + _low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32) _shape = (screen_size, screen_size, 1 if grayscale_obs else 3) if grayscale_obs and not grayscale_newaxis: _shape = _shape[:-1] # Remove channel axis - self.observation_space = Box( - low=_low, high=_high, shape=_shape, dtype=_obs_dtype - ) + self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype) def step(self, action): R = 0.0 @@ -132,11 +129,7 @@ class AtariPreprocessing(gym.Wrapper): def reset(self, **kwargs): # NoopReset self.env.reset(**kwargs) - noops = ( - self.env.unwrapped.np_random.randint(1, self.noop_max + 1) - if self.noop_max > 0 - else 0 - ) + noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0 for _ in range(noops): _, _, done, _ = self.env.step(0) if done: diff --git a/gym/wrappers/clip_action.py b/gym/wrappers/clip_action.py index cb9a1c28c..bb096f13c 100644 --- a/gym/wrappers/clip_action.py +++ b/gym/wrappers/clip_action.py @@ -9,7 +9,9 @@ class ClipAction(ActionWrapper): def __init__(self, env): assert isinstance(env.action_space, Box) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) super(ClipAction, self).__init__(env) def action(self, action): diff --git a/gym/wrappers/filter_observation.py b/gym/wrappers/filter_observation.py index c3c31c684..8c1d3fac8 100644 --- a/gym/wrappers/filter_observation.py +++ b/gym/wrappers/filter_observation.py @@ -26,7 +26,9 @@ class FilterObservation(ObservationWrapper): assert isinstance( wrapped_observation_space, spaces.Dict ), "FilterObservationWrapper is only usable with dict observations." - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) observation_keys = wrapped_observation_space.spaces.keys() @@ -49,11 +51,7 @@ class FilterObservation(ObservationWrapper): ) self.observation_space = type(wrapped_observation_space)( - [ - (name, copy.deepcopy(space)) - for name, space in wrapped_observation_space.spaces.items() - if name in filter_keys - ] + [(name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys] ) self._env = env @@ -64,11 +62,5 @@ class FilterObservation(ObservationWrapper): return filter_observation def _filter_observation(self, observation): - observation = type(observation)( - [ - (name, value) - for name, value in observation.items() - if name in self._filter_keys - ] - ) + observation = type(observation)([(name, value) for name, value in observation.items() if name in self._filter_keys]) return observation diff --git a/gym/wrappers/flatten_observation.py b/gym/wrappers/flatten_observation.py index d19e57adc..b78cb4248 100644 --- a/gym/wrappers/flatten_observation.py +++ b/gym/wrappers/flatten_observation.py @@ -9,7 +9,9 @@ class FlattenObservation(ObservationWrapper): def __init__(self, env): super(FlattenObservation, self).__init__(env) self.observation_space = spaces.flatten_space(env.observation_space) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) def observation(self, observation): return spaces.flatten(self.env.observation_space, observation) diff --git a/gym/wrappers/frame_stack.py b/gym/wrappers/frame_stack.py index b8fa0762e..cad4dc064 100644 --- a/gym/wrappers/frame_stack.py +++ b/gym/wrappers/frame_stack.py @@ -22,7 +22,9 @@ class LazyFrames(object): __slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames") def __init__(self, frames, lz4_compress=False): - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) self.frame_shape = tuple(frames[0].shape) self.shape = (len(frames),) + self.frame_shape self.dtype = frames[0].dtype @@ -45,9 +47,7 @@ class LazyFrames(object): def __getitem__(self, int_or_slice): if isinstance(int_or_slice, int): return self._check_decompress(self._frames[int_or_slice]) # single frame - return np.stack( - [self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0 - ) + return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0) def __eq__(self, other): return self.__array__() == other @@ -56,9 +56,7 @@ class LazyFrames(object): if self.lz4_compress: from lz4.block import decompress - return np.frombuffer(decompress(frame), dtype=self.dtype).reshape( - self.frame_shape - ) + return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape) return frame @@ -102,12 +100,8 @@ class FrameStack(Wrapper): self.frames = deque(maxlen=num_stack) low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0) - high = np.repeat( - self.observation_space.high[np.newaxis, ...], num_stack, axis=0 - ) - self.observation_space = Box( - low=low, high=high, dtype=self.observation_space.dtype - ) + high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0) + self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype) def _get_observation(self): assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) diff --git a/gym/wrappers/gray_scale_observation.py b/gym/wrappers/gray_scale_observation.py index be455ecde..890e4766b 100644 --- a/gym/wrappers/gray_scale_observation.py +++ b/gym/wrappers/gray_scale_observation.py @@ -11,20 +11,15 @@ class GrayScaleObservation(ObservationWrapper): super(GrayScaleObservation, self).__init__(env) self.keep_dim = keep_dim - assert ( - len(env.observation_space.shape) == 3 - and env.observation_space.shape[-1] == 3 + assert len(env.observation_space.shape) == 3 and env.observation_space.shape[-1] == 3 + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" ) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") obs_shape = self.observation_space.shape[:2] if self.keep_dim: - self.observation_space = Box( - low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 - ) + self.observation_space = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8) else: - self.observation_space = Box( - low=0, high=255, shape=obs_shape, dtype=np.uint8 - ) + self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) def observation(self, observation): import cv2 diff --git a/gym/wrappers/monitor.py b/gym/wrappers/monitor.py index 9239d5d09..842f5724f 100644 --- a/gym/wrappers/monitor.py +++ b/gym/wrappers/monitor.py @@ -37,9 +37,7 @@ class Monitor(Wrapper): self._monitor_id = None self.env_semantics_autoreset = env.metadata.get("semantics.autoreset") - self._start( - directory, video_callable, force, resume, write_upon_reset, uid, mode - ) + self._start(directory, video_callable, force, resume, write_upon_reset, uid, mode) def step(self, action): self._before_step(action) @@ -163,10 +161,7 @@ class Monitor(Wrapper): json.dump( { "stats": os.path.basename(self.stats_recorder.path), - "videos": [ - (os.path.basename(v), os.path.basename(m)) - for v, m in self.videos - ], + "videos": [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos], "env_info": self._env_info(), }, f, @@ -199,9 +194,7 @@ class Monitor(Wrapper): elif mode == "training": type = "t" else: - raise error.Error( - 'Invalid mode {}: must be "training" or "evaluation"', mode - ) + raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode) self.stats_recorder.type = type def _before_step(self, action): @@ -257,9 +250,7 @@ class Monitor(Wrapper): env=self.env, base_path=os.path.join( self.directory, - "{}.video.{}.video{:06}".format( - self.file_prefix, self.file_infix, self.episode_id - ), + "{}.video.{}.video{:06}".format(self.file_prefix, self.file_infix, self.episode_id), ), metadata={"episode_id": self.episode_id}, enabled=self._video_enabled(), @@ -269,9 +260,7 @@ class Monitor(Wrapper): def _close_video_recorder(self): self.video_recorder.close() if self.video_recorder.functional: - self.videos.append( - (self.video_recorder.path, self.video_recorder.metadata_path) - ) + self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path)) def _video_enabled(self): return self.video_callable(self.episode_id) @@ -301,19 +290,11 @@ class Monitor(Wrapper): def detect_training_manifests(training_dir, files=None): if files is None: files = os.listdir(training_dir) - return [ - os.path.join(training_dir, f) - for f in files - if f.startswith(MANIFEST_PREFIX + ".") - ] + return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + ".")] def detect_monitor_files(training_dir): - return [ - os.path.join(training_dir, f) - for f in os.listdir(training_dir) - if f.startswith(FILE_PREFIX + ".") - ] + return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + ".")] def clear_monitor_files(training_dir): @@ -382,10 +363,7 @@ def load_results(training_dir): contents = json.load(f) # Make these paths absolute again stats_files.append(os.path.join(training_dir, contents["stats"])) - videos += [ - (os.path.join(training_dir, v), os.path.join(training_dir, m)) - for v, m in contents["videos"] - ] + videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m)) for v, m in contents["videos"]] env_infos.append(contents["env_info"]) env_info = collapse_env_infos(env_infos, training_dir) diff --git a/gym/wrappers/monitoring/video_recorder.py b/gym/wrappers/monitoring/video_recorder.py index b19363199..3d5d5c523 100644 --- a/gym/wrappers/monitoring/video_recorder.py +++ b/gym/wrappers/monitoring/video_recorder.py @@ -48,9 +48,7 @@ class VideoRecorder(object): self.ansi_mode = True else: logger.info( - 'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format( - env - ) + 'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env) ) # Whoops, turns out we shouldn't be enabled after all self.enabled = False @@ -69,9 +67,7 @@ class VideoRecorder(object): path = base_path + required_ext else: # Otherwise, just generate a unique filename - with tempfile.NamedTemporaryFile( - suffix=required_ext, delete=False - ) as f: + with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f: path = f.name self.path = path @@ -83,28 +79,20 @@ class VideoRecorder(object): if self.ansi_mode else "" ) - raise error.Error( - "Invalid path given: {} -- must have file extension {}.{}".format( - self.path, required_ext, hint - ) - ) + raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint)) # Touch the file in any case, so we know it's present. (This # corrects for platform platform differences. Using ffmpeg on # OS X, the file is precreated, but not on Linux. touch(path) self.frames_per_sec = env.metadata.get("video.frames_per_second", 30) - self.output_frames_per_sec = env.metadata.get( - "video.output_frames_per_second", self.frames_per_sec - ) + self.output_frames_per_sec = env.metadata.get("video.output_frames_per_second", self.frames_per_sec) self.encoder = None # lazily start the process self.broken = False # Dump metadata self.metadata = metadata or {} - self.metadata["content_type"] = ( - "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4" - ) + self.metadata["content_type"] = "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4" self.metadata_path = "{}.meta.json".format(path_base) self.write_metadata() @@ -191,9 +179,7 @@ class VideoRecorder(object): def _encode_image_frame(self, frame): if not self.encoder: - self.encoder = ImageEncoder( - self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec - ) + self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec) self.metadata["encoder_version"] = self.encoder.version_info try: @@ -222,24 +208,16 @@ class TextEncoder(object): string = frame.getvalue() else: raise error.InvalidFrame( - "Wrong type {} for {}: text frame must be a string or StringIO".format( - type(frame), frame - ) + "Wrong type {} for {}: text frame must be a string or StringIO".format(type(frame), frame) ) frame_bytes = string.encode("utf-8") if frame_bytes[-1:] != b"\n": - raise error.InvalidFrame( - 'Frame must end with a newline: """{}"""'.format(string) - ) + raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string)) if b"\r" in frame_bytes: - raise error.InvalidFrame( - 'Frame contains carriage returns (only newlines are allowed: """{}"""'.format( - string - ) - ) + raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string)) self.frames.append(frame_bytes) @@ -263,15 +241,7 @@ class TextEncoder(object): # Calculate frame size from the largest frames. # Add some padding since we'll get cut off otherwise. height = max([frame.count(b"\n") for frame in self.frames]) + 1 - width = ( - max( - [ - max([len(line) for line in frame.split(b"\n")]) - for frame in self.frames - ] - ) - + 2 - ) + width = max([max([len(line) for line in frame.split(b"\n")]) for frame in self.frames]) + 2 data = { "version": 1, @@ -325,11 +295,7 @@ class ImageEncoder(object): def version_info(self): return { "backend": self.backend, - "version": str( - subprocess.check_output( - [self.backend, "-version"], stderr=subprocess.STDOUT - ) - ), + "version": str(subprocess.check_output([self.backend, "-version"], stderr=subprocess.STDOUT)), "cmdline": self.cmdline, } @@ -396,19 +362,13 @@ class ImageEncoder(object): logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline)) if hasattr(os, "setsid"): # setsid not present on Windows - self.proc = subprocess.Popen( - self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid - ) + self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid) else: self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE) def capture_frame(self, frame): if not isinstance(frame, (np.ndarray, np.generic)): - raise error.InvalidFrame( - "Wrong type {} for {} (must be np.ndarray or np.generic)".format( - type(frame), frame - ) - ) + raise error.InvalidFrame("Wrong type {} for {} (must be np.ndarray or np.generic)".format(type(frame), frame)) if frame.shape != self.frame_shape: raise error.InvalidFrame( "Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format( @@ -417,15 +377,11 @@ class ImageEncoder(object): ) if frame.dtype != np.uint8: raise error.InvalidFrame( - "Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format( - frame.dtype - ) + "Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype) ) try: - if distutils.version.LooseVersion( - np.__version__ - ) >= distutils.version.LooseVersion("1.9.0"): + if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion("1.9.0"): self.proc.stdin.write(frame.tobytes()) else: self.proc.stdin.write(frame.tostring()) diff --git a/gym/wrappers/pixel_observation.py b/gym/wrappers/pixel_observation.py index 5267510b9..fce81c25e 100644 --- a/gym/wrappers/pixel_observation.py +++ b/gym/wrappers/pixel_observation.py @@ -14,9 +14,7 @@ STATE_KEY = "state" class PixelObservationWrapper(ObservationWrapper): """Augment observations by pixel values.""" - def __init__( - self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",) - ): + def __init__(self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)): """Initializes a new pixel Wrapper. Args: @@ -52,7 +50,9 @@ class PixelObservationWrapper(ObservationWrapper): assert render_mode == "rgb_array", render_mode render_kwargs[key]["mode"] = "rgb_array" - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) wrapped_observation_space = env.observation_space @@ -70,9 +70,7 @@ class PixelObservationWrapper(ObservationWrapper): # `observation_keys` overlapping_keys = set(pixel_keys) & set(invalid_keys) if overlapping_keys: - raise ValueError( - "Duplicate or reserved pixel keys {!r}.".format(overlapping_keys) - ) + raise ValueError("Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)) if pixels_only: self.observation_space = spaces.Dict() @@ -95,9 +93,7 @@ class PixelObservationWrapper(ObservationWrapper): else: raise TypeError(pixels.dtype) - pixels_space = spaces.Box( - shape=pixels.shape, low=low, high=high, dtype=pixels.dtype - ) + pixels_space = spaces.Box(shape=pixels.shape, low=low, high=high, dtype=pixels.dtype) pixels_spaces[pixel_key] = pixels_space self.observation_space.spaces.update(pixels_spaces) @@ -120,10 +116,7 @@ class PixelObservationWrapper(ObservationWrapper): observation = collections.OrderedDict() observation[STATE_KEY] = wrapped_observation - pixel_observations = { - pixel_key: self.env.render(**self._render_kwargs[pixel_key]) - for pixel_key in self._pixel_keys - } + pixel_observations = {pixel_key: self.env.render(**self._render_kwargs[pixel_key]) for pixel_key in self._pixel_keys} observation.update(pixel_observations) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index 863dae292..44e2fb50d 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -7,14 +7,14 @@ import gym class RecordEpisodeStatistics(gym.Wrapper): def __init__(self, env, deque_size=100): super(RecordEpisodeStatistics, self).__init__(env) - self.t0 = ( - time.time() - ) # TODO: use perf_counter when gym removes Python 2 support + self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support self.episode_return = 0.0 self.episode_length = 0 self.return_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) def reset(self, **kwargs): observation = super(RecordEpisodeStatistics, self).reset(**kwargs) @@ -23,9 +23,7 @@ class RecordEpisodeStatistics(gym.Wrapper): return observation def step(self, action): - observation, reward, done, info = super(RecordEpisodeStatistics, self).step( - action - ) + observation, reward, done, info = super(RecordEpisodeStatistics, self).step(action) self.episode_return += reward self.episode_length += 1 if done: diff --git a/gym/wrappers/rescale_action.py b/gym/wrappers/rescale_action.py index 6a74b8610..8826f5b3a 100644 --- a/gym/wrappers/rescale_action.py +++ b/gym/wrappers/rescale_action.py @@ -15,17 +15,15 @@ class RescaleAction(gym.ActionWrapper): """ def __init__(self, env, a, b): - assert isinstance( - env.action_space, spaces.Box - ), "expected Box action space, got {}".format(type(env.action_space)) + assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space)) assert np.less_equal(a, b).all(), (a, b) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) super(RescaleAction, self).__init__(env) self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b - self.action_space = spaces.Box( - low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype - ) + self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype) def action(self, action): assert np.all(np.greater_equal(action, self.a)), (action, self.a) diff --git a/gym/wrappers/resize_observation.py b/gym/wrappers/resize_observation.py index c16881199..9c5f0734a 100644 --- a/gym/wrappers/resize_observation.py +++ b/gym/wrappers/resize_observation.py @@ -12,7 +12,9 @@ class ResizeObservation(ObservationWrapper): if isinstance(shape, int): shape = (shape, shape) assert all(x > 0 for x in shape), shape - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) self.shape = tuple(shape) obs_shape = self.shape + self.observation_space.shape[2:] @@ -21,9 +23,7 @@ class ResizeObservation(ObservationWrapper): def observation(self, observation): import cv2 - observation = cv2.resize( - observation, self.shape[::-1], interpolation=cv2.INTER_AREA - ) + observation = cv2.resize(observation, self.shape[::-1], interpolation=cv2.INTER_AREA) if observation.ndim == 2: observation = np.expand_dims(observation, -1) return observation diff --git a/gym/wrappers/test_atari_preprocessing.py b/gym/wrappers/test_atari_preprocessing.py index 0f5ab7c0d..bca20746f 100644 --- a/gym/wrappers/test_atari_preprocessing.py +++ b/gym/wrappers/test_atari_preprocessing.py @@ -15,12 +15,8 @@ def test_atari_preprocessing_grayscale(env_fn): import cv2 env1 = env_fn() - env2 = AtariPreprocessing( - env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0 - ) - env3 = AtariPreprocessing( - env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0 - ) + env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0) + env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0) env4 = AtariPreprocessing( env_fn(), screen_size=84, @@ -79,15 +75,11 @@ def test_atari_preprocessing_scale(env_fn): obs = env.reset().flatten() done, step_i = False, 0 max_obs = 1 if scaled else 255 - assert (0 <= obs).all() and ( - obs <= max_obs - ).all(), "Obs. must be in range [0,{}]".format(max_obs) + assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs) while not done or step_i <= max_test_steps: obs, _, done, _ = env.step(env.action_space.sample()) obs = obs.flatten() - assert (0 <= obs).all() and ( - obs <= max_obs - ).all(), "Obs. must be in range [0,{}]".format(max_obs) + assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs) step_i += 1 env.close() diff --git a/gym/wrappers/test_clip_action.py b/gym/wrappers/test_clip_action.py index 392a40ea8..2e6d5cc2e 100644 --- a/gym/wrappers/test_clip_action.py +++ b/gym/wrappers/test_clip_action.py @@ -19,9 +19,7 @@ def test_clip_action(): actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] for action in actions: - obs1, r1, d1, _ = env.step( - np.clip(action, env.action_space.low, env.action_space.high) - ) + obs1, r1, d1, _ = env.step(np.clip(action, env.action_space.low, env.action_space.high)) obs2, r2, d2, _ = wrapped_env.step(action) assert np.allclose(r1, r2) assert np.allclose(obs1, obs2) diff --git a/gym/wrappers/test_filter_observation.py b/gym/wrappers/test_filter_observation.py index 35beaad56..c936c9588 100644 --- a/gym/wrappers/test_filter_observation.py +++ b/gym/wrappers/test_filter_observation.py @@ -9,10 +9,7 @@ from gym.wrappers.filter_observation import FilterObservation class FakeEnvironment(gym.Env): def __init__(self, observation_keys=("state")): self.observation_space = spaces.Dict( - { - name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) - for name in observation_keys - } + {name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) for name in observation_keys} ) self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) @@ -48,9 +45,7 @@ ERROR_TEST_CASES = ( class TestFilterObservation(object): - @pytest.mark.parametrize( - "observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES - ) + @pytest.mark.parametrize("observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES) def test_filter_observation(self, observation_keys, filter_keys): env = FakeEnvironment(observation_keys=observation_keys) @@ -73,9 +68,7 @@ class TestFilterObservation(object): assert len(observation) == len(filter_keys) @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES) - def test_raises_with_incorrect_arguments( - self, filter_keys, error_type, error_match - ): + def test_raises_with_incorrect_arguments(self, filter_keys, error_type, error_match): env = FakeEnvironment(observation_keys=("key1", "key2")) ValueError diff --git a/gym/wrappers/test_flatten_observation.py b/gym/wrappers/test_flatten_observation.py index f19008107..72c3f71c1 100644 --- a/gym/wrappers/test_flatten_observation.py +++ b/gym/wrappers/test_flatten_observation.py @@ -16,14 +16,10 @@ def test_flatten_observation(env_id): wrapped_obs = wrapped_env.reset() if env_id == "Blackjack-v0": - space = spaces.Tuple( - (spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)) - ) + space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))) wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32) elif env_id == "KellyCoinflip-v0": - space = spaces.Tuple( - (spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)) - ) + space = spaces.Tuple((spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))) wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32) assert space.contains(obs) diff --git a/gym/wrappers/test_frame_stack.py b/gym/wrappers/test_frame_stack.py index f509f7d6b..467892ac5 100644 --- a/gym/wrappers/test_frame_stack.py +++ b/gym/wrappers/test_frame_stack.py @@ -19,9 +19,7 @@ except ImportError: [ pytest.param( True, - marks=pytest.mark.skipif( - lz4 is None, reason="Need lz4 to run tests with compression" - ), + marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression"), ), False, ], diff --git a/gym/wrappers/test_gray_scale_observation.py b/gym/wrappers/test_gray_scale_observation.py index cf2176ade..c56d0a6ea 100644 --- a/gym/wrappers/test_gray_scale_observation.py +++ b/gym/wrappers/test_gray_scale_observation.py @@ -10,9 +10,7 @@ pytest.importorskip("atari_py") pytest.importorskip("cv2") -@pytest.mark.parametrize( - "env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"] -) +@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]) @pytest.mark.parametrize("keep_dim", [True, False]) def test_gray_scale_observation(env_id, keep_dim): gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True) diff --git a/gym/wrappers/test_pixel_observation.py b/gym/wrappers/test_pixel_observation.py index 8e3e7f218..b510a9bd0 100644 --- a/gym/wrappers/test_pixel_observation.py +++ b/gym/wrappers/test_pixel_observation.py @@ -32,9 +32,7 @@ class FakeEnvironment(gym.Env): class FakeArrayObservationEnvironment(FakeEnvironment): def __init__(self, *args, **kwargs): - self.observation_space = spaces.Box( - shape=(2,), low=-1, high=1, dtype=np.float32 - ) + self.observation_space = spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs) @@ -75,10 +73,7 @@ class TestPixelObservationWrapper(object): assert len(wrapped_env.observation_space.spaces) == 1 assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key] else: - assert ( - len(wrapped_env.observation_space.spaces) - == len(observation_space.spaces) + 1 - ) + assert len(wrapped_env.observation_space.spaces) == len(observation_space.spaces) + 1 expected_keys = list(observation_space.spaces.keys()) + [pixel_key] assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys @@ -97,9 +92,7 @@ class TestPixelObservationWrapper(object): observation_space = env.observation_space assert isinstance(observation_space, spaces.Box) - wrapped_env = PixelObservationWrapper( - env, pixel_keys=(pixel_key,), pixels_only=pixels_only - ) + wrapped_env = PixelObservationWrapper(env, pixel_keys=(pixel_key,), pixels_only=pixels_only) wrapped_env.observation_space = wrapped_env.observation_space assert isinstance(wrapped_env.observation_space, spaces.Dict) diff --git a/gym/wrappers/test_resize_observation.py b/gym/wrappers/test_resize_observation.py index 8cf6f2048..a3cc295f9 100644 --- a/gym/wrappers/test_resize_observation.py +++ b/gym/wrappers/test_resize_observation.py @@ -9,12 +9,8 @@ except ImportError: atari_py = None -@pytest.mark.skipif( - atari_py is None, reason="Only run this test when atari_py is installed" -) -@pytest.mark.parametrize( - "env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"] -) +@pytest.mark.skipif(atari_py is None, reason="Only run this test when atari_py is installed") +@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]) @pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]]) def test_resize_observation(env_id, shape): env = gym.make(env_id) diff --git a/gym/wrappers/test_transform_observation.py b/gym/wrappers/test_transform_observation.py index 8c43cfb68..c4410eca6 100644 --- a/gym/wrappers/test_transform_observation.py +++ b/gym/wrappers/test_transform_observation.py @@ -10,9 +10,7 @@ from gym.wrappers import TransformObservation def test_transform_observation(env_id): affine_transform = lambda x: 3 * x + 2 env = gym.make(env_id) - wrapped_env = TransformObservation( - gym.make(env_id), lambda obs: affine_transform(obs) - ) + wrapped_env = TransformObservation(gym.make(env_id), lambda obs: affine_transform(obs)) env.seed(0) wrapped_env.seed(0) diff --git a/gym/wrappers/time_aware_observation.py b/gym/wrappers/time_aware_observation.py index 8b0b864b2..f10e01b55 100644 --- a/gym/wrappers/time_aware_observation.py +++ b/gym/wrappers/time_aware_observation.py @@ -17,7 +17,9 @@ class TimeAwareObservation(ObservationWrapper): super(TimeAwareObservation, self).__init__(env) assert isinstance(env.observation_space, Box) assert env.observation_space.dtype == np.float32 - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) low = np.append(self.observation_space.low, 0.0) high = np.append(self.observation_space.high, np.inf) self.observation_space = Box(low, high, dtype=np.float32) diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index a0ef43dde..d3106e9df 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -5,7 +5,9 @@ import warnings class TimeLimit(gym.Wrapper): def __init__(self, env, max_episode_steps=None): super(TimeLimit, self).__init__(env) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) if max_episode_steps is None and self.env.spec is not None: max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: @@ -14,9 +16,7 @@ class TimeLimit(gym.Wrapper): self._elapsed_steps = None def step(self, action): - assert ( - self._elapsed_steps is not None - ), "Cannot call env.step() before calling reset()" + assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()" observation, reward, done, info = self.env.step(action) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: diff --git a/gym/wrappers/transform_observation.py b/gym/wrappers/transform_observation.py index 5b56f8fab..02f4eac2c 100644 --- a/gym/wrappers/transform_observation.py +++ b/gym/wrappers/transform_observation.py @@ -22,7 +22,9 @@ class TransformObservation(ObservationWrapper): def __init__(self, env, f): super(TransformObservation, self).__init__(env) assert callable(f) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) self.f = f def observation(self, observation): diff --git a/gym/wrappers/transform_reward.py b/gym/wrappers/transform_reward.py index be44abf76..20dde050f 100644 --- a/gym/wrappers/transform_reward.py +++ b/gym/wrappers/transform_reward.py @@ -24,7 +24,9 @@ class TransformReward(RewardWrapper): def __init__(self, env, f): super(TransformReward, self).__init__(env) assert callable(f) - warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") + warnings.warn( + "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit" + ) self.f = f def reward(self, reward): diff --git a/scripts/generate_json.py b/scripts/generate_json.py index b5aa706d6..4baae1c9c 100644 --- a/scripts/generate_json.py +++ b/scripts/generate_json.py @@ -15,9 +15,7 @@ steps = ROLLOUT_STEPS ROLLOUT_FILE = os.path.join(DATA_DIR, "rollout.json") if not os.path.isfile(ROLLOUT_FILE): - logger.info( - "No rollout file found. Writing empty json file to {}".format(ROLLOUT_FILE) - ) + logger.info("No rollout file found. Writing empty json file to {}".format(ROLLOUT_FILE)) with open(ROLLOUT_FILE, "w") as outfile: json.dump({}, outfile, indent=2) @@ -50,9 +48,7 @@ def update_rollout_dict(spec, rollout_dict): except: # If running the env generates an exception, don't write to the rollout file logger.warn( - "Exception {} thrown while generating rollout for {}. Rollout not added.".format( - sys.exc_info()[0], spec.id - ) + "Exception {} thrown while generating rollout for {}. Rollout not added.".format(sys.exc_info()[0], spec.id) ) return False @@ -78,9 +74,7 @@ def update_rollout_dict(spec, rollout_dict): def add_new_rollouts(spec_ids, overwrite): - environments = [ - spec for spec in envs.registry.all() if spec.entry_point is not None - ] + environments = [spec for spec in envs.registry.all() if spec.entry_point is not None] if spec_ids: environments = [spec for spec in environments if spec.id in spec_ids] assert len(environments) == len(spec_ids), "Some specs not found" @@ -110,9 +104,7 @@ if __name__ == "__main__": help="Overwrite " + "existing rollouts if hashes differ.", ) parser.add_argument("-v", "--verbose", action="store_true") - parser.add_argument( - "specs", nargs="*", help="ids of env specs to check (default: all)" - ) + parser.add_argument("specs", nargs="*", help="ids of env specs to check (default: all)") args = parser.parse_args() if args.verbose: logger.set_level(logger.INFO) diff --git a/setup.py b/setup.py index f7bcc8493..49ac4870c 100644 --- a/setup.py +++ b/setup.py @@ -18,14 +18,7 @@ extras = { # Meta dependency groups. extras["nomujoco"] = list( - set( - [ - item - for name, group in extras.items() - if name != "mujoco" and name != "robotics" - for item in group - ] - ) + set([item for name, group in extras.items() if name != "mujoco" and name != "robotics" for item in group]) ) extras["all"] = list(set([item for group in extras.values() for item in group])) diff --git a/tests/gym/wrappers/nested_dict_test.py b/tests/gym/wrappers/nested_dict_test.py index a390675b5..13e1ae67c 100644 --- a/tests/gym/wrappers/nested_dict_test.py +++ b/tests/gym/wrappers/nested_dict_test.py @@ -85,9 +85,7 @@ NESTED_DICT_TEST_CASES = ( ( Dict( { - "key1": Tuple( - (Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),) - ), + "key1": Tuple((Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),)), "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), }