redo black (#2272)

This commit is contained in:
J K Terry
2021-07-29 15:39:42 -04:00
committed by GitHub
parent ff8c269abb
commit 78d2b512d8
104 changed files with 1350 additions and 418 deletions

View File

@@ -3,7 +3,9 @@ import argparse
import gym
parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.")
parser = argparse.ArgumentParser(
description="Renders a Gym environment for quick inspection."
)
parser.add_argument(
"env_id",
type=str,

View File

@@ -35,7 +35,12 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0):
th_std = np.ones_like(th_mean) * initial_std
for _ in range(n_iter):
ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)])
ths = np.array(
[
th_mean + dth
for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)
]
)
ys = np.array([f(th) for th in ths])
elite_inds = ys.argsort()[::-1][:n_elite]
elite_ths = ths[elite_inds]
@@ -96,7 +101,9 @@ if __name__ == "__main__":
return rew
# Train the agent, and snapshot each stage
for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)):
for (i, iterdata) in enumerate(
cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)
):
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
if args.display:

View File

@@ -17,7 +17,9 @@ class RandomAgent(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=None)
parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run")
parser.add_argument(
"env_id", nargs="?", default="CartPole-v0", help="Select the environment to run"
)
args = parser.parse_args()
# You can set the level to logger.DEBUG or logger.WARN if you

View File

@@ -173,10 +173,16 @@ class GoalEnv(Env):
def reset(self):
# Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict")
raise error.Error(
"GoalEnv requires an observation space of type gym.spaces.Dict"
)
for key in ["observation", "achieved_goal", "desired_goal"]:
if key not in self.observation_space.spaces:
raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))
raise error.Error(
'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(
key
)
)
def compute_reward(self, achieved_goal, desired_goal, info):
"""Compute the step reward. This externalizes the reward function and makes
@@ -221,7 +227,9 @@ class Wrapper(Env):
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
return getattr(self.env, name)
@property

View File

@@ -422,7 +422,9 @@ for reward_type in ["sparse", "dense"]:
register(
id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
entry_point="gym.envs.robotics:HandBlockEnv",
kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs),
kwargs=_merge(
{"target_position": "ignore", "target_rotation": "parallel"}, kwargs
),
max_episode_steps=100,
)

View File

@@ -73,7 +73,9 @@ class AlgorithmicEnv(Env):
# 1. Move read head left or right (or up/down)
# 2. Write or not
# 3. Which character to write. (Ignored if should_write=0)
self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
self.action_space = Tuple(
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
)
# Can see just what is on the input tape (one of n characters, or
# nothing)
self.observation_space = Discrete(self.base + 1)
@@ -145,7 +147,10 @@ class AlgorithmicEnv(Env):
move = self.MOVEMENTS[inp_act]
outfile.write("Action : Tuple(move over input: %s,\n" % move)
out_act = out_act == 1
outfile.write(" write to the output tape: %s,\n" % out_act)
outfile.write(
" write to the output tape: %s,\n"
% out_act
)
outfile.write(" prediction: %s)\n" % pred_str)
else:
outfile.write("\n" * 5)
@@ -271,7 +276,9 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
x_str = "Observation Tape : "
for i in range(-2, self.input_width + 2):
if i == x:
x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
x_str += colorize(
self._get_str_obs(np.array([i])), "green", highlight=True
)
else:
x_str += self._get_str_obs(np.array([i]))
x_str += "\n"
@@ -304,7 +311,10 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
self.read_head_position = x, y
def generate_input_data(self, size):
return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]
return [
[self.np_random.randint(self.base) for _ in range(self.rows)]
for __ in range(size)
]
def _get_obs(self, pos=None):
if pos is None:
@@ -326,7 +336,9 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
x_str += " " * len(label)
for i in range(-2, self.input_width + 2):
if i == x[0] and j == x[1]:
x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
x_str += colorize(
self._get_str_obs((i, j)), "green", highlight=True
)
else:
x_str += self._get_str_obs((i, j))
x_str += "\n"

View File

@@ -10,8 +10,12 @@ ALL_ENVS = [
alg.reverse.ReverseEnv,
alg.reversed_addition.ReversedAdditionEnv,
]
ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)]
ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)]
ALL_TAPE_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)
]
ALL_GRID_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)
]
def imprint(env, input_arr):
@@ -88,7 +92,10 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
def test_grid_naviation(self):
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]]
N, S, E, W = [
env._movement_idx(named_dir)
for named_dir in ["up", "down", "right", "left"]
]
# Corresponds to a grid that looks like...
# 0 1 2
# 3 4 5
@@ -197,7 +204,9 @@ class TestTargets(unittest.TestCase):
def test_repeat_copy_target(self):
env = alg.repeat_copy.RepeatCopyEnv()
self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2])
self.assertEqual(
env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]
)
class TestInputGeneration(unittest.TestCase):

View File

@@ -9,7 +9,8 @@ try:
import atari_py
except ImportError as e:
raise error.DependencyNotInstalled(
"{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e)
"{}. (HINT: you can install Atari dependencies by running "
"'pip install gym[atari]'.)".format(e)
)
@@ -63,23 +64,35 @@ class AtariEnv(gym.Env, utils.EzPickle):
# Tune (or disable) ALE's action repeat:
# https://github.com/openai/gym/issues/349
assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format(
repeat_action_probability
assert isinstance(
repeat_action_probability, (float, int)
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
self.ale.setFloat(
"repeat_action_probability".encode("utf-8"), repeat_action_probability
)
self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability)
self.seed()
self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet()
self._action_set = (
self.ale.getLegalActionSet()
if full_action_space
else self.ale.getMinimalActionSet()
)
self.action_space = spaces.Discrete(len(self._action_set))
(screen_width, screen_height) = self.ale.getScreenDims()
if self._obs_type == "ram":
self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,))
self.observation_space = spaces.Box(
low=0, high=255, dtype=np.uint8, shape=(128,)
)
elif self._obs_type == "image":
self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8)
self.observation_space = spaces.Box(
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
)
else:
raise error.Error("Unrecognized observation type: {}".format(self._obs_type))
raise error.Error(
"Unrecognized observation type: {}".format(self._obs_type)
)
def seed(self, seed=None):
self.np_random, seed1 = seeding.np_random(seed)
@@ -94,9 +107,9 @@ class AtariEnv(gym.Env, utils.EzPickle):
if self.game_mode is not None:
modes = self.ale.getAvailableModes()
assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format(
self.game_mode, self.game, modes
)
assert self.game_mode in modes, (
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
).format(self.game_mode, self.game, modes)
self.ale.setMode(self.game_mode)
if self.game_difficulty is not None:

View File

@@ -100,7 +100,10 @@ class ContactDetector(contactListener):
self.env = env
def BeginContact(self, contact):
if self.env.hull == contact.fixtureA.body or self.env.hull == contact.fixtureB.body:
if (
self.env.hull == contact.fixtureA.body
or self.env.hull == contact.fixtureB.body
):
self.env.game_over = True
for leg in [self.env.legs[1], self.env.legs[3]]:
if leg in [contact.fixtureA.body, contact.fixtureB.body]:
@@ -199,7 +202,9 @@ class BipedalWalker(gym.Env, EzPickle):
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
self.terrain.append(t)
self.fd_polygon.shape.vertices = [(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly]
self.fd_polygon.shape.vertices = [
(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly
]
t = self.world.CreateStaticBody(fixtures=self.fd_polygon)
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
self.terrain.append(t)
@@ -296,8 +301,12 @@ class BipedalWalker(gym.Env, EzPickle):
y = VIEWPORT_H / SCALE * 3 / 4
poly = [
(
x + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
y + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
x
+ 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5)
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
y
+ 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5)
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
)
for a in range(5)
]
@@ -322,10 +331,14 @@ class BipedalWalker(gym.Env, EzPickle):
init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2
init_y = TERRAIN_HEIGHT + 2 * LEG_H
self.hull = self.world.CreateDynamicBody(position=(init_x, init_y), fixtures=HULL_FD)
self.hull = self.world.CreateDynamicBody(
position=(init_x, init_y), fixtures=HULL_FD
)
self.hull.color1 = (0.5, 0.4, 0.9)
self.hull.color2 = (0.3, 0.3, 0.5)
self.hull.ApplyForceToCenter((self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True)
self.hull.ApplyForceToCenter(
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
)
self.legs = []
self.joints = []
@@ -399,13 +412,21 @@ class BipedalWalker(gym.Env, EzPickle):
self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1))
else:
self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0]))
self.joints[0].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1))
self.joints[0].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1)
)
self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1]))
self.joints[1].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1))
self.joints[1].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1)
)
self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2]))
self.joints[2].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1))
self.joints[2].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1)
)
self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3]))
self.joints[3].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1))
self.joints[3].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1)
)
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
@@ -444,8 +465,12 @@ class BipedalWalker(gym.Env, EzPickle):
self.scroll = pos.x - VIEWPORT_W / SCALE / 5
shaping = 130 * pos[0] / SCALE # moving forward is a way to receive reward (normalized to get 300 on completion)
shaping -= 5.0 * abs(state[0]) # keep head straight, other than that and falling, any behavior is unpunished
shaping = (
130 * pos[0] / SCALE
) # moving forward is a way to receive reward (normalized to get 300 on completion)
shaping -= 5.0 * abs(
state[0]
) # keep head straight, other than that and falling, any behavior is unpunished
reward = 0
if self.prev_shaping is not None:
@@ -469,7 +494,9 @@ class BipedalWalker(gym.Env, EzPickle):
if self.viewer is None:
self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H)
self.viewer.set_bounds(self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE)
self.viewer.set_bounds(
self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE
)
self.viewer.draw_polygon(
[
@@ -485,7 +512,9 @@ class BipedalWalker(gym.Env, EzPickle):
continue
if x1 > self.scroll / 2 + VIEWPORT_W / SCALE:
continue
self.viewer.draw_polygon([(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1))
self.viewer.draw_polygon(
[(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1)
)
for poly, color in self.terrain_poly:
if poly[1][0] < self.scroll:
continue
@@ -496,7 +525,11 @@ class BipedalWalker(gym.Env, EzPickle):
self.lidar_render = (self.lidar_render + 1) % 100
i = self.lidar_render
if i < 2 * len(self.lidar):
l = self.lidar[i] if i < len(self.lidar) else self.lidar[len(self.lidar) - i - 1]
l = (
self.lidar[i]
if i < len(self.lidar)
else self.lidar[len(self.lidar) - i - 1]
)
self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1)
for obj in self.drawlist:
@@ -504,8 +537,12 @@ class BipedalWalker(gym.Env, EzPickle):
trans = f.body.transform
if type(f.shape) is circleShape:
t = rendering.Transform(translation=trans * f.shape.pos)
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t)
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 30, color=obj.color1
).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2
).add_attr(t)
else:
path = [trans * v for v in f.shape.vertices]
self.viewer.draw_polygon(path, color=obj.color1)
@@ -515,7 +552,9 @@ class BipedalWalker(gym.Env, EzPickle):
flagy1 = TERRAIN_HEIGHT
flagy2 = flagy1 + 50 / SCALE
x = TERRAIN_STEP * 3
self.viewer.draw_polyline([(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2)
self.viewer.draw_polyline(
[(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2
)
f = [
(x, flagy2),
(x, flagy2 - 10 / SCALE),

View File

@@ -23,7 +23,9 @@ from Box2D.b2 import (
SIZE = 0.02
ENGINE_POWER = 100000000 * SIZE * SIZE
WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE
FRICTION_LIMIT = 1000000 * SIZE * SIZE # friction ~= mass ~= size^2 (calculated implicitly using density)
FRICTION_LIMIT = (
1000000 * SIZE * SIZE
) # friction ~= mass ~= size^2 (calculated implicitly using density)
WHEEL_R = 27
WHEEL_W = 14
WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)]
@@ -53,19 +55,27 @@ class Car:
angle=init_angle,
fixtures=[
fixtureDef(
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]),
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]
),
density=1.0,
),
fixtureDef(
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]),
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]
),
density=1.0,
),
fixtureDef(
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]),
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]
),
density=1.0,
),
fixtureDef(
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]),
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]
),
density=1.0,
),
],
@@ -85,7 +95,12 @@ class Car:
position=(init_x + wx * SIZE, init_y + wy * SIZE),
angle=init_angle,
fixtures=fixtureDef(
shape=polygonShape(vertices=[(x * front_k * SIZE, y * front_k * SIZE) for x, y in WHEEL_POLY]),
shape=polygonShape(
vertices=[
(x * front_k * SIZE, y * front_k * SIZE)
for x, y in WHEEL_POLY
]
),
density=0.1,
categoryBits=0x0020,
maskBits=0x001,
@@ -160,7 +175,9 @@ class Car:
grass = True
friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile
for tile in w.tiles:
friction_limit = max(friction_limit, FRICTION_LIMIT * tile.road_friction)
friction_limit = max(
friction_limit, FRICTION_LIMIT * tile.road_friction
)
grass = False
# Force
@@ -175,7 +192,13 @@ class Car:
# domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega
# add small coef not to divide by zero
w.omega += dt * ENGINE_POWER * w.gas / WHEEL_MOMENT_OF_INERTIA / (abs(w.omega) + 5.0)
w.omega += (
dt
* ENGINE_POWER
* w.gas
/ WHEEL_MOMENT_OF_INERTIA
/ (abs(w.omega) + 5.0)
)
self.fuel_spent += dt * ENGINE_POWER * w.gas
if w.brake >= 0.9:
@@ -203,12 +226,18 @@ class Car:
# Skid trace
if abs(force) > 2.0 * friction_limit:
if w.skid_particle and w.skid_particle.grass == grass and len(w.skid_particle.poly) < 30:
if (
w.skid_particle
and w.skid_particle.grass == grass
and len(w.skid_particle.poly) < 30
):
w.skid_particle.poly.append((w.position[0], w.position[1]))
elif w.skid_start is None:
w.skid_start = w.position
else:
w.skid_particle = self._create_particle(w.skid_start, w.position, grass)
w.skid_particle = self._create_particle(
w.skid_start, w.position, grass
)
w.skid_start = None
else:
w.skid_start = None

View File

@@ -132,14 +132,18 @@ class CarRacing(gym.Env, EzPickle):
self.reward = 0.0
self.prev_reward = 0.0
self.verbose = verbose
self.fd_tile = fixtureDef(shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)]))
self.fd_tile = fixtureDef(
shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])
)
self.action_space = spaces.Box(
np.array([-1, 0, 0]).astype(np.float32),
np.array([+1, +1, +1]).astype(np.float32),
) # steer, gas, brake
self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8)
self.observation_space = spaces.Box(
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
@@ -242,7 +246,9 @@ class CarRacing(gym.Env, EzPickle):
i -= 1
if i == 0:
return False # Failed
pass_through_start = track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
pass_through_start = (
track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
)
if pass_through_start and i2 == -1:
i2 = i
elif pass_through_start and i1 == -1:
@@ -260,7 +266,8 @@ class CarRacing(gym.Env, EzPickle):
first_perp_y = math.sin(first_beta)
# Length of perpendicular jump to put together head and tail
well_glued_together = np.sqrt(
np.square(first_perp_x * (track[0][2] - track[-1][2])) + np.square(first_perp_y * (track[0][3] - track[-1][3]))
np.square(first_perp_x * (track[0][2] - track[-1][2]))
+ np.square(first_perp_y * (track[0][3] - track[-1][3]))
)
if well_glued_together > TRACK_DETAIL_STEP:
return False
@@ -330,7 +337,9 @@ class CarRacing(gym.Env, EzPickle):
x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
)
self.road_poly.append(([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0)))
self.road_poly.append(
([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0))
)
self.track = track
return True
@@ -347,7 +356,10 @@ class CarRacing(gym.Env, EzPickle):
if success:
break
if self.verbose == 1:
print("retry to generate track (normal if there are not many" "instances of this message)")
print(
"retry to generate track (normal if there are not many"
"instances of this message)"
)
self.car = Car(self.world, *self.track[0][1:4])
return self.step(None)[0]
@@ -412,8 +424,10 @@ class CarRacing(gym.Env, EzPickle):
angle = math.atan2(vel[0], vel[1])
self.transform.set_scale(zoom, zoom)
self.transform.set_translation(
WINDOW_W / 2 - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
WINDOW_H / 4 - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
WINDOW_W / 2
- (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
WINDOW_H / 4
- (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
)
self.transform.set_rotation(angle)
@@ -435,7 +449,9 @@ class CarRacing(gym.Env, EzPickle):
else:
pixel_scale = 1
if hasattr(win.context, "_nscontext"):
pixel_scale = win.context._nscontext.view().backingScaleFactor() # pylint: disable=protected-access
pixel_scale = (
win.context._nscontext.view().backingScaleFactor()
) # pylint: disable=protected-access
VP_W = int(pixel_scale * WINDOW_W)
VP_H = int(pixel_scale * WINDOW_H)
@@ -452,7 +468,9 @@ class CarRacing(gym.Env, EzPickle):
win.flip()
return self.viewer.isopen
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
image_data = (
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
)
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
arr = arr.reshape(VP_H, VP_W, 4)
arr = arr[::-1, :, 0:3]
@@ -507,7 +525,9 @@ class CarRacing(gym.Env, EzPickle):
for p in poly:
polygons_.extend([p[0], p[1], 0])
vl = pyglet.graphics.vertex_list(len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)) # gl.GL_QUADS,
vl = pyglet.graphics.vertex_list(
len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)
) # gl.GL_QUADS,
vl.draw(gl.GL_QUADS)
vl.delete()
@@ -555,7 +575,10 @@ class CarRacing(gym.Env, EzPickle):
]
)
true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1]))
true_speed = np.sqrt(
np.square(self.car.hull.linearVelocity[0])
+ np.square(self.car.hull.linearVelocity[1])
)
vertical_ind(5, 0.02 * true_speed, (1, 1, 1))
vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors
@@ -564,7 +587,9 @@ class CarRacing(gym.Env, EzPickle):
vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1))
horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0))
horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0))
vl = pyglet.graphics.vertex_list(len(polygons) // 3, ("v3f", polygons), ("c4f", colors)) # gl.GL_QUADS,
vl = pyglet.graphics.vertex_list(
len(polygons) // 3, ("v3f", polygons), ("c4f", colors)
) # gl.GL_QUADS,
vl.draw(gl.GL_QUADS)
vl.delete()
self.score_label.text = "%04i" % self.reward

View File

@@ -71,7 +71,10 @@ class ContactDetector(contactListener):
self.env = env
def BeginContact(self, contact):
if self.env.lander == contact.fixtureA.body or self.env.lander == contact.fixtureB.body:
if (
self.env.lander == contact.fixtureA.body
or self.env.lander == contact.fixtureB.body
):
self.env.game_over = True
for i in range(2):
if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
@@ -101,7 +104,9 @@ class LunarLander(gym.Env, EzPickle):
self.prev_reward = None
# useful range is -1 .. +1, but spikes can be higher
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32)
self.observation_space = spaces.Box(
-np.inf, np.inf, shape=(8,), dtype=np.float32
)
if self.continuous:
# Action is two floats [main engine, left-right engines].
@@ -152,9 +157,14 @@ class LunarLander(gym.Env, EzPickle):
height[CHUNKS // 2 + 0] = self.helipad_y
height[CHUNKS // 2 + 1] = self.helipad_y
height[CHUNKS // 2 + 2] = self.helipad_y
smooth_y = [0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) for i in range(CHUNKS)]
smooth_y = [
0.33 * (height[i - 1] + height[i + 0] + height[i + 1])
for i in range(CHUNKS)
]
self.moon = self.world.CreateStaticBody(shapes=edgeShape(vertices=[(0, 0), (W, 0)]))
self.moon = self.world.CreateStaticBody(
shapes=edgeShape(vertices=[(0, 0), (W, 0)])
)
self.sky_polys = []
for i in range(CHUNKS - 1):
p1 = (chunk_x[i], smooth_y[i])
@@ -170,7 +180,9 @@ class LunarLander(gym.Env, EzPickle):
position=(VIEWPORT_W / SCALE / 2, initial_y),
angle=0.0,
fixtures=fixtureDef(
shape=polygonShape(vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]),
shape=polygonShape(
vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]
),
density=5.0,
friction=0.1,
categoryBits=0x0010,
@@ -215,7 +227,9 @@ class LunarLander(gym.Env, EzPickle):
motorSpeed=+0.3 * i, # low enough not to jump back into the sky
)
if i == -1:
rjd.lowerAngle = +0.9 - 0.5 # The most esoteric numbers here, angled legs have freedom to travel within
rjd.lowerAngle = (
+0.9 - 0.5
) # The most esoteric numbers here, angled legs have freedom to travel within
rjd.upperAngle = +0.9
else:
rjd.lowerAngle = -0.9
@@ -264,7 +278,9 @@ class LunarLander(gym.Env, EzPickle):
dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]
m_power = 0.0
if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2):
if (self.continuous and action[0] > 0.0) or (
not self.continuous and action == 2
):
# Main engine
if self.continuous:
m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0
@@ -294,7 +310,9 @@ class LunarLander(gym.Env, EzPickle):
)
s_power = 0.0
if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]):
if (self.continuous and np.abs(action[1]) > 0.5) or (
not self.continuous and action in [1, 3]
):
# Orientation engines
if self.continuous:
direction = np.sign(action[1])
@@ -303,8 +321,12 @@ class LunarLander(gym.Env, EzPickle):
else:
direction = action - 2
s_power = 1.0
ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
ox = tip[0] * dispersion[0] + side[0] * (
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
)
oy = -tip[1] * dispersion[0] - side[1] * (
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
)
impulse_pos = (
self.lander.position[0] + ox - tip[0] * 17 / SCALE,
self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
@@ -350,7 +372,9 @@ class LunarLander(gym.Env, EzPickle):
reward = shaping - self.prev_shaping
self.prev_shaping = shaping
reward -= m_power * 0.30 # less fuel spent is better, about -30 for heuristic landing
reward -= (
m_power * 0.30
) # less fuel spent is better, about -30 for heuristic landing
reward -= s_power * 0.03
done = False
@@ -392,8 +416,12 @@ class LunarLander(gym.Env, EzPickle):
trans = f.body.transform
if type(f.shape) is circleShape:
t = rendering.Transform(translation=trans * f.shape.pos)
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color1).add_attr(t)
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 20, color=obj.color1
).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2
).add_attr(t)
else:
path = [trans * v for v in f.shape.vertices]
self.viewer.draw_polygon(path, color=obj.color1)
@@ -451,14 +479,18 @@ def heuristic(env, s):
angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad
if angle_targ < -0.4:
angle_targ = -0.4
hover_targ = 0.55 * np.abs(s[0]) # target y should be proportional to horizontal offset
hover_targ = 0.55 * np.abs(
s[0]
) # target y should be proportional to horizontal offset
angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5
if s[6] or s[7]: # legs have contact
angle_todo = 0
hover_todo = -(s[3]) * 0.5 # override to reduce fall speed, that's all we need after contact
hover_todo = (
-(s[3]) * 0.5
) # override to reduce fall speed, that's all we need after contact
if env.continuous:
a = np.array([hover_todo * 20 - 1, -angle_todo * 20])

View File

@@ -88,7 +88,9 @@ class AcrobotEnv(core.Env):
def __init__(self):
self.viewer = None
high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32)
high = np.array(
[1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32
)
low = -high
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
self.action_space = spaces.Discrete(3)
@@ -109,7 +111,9 @@ class AcrobotEnv(core.Env):
# Add noise to the force action
if self.torque_noise_max > 0:
torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max)
torque += self.np_random.uniform(
-self.torque_noise_max, self.torque_noise_max
)
# Now, augment the state with our force action so it can be passed to
# _dsdt
@@ -156,7 +160,12 @@ class AcrobotEnv(core.Env):
theta2 = s[1]
dtheta1 = s[2]
dtheta2 = s[3]
d1 = m1 * lc1 ** 2 + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2
d1 = (
m1 * lc1 ** 2
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2))
+ I1
+ I2
)
d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2
phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
phi1 = (
@@ -172,9 +181,9 @@ class AcrobotEnv(core.Env):
else:
# the following line is consistent with the java implementation and the
# book
ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2) / (
m2 * lc2 ** 2 + I2 - d2 ** 2 / d1
)
ddtheta2 = (
a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)

View File

@@ -111,7 +111,9 @@ class CartPoleEnv(gym.Env):
# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
temp = (
force + self.polemass_length * theta_dot ** 2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
)

View File

@@ -62,17 +62,27 @@ class Continuous_MountainCarEnv(gym.Env):
self.min_position = -1.2
self.max_position = 0.6
self.max_speed = 0.07
self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
self.goal_position = (
0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
)
self.goal_velocity = goal_velocity
self.power = 0.0015
self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)
self.low_state = np.array(
[self.min_position, -self.max_speed], dtype=np.float32
)
self.high_state = np.array(
[self.max_position, self.max_speed], dtype=np.float32
)
self.viewer = None
self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32)
self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32)
self.action_space = spaces.Box(
low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32
)
self.observation_space = spaces.Box(
low=self.low_state, high=self.high_state, dtype=np.float32
)
self.seed()
self.reset()
@@ -149,11 +159,15 @@ class Continuous_MountainCarEnv(gym.Env):
self.viewer.add_geom(car)
frontwheel = rendering.make_circle(carheight / 2.5)
frontwheel.set_color(0.5, 0.5, 0.5)
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
frontwheel.add_attr(
rendering.Transform(translation=(carwidth / 4, clearance))
)
frontwheel.add_attr(self.cartrans)
self.viewer.add_geom(frontwheel)
backwheel = rendering.make_circle(carheight / 2.5)
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
backwheel.add_attr(
rendering.Transform(translation=(-carwidth / 4, clearance))
)
backwheel.add_attr(self.cartrans)
backwheel.set_color(0.5, 0.5, 0.5)
self.viewer.add_geom(backwheel)
@@ -162,12 +176,16 @@ class Continuous_MountainCarEnv(gym.Env):
flagy2 = flagy1 + 50
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
self.viewer.add_geom(flagpole)
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
flag = rendering.FilledPolygon(
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
)
flag.set_color(0.8, 0.8, 0)
self.viewer.add_geom(flag)
pos = self.state[0]
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
self.cartrans.set_translation(
(pos - self.min_position) * scale, self._height(pos) * scale
)
self.cartrans.set_rotation(math.cos(3 * pos))
return self.viewer.render(return_rgb_array=mode == "rgb_array")

View File

@@ -136,11 +136,15 @@ class MountainCarEnv(gym.Env):
self.viewer.add_geom(car)
frontwheel = rendering.make_circle(carheight / 2.5)
frontwheel.set_color(0.5, 0.5, 0.5)
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
frontwheel.add_attr(
rendering.Transform(translation=(carwidth / 4, clearance))
)
frontwheel.add_attr(self.cartrans)
self.viewer.add_geom(frontwheel)
backwheel = rendering.make_circle(carheight / 2.5)
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
backwheel.add_attr(
rendering.Transform(translation=(-carwidth / 4, clearance))
)
backwheel.add_attr(self.cartrans)
backwheel.set_color(0.5, 0.5, 0.5)
self.viewer.add_geom(backwheel)
@@ -149,12 +153,16 @@ class MountainCarEnv(gym.Env):
flagy2 = flagy1 + 50
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
self.viewer.add_geom(flagpole)
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
flag = rendering.FilledPolygon(
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
)
flag.set_color(0.8, 0.8, 0)
self.viewer.add_geom(flag)
pos = self.state[0]
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
self.cartrans.set_translation(
(pos - self.min_position) * scale, self._height(pos) * scale
)
self.cartrans.set_rotation(math.cos(3 * pos))
return self.viewer.render(return_rgb_array=mode == "rgb_array")

View File

@@ -18,7 +18,9 @@ class PendulumEnv(gym.Env):
self.viewer = None
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
self.action_space = spaces.Box(
low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32
)
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
self.seed()
@@ -39,7 +41,10 @@ class PendulumEnv(gym.Env):
self.last_u = u # for rendering
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2)
newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
newthdot = (
thdot
+ (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
)
newth = th + newthdot * dt
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)

View File

@@ -54,7 +54,11 @@ def get_display(spec):
elif isinstance(spec, str):
return pyglet.canvas.Display(spec)
else:
raise error.Error("Invalid display specification: {}. (Must be a string like :0 or None.)".format(spec))
raise error.Error(
"Invalid display specification: {}. (Must be a string like :0 or None.)".format(
spec
)
)
def get_window(width, height, display, **kwargs):
@@ -65,7 +69,14 @@ def get_window(width, height, display, **kwargs):
config = screen[0].get_best_config() # selecting the first screen
context = config.create_context(None) # create GL context
return pyglet.window.Window(width=width, height=height, display=display, config=config, context=context, **kwargs)
return pyglet.window.Window(
width=width,
height=height,
display=display,
config=config,
context=context,
**kwargs
)
class Viewer(object):
@@ -97,7 +108,9 @@ class Viewer(object):
assert right > left and top > bottom
scalex = self.width / (right - left)
scaley = self.height / (top - bottom)
self.transform = Transform(translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley))
self.transform = Transform(
translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley)
)
def add_geom(self, geom):
self.geoms.append(geom)
@@ -160,7 +173,9 @@ class Viewer(object):
def get_array(self):
self.window.flip()
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
image_data = (
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
)
self.window.flip()
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
arr = arr.reshape(self.height, self.width, 4)
@@ -215,7 +230,9 @@ class Transform(Attr):
def enable(self):
glPushMatrix()
glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
glTranslatef(
self.translation[0], self.translation[1], 0
) # translate to GL loc ppint
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
glScalef(self.scale[0], self.scale[1], 1)
@@ -375,7 +392,9 @@ class Image(Geom):
self.flip = False
def render1(self):
self.img.blit(-self.width / 2, -self.height / 2, width=self.width, height=self.height)
self.img.blit(
-self.width / 2, -self.height / 2, width=self.width, height=self.height
)
# ================================================================
@@ -416,7 +435,9 @@ class SimpleImageViewer(object):
self.isopen = False
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3)
image = pyglet.image.ImageData(
arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3
)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
texture = image.get_texture()
texture.width = self.width

View File

@@ -14,7 +14,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
xposafter = self.get_body_com("torso")[0]
forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = 0.5 * np.square(a).sum()
contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
contact_cost = (
0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
)
survive_reward = 1.0
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self.state_vector()
@@ -43,7 +45,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1)
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-0.1, high=0.1
)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -34,13 +34,18 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@property
def healthy_reward(self):
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -55,7 +60,9 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@property
def contact_cost(self):
contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces))
contact_cost = self._contact_cost_weight * np.sum(
np.square(self.contact_forces)
)
return contact_cost
@property
@@ -121,8 +128,12 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
)
self.set_state(qpos, qvel)
observation = self._get_obs()

View File

@@ -28,7 +28,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
qpos = self.init_qpos + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -25,7 +25,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@@ -69,8 +71,12 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
)
self.set_state(qpos, qvel)

View File

@@ -17,16 +17,27 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward += alive_bonus
reward -= 1e-3 * np.square(a).sum()
s = self.state_vector()
done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2))
done = not (
np.isfinite(s).all()
and (np.abs(s[2:]) < 100).all()
and (height > 0.7)
and (abs(ang) < 0.2)
)
ob = self._get_obs()
return ob, reward, done, {}
def _get_obs(self):
return np.concatenate([self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)])
return np.concatenate(
[self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -40,13 +40,18 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@property
def healthy_reward(self):
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -112,8 +117,12 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel)

View File

@@ -43,13 +43,18 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@property
def healthy_reward(self):
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl))
@@ -138,8 +143,12 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel)
observation = self._get_obs()

View File

@@ -33,7 +33,8 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self):
self.set_state(
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qpos
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
)
return self._get_obs()

View File

@@ -17,8 +17,12 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, {}
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-0.01, high=0.01
)
qvel = self.init_qvel + self.np_random.uniform(
size=self.model.nv, low=-0.01, high=0.01
)
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -22,7 +22,14 @@ DEFAULT_SIZE = 500
def convert_observation_to_space(observation):
if isinstance(observation, dict):
space = spaces.Dict(OrderedDict([(key, convert_observation_to_space(value)) for key, value in observation.items()]))
space = spaces.Dict(
OrderedDict(
[
(key, convert_observation_to_space(value))
for key, value in observation.items()
]
)
)
elif isinstance(observation, np.ndarray):
low = np.full(observation.shape, -float("inf"), dtype=np.float32)
high = np.full(observation.shape, float("inf"), dtype=np.float32)
@@ -110,7 +117,9 @@ class MujocoEnv(gym.Env):
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
old_state = self.sim.get_state()
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state)
new_state = mujoco_py.MjSimState(
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
)
self.sim.set_state(new_state)
self.sim.forward()
@@ -133,7 +142,10 @@ class MujocoEnv(gym.Env):
):
if mode == "rgb_array" or mode == "depth_array":
if camera_id is not None and camera_name is not None:
raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.")
raise ValueError(
"Both `camera_id` and `camera_name` cannot be"
" specified at the same time."
)
no_camera_specified = camera_name is None and camera_id is None
if no_camera_specified:

View File

@@ -44,7 +44,9 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos[-4:-2] = self.cylinder_pos
qpos[-2:] = self.goal_pos
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qvel[-4:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -22,13 +22,18 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.viewer.cam.trackbodyid = 0
def reset_model(self):
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
qpos = (
self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
+ self.init_qpos
)
while True:
self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
if np.linalg.norm(self.goal) < 0.2:
break
qpos[-2:] = self.goal
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qvel[-2:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -62,7 +62,9 @@ class StrikerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
diff = self.ball - self.goal
angle = -np.arctan(diff[0] / (diff[1] + 1e-8))
qpos[-1] = angle / 3.14
qvel = self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nv
)
qvel[7:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -26,7 +26,9 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self):
self.set_state(
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
self.init_qpos
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
)
return self._get_obs()

View File

@@ -22,7 +22,9 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@@ -72,8 +74,12 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel)

View File

@@ -48,7 +48,9 @@ class ThrowerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
)
qpos[-9:-7] = self.goal
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qvel[7:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -27,8 +27,10 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self):
self.set_state(
self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
self.init_qpos
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
self.init_qvel
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
)
return self._get_obs()

View File

@@ -37,13 +37,18 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@property
def healthy_reward(self):
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -105,8 +110,12 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel)

View File

@@ -108,7 +108,11 @@ class EnvRegistry(object):
# reset/step. Set _gym_disable_underscore_compat = True on
# your environment if you use these methods and don't want
# compatibility code to be invoked.
if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False):
if (
hasattr(env, "_reset")
and hasattr(env, "_step")
and not getattr(env, "_gym_disable_underscore_compat", False)
):
patch_deprecated_methods(env)
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit
@@ -154,7 +158,11 @@ class EnvRegistry(object):
if env_name == valid_env_spec._env_name
]
if matching_envs:
raise error.DeprecatedEnv("Env {} not found (valid versions include {})".format(id, matching_envs))
raise error.DeprecatedEnv(
"Env {} not found (valid versions include {})".format(
id, matching_envs
)
)
else:
raise error.UnregisteredEnv("No registered env with id: {}".format(id))

View File

@@ -81,7 +81,9 @@ class FetchEnv(robot_env.RobotEnv):
def _set_action(self, action):
assert action.shape == (4,)
action = action.copy() # ensure that we don't change the action outside of this scope
action = (
action.copy()
) # ensure that we don't change the action outside of this scope
pos_ctrl, gripper_ctrl = action[:3], action[3]
pos_ctrl *= 0.05 # limit maximum change in position
@@ -118,9 +120,13 @@ class FetchEnv(robot_env.RobotEnv):
object_rel_pos = object_pos - grip_pos
object_velp -= grip_velp
else:
object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
object_pos = (
object_rot
) = object_velp = object_velr = object_rel_pos = np.zeros(0)
gripper_state = robot_qpos[-2:]
gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric
gripper_vel = (
robot_qvel[-2:] * dt
) # change to a scalar if the gripper is made symmetric
if not self.has_object:
achieved_goal = grip_pos.copy()
@@ -169,7 +175,9 @@ class FetchEnv(robot_env.RobotEnv):
if self.has_object:
object_xpos = self.initial_gripper_xpos[:2]
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2)
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(
-self.obj_range, self.obj_range, size=2
)
object_qpos = self.sim.data.get_joint_qpos("object0:joint")
assert object_qpos.shape == (7,)
object_qpos[:2] = object_xpos
@@ -180,13 +188,17 @@ class FetchEnv(robot_env.RobotEnv):
def _sample_goal(self):
if self.has_object:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
-self.target_range, self.target_range, size=3
)
goal += self.target_offset
goal[2] = self.height_offset
if self.target_in_the_air and self.np_random.uniform() < 0.5:
goal[2] += self.np_random.uniform(0, 0.45)
else:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
-self.target_range, self.target_range, size=3
)
return goal.copy()
def _is_success(self, achieved_goal, desired_goal):
@@ -200,9 +212,9 @@ class FetchEnv(robot_env.RobotEnv):
self.sim.forward()
# Move end effector into position.
gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos(
"robot0:grip"
)
gripper_target = np.array(
[-0.498, 0.005, -0.431 + self.gripper_extra_height]
) + self.sim.data.get_site_xpos("robot0:grip")
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)

View File

@@ -74,7 +74,9 @@ class ManipulateEnv(hand_env.HandEnv):
self.target_position = target_position
self.target_rotation = target_rotation
self.target_position_range = target_position_range
self.parallel_quats = [rotations.euler2quat(r) for r in rotations.get_parallel_rotations()]
self.parallel_quats = [
rotations.euler2quat(r) for r in rotations.get_parallel_rotations()
]
self.randomize_initial_rotation = randomize_initial_rotation
self.randomize_initial_position = randomize_initial_position
self.distance_threshold = distance_threshold
@@ -180,7 +182,9 @@ class ManipulateEnv(hand_env.HandEnv):
angle = self.np_random.uniform(-np.pi, np.pi)
axis = np.array([0.0, 0.0, 1.0])
z_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats))
]
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
elif self.target_rotation in ["xyz", "ignore"]:
@@ -191,7 +195,9 @@ class ManipulateEnv(hand_env.HandEnv):
elif self.target_rotation == "fixed":
pass
else:
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
raise error.Error(
'Unknown target_rotation option "{}".'.format(self.target_rotation)
)
# Randomize initial position.
if self.randomize_initial_position:
@@ -223,13 +229,17 @@ class ManipulateEnv(hand_env.HandEnv):
target_pos = None
if self.target_position == "random":
assert self.target_position_range.shape == (3, 2)
offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1])
offset = self.np_random.uniform(
self.target_position_range[:, 0], self.target_position_range[:, 1]
)
assert offset.shape == (3,)
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
elif self.target_position in ["ignore", "fixed"]:
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
else:
raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
raise error.Error(
'Unknown target_position option "{}".'.format(self.target_position)
)
assert target_pos is not None
assert target_pos.shape == (3,)
@@ -243,7 +253,9 @@ class ManipulateEnv(hand_env.HandEnv):
angle = self.np_random.uniform(-np.pi, np.pi)
axis = np.array([0.0, 0.0, 1.0])
target_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats))
]
target_quat = rotations.quat_mul(target_quat, parallel_quat)
elif self.target_rotation == "xyz":
angle = self.np_random.uniform(-np.pi, np.pi)
@@ -252,7 +264,9 @@ class ManipulateEnv(hand_env.HandEnv):
elif self.target_rotation in ["ignore", "fixed"]:
target_quat = self.sim.data.get_joint_qpos("object:joint")
else:
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
raise error.Error(
'Unknown target_rotation option "{}".'.format(self.target_rotation)
)
assert target_quat is not None
assert target_quat.shape == (4,)
@@ -279,8 +293,12 @@ class ManipulateEnv(hand_env.HandEnv):
def _get_obs(self):
robot_qpos, robot_qvel = robot_get_obs(self.sim)
object_qvel = self.sim.data.get_joint_qvel("object:joint")
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, achieved_goal])
achieved_goal = (
self._get_achieved_goal().ravel()
) # this contains the object position + rotation
observation = np.concatenate(
[robot_qpos, robot_qvel, object_qvel, achieved_goal]
)
return {
"observation": observation.copy(),
"achieved_goal": achieved_goal.copy(),
@@ -289,7 +307,9 @@ class ManipulateEnv(hand_env.HandEnv):
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
def __init__(
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(
self,
@@ -302,7 +322,9 @@ class HandBlockEnv(ManipulateEnv, utils.EzPickle):
class HandEggEnv(ManipulateEnv, utils.EzPickle):
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
def __init__(
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(
self,
@@ -315,7 +337,9 @@ class HandEggEnv(ManipulateEnv, utils.EzPickle):
class HandPenEnv(ManipulateEnv, utils.EzPickle):
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
def __init__(
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(
self,

View File

@@ -70,12 +70,16 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
for (
k,
v,
) in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
) in (
self.sim.model._sensor_name2id.items()
): # get touch sensor site names and their ids
if "robot0:TS_" in k:
self._touch_sensor_id_site_id.append(
(
v,
self.sim.model._site_name2id[k.replace("robot0:TS_", "robot0:T_")],
self.sim.model._site_name2id[
k.replace("robot0:TS_", "robot0:T_")
],
)
)
self._touch_sensor_id.append(v)
@@ -89,9 +93,15 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
obs = self._get_obs()
self.observation_space = spaces.Dict(
dict(
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
desired_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
achieved_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
observation=spaces.Box(
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
),
)
)
@@ -107,7 +117,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
def _get_obs(self):
robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim)
object_qvel = self.sim.data.get_joint_qvel("object:joint")
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
achieved_goal = (
self._get_achieved_goal().ravel()
) # this contains the object position + rotation
touch_values = [] # get touch sensor readings. if there is one, set value to 1
if self.touch_get_obs == "sensordata":
touch_values = self.sim.data.sensordata[self._touch_sensor_id]
@@ -115,7 +127,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0
elif self.touch_get_obs == "log":
touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0)
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal])
observation = np.concatenate(
[robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal]
)
return {
"observation": observation.copy(),
@@ -132,7 +146,9 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata",
reward_type="sparse",
):
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
utils.EzPickle.__init__(
self, target_position, target_rotation, touch_get_obs, reward_type
)
ManipulateTouchSensorsEnv.__init__(
self,
model_path=MANIPULATE_BLOCK_XML,
@@ -152,7 +168,9 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata",
reward_type="sparse",
):
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
utils.EzPickle.__init__(
self, target_position, target_rotation, touch_get_obs, reward_type
)
ManipulateTouchSensorsEnv.__init__(
self,
model_path=MANIPULATE_EGG_XML,
@@ -172,7 +190,9 @@ class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata",
reward_type="sparse",
):
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
utils.EzPickle.__init__(
self, target_position, target_rotation, touch_get_obs, reward_type
)
ManipulateTouchSensorsEnv.__init__(
self,
model_path=MANIPULATE_PEN_XML,

View File

@@ -96,7 +96,9 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
self.sim.forward()
self.initial_goal = self._get_achieved_goal().copy()
self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id("robot0:palm")].copy()
self.palm_xpos = self.sim.data.body_xpos[
self.sim.model.body_name2id("robot0:palm")
].copy()
def _get_obs(self):
robot_qpos, robot_qvel = robot_get_obs(self.sim)
@@ -153,5 +155,7 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
for finger_idx in range(5):
site_name = "finger{}".format(finger_idx)
site_id = self.sim.model.site_name2id(site_name)
self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id]
self.sim.model.site_pos[site_id] = (
achieved_goal[finger_idx] - sites_offset[site_id]
)
self.sim.forward()

View File

@@ -30,14 +30,22 @@ class HandEnv(robot_env.RobotEnv):
if self.relative_control:
actuation_center = np.zeros_like(action)
for i in range(self.sim.data.ctrl.shape[0]):
actuation_center[i] = self.sim.data.get_joint_qpos(self.sim.model.actuator_names[i].replace(":A_", ":"))
actuation_center[i] = self.sim.data.get_joint_qpos(
self.sim.model.actuator_names[i].replace(":A_", ":")
)
for joint_name in ["FF", "MF", "RF", "LF"]:
act_idx = self.sim.model.actuator_name2id("robot0:A_{}J1".format(joint_name))
actuation_center[act_idx] += self.sim.data.get_joint_qpos("robot0:{}J0".format(joint_name))
act_idx = self.sim.model.actuator_name2id(
"robot0:A_{}J1".format(joint_name)
)
actuation_center[act_idx] += self.sim.data.get_joint_qpos(
"robot0:{}J0".format(joint_name)
)
else:
actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0
self.sim.data.ctrl[:] = actuation_center + action * actuation_range
self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1])
self.sim.data.ctrl[:] = np.clip(
self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]
)
def _viewer_setup(self):
body_id = self.sim.model.body_name2id("robot0:palm")

View File

@@ -46,9 +46,15 @@ class RobotEnv(gym.GoalEnv):
self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32")
self.observation_space = spaces.Dict(
dict(
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
desired_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
achieved_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
observation=spaces.Box(
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
),
)
)

View File

@@ -164,8 +164,12 @@ def mat2euler(mat):
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]),
)
euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy))
euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0)
euler[..., 1] = np.where(
condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)
)
euler[..., 0] = np.where(
condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0
)
return euler

View File

@@ -75,9 +75,15 @@ def reset_mocap2body_xpos(sim):
values as the bodies they're welded to.
"""
if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None:
if (
sim.model.eq_type is None
or sim.model.eq_obj1id is None
or sim.model.eq_obj2id is None
):
return
for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id):
for eq_type, obj1_id, obj2_id in zip(
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id
):
if eq_type != mujoco_py.const.EQ_WELD:
continue

View File

@@ -2,7 +2,10 @@ from gym import envs, logger
import os
SKIP_MUJOCO_WARNING_MESSAGE = "Cannot run mujoco test (either license key not found or mujoco not" "installed properly)."
SKIP_MUJOCO_WARNING_MESSAGE = (
"Cannot run mujoco test (either license key not found or mujoco not"
"installed properly)."
)
skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
@@ -18,7 +21,9 @@ def should_skip_env_spec_for_tests(spec):
# troublesome to run frequently
ep = spec.entry_point
# Skip mujoco tests for pull request CI
if skip_mujoco and (ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")):
if skip_mujoco and (
ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")
):
return True
try:
import atari_py
@@ -34,7 +39,11 @@ def should_skip_env_spec_for_tests(spec):
if (
"GoEnv" in ep
or "HexEnv" in ep
or (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
or (
ep.startswith("gym.envs.atari")
and not spec.id.startswith("Pong")
and not spec.id.startswith("Seaquest")
)
):
logger.warn("Skipping tests for env {}".format(ep))
return True

View File

@@ -25,7 +25,9 @@ def test_env(spec):
step_responses2 = [env2.step(action) for action in action_samples2]
env2.close()
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
for i, (action_sample1, action_sample2) in enumerate(
zip(action_samples1, action_samples2)
):
try:
assert_equals(action_sample1, action_sample2)
except AssertionError:
@@ -33,7 +35,11 @@ def test_env(spec):
print("env2.action_space=", env2.action_space)
print("action_samples1=", action_samples1)
print("action_samples2=", action_samples2)
print("[{}] action_sample1: {}, action_sample2: {}".format(i, action_sample1, action_sample2))
print(
"[{}] action_sample1: {}, action_sample2: {}".format(
i, action_sample1, action_sample2
)
)
raise
# Don't check rollout equality if it's a a nondeterministic
@@ -43,7 +49,9 @@ def test_env(spec):
assert_equals(initial_observation1, initial_observation2)
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
zip(step_responses1, step_responses2)
):
assert_equals(o1, o2, "[{}] ".format(i))
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
@@ -58,7 +66,9 @@ def test_env(spec):
def assert_equals(a, b, prefix=None):
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b)
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(
prefix, a, b
)
for k in a.keys():
v_a = a[k]

View File

@@ -24,7 +24,9 @@ def test_env(spec):
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
a = act_space.sample()
observation, reward, done, _info = env.step(a)
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(observation)
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(
observation
)
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)

View File

@@ -81,7 +81,9 @@ def test_env_semantics(spec):
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn(
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id)
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(
spec.id
)
)
return
@@ -98,15 +100,21 @@ def test_env_semantics(spec):
)
if rollout_dict[spec.id]["actions"] != actions_now:
errors.append(
"Actions not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["actions"], actions_now)
"Actions not equal for {} -- expected {} but got {}".format(
spec.id, rollout_dict[spec.id]["actions"], actions_now
)
)
if rollout_dict[spec.id]["rewards"] != rewards_now:
errors.append(
"Rewards not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["rewards"], rewards_now)
"Rewards not equal for {} -- expected {} but got {}".format(
spec.id, rollout_dict[spec.id]["rewards"], rewards_now
)
)
if rollout_dict[spec.id]["dones"] != dones_now:
errors.append(
"Dones not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["dones"], dones_now)
"Dones not equal for {} -- expected {} but got {}".format(
spec.id, rollout_dict[spec.id]["dones"], dones_now
)
)
if len(errors):
for error in errors:

View File

@@ -4,7 +4,9 @@ from gym import envs
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
def verify_environments_match(old_environment_id, new_environment_id, seed=1, num_actions=1000):
def verify_environments_match(
old_environment_id, new_environment_id, seed=1, num_actions=1000
):
old_environment = envs.make(old_environment_id)
new_environment = envs.make(new_environment_id)

View File

@@ -83,6 +83,8 @@ def test_malformed_lookup():
try:
registry.spec(u"“Breakout-v0”")
except error.Error as e:
assert "malformed environment ID" in "{}".format(e), "Unexpected message: {}".format(e)
assert "malformed environment ID" in "{}".format(
e
), "Unexpected message: {}".format(e)
else:
assert False

View File

@@ -75,7 +75,9 @@ class BlackjackEnv(gym.Env):
def __init__(self, natural=False):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
self.observation_space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
self.seed()
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules

View File

@@ -141,7 +141,9 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
else:
if is_slippery:
for b in [(a - 1) % 4, a, (a + 1) % 4]:
li.append((1.0 / 3.0, *update_probability_matrix(row, col, b)))
li.append(
(1.0 / 3.0, *update_probability_matrix(row, col, b))
)
else:
li.append((1.0, *update_probability_matrix(row, col, a)))
@@ -155,7 +157,9 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
desc = [[c.decode("utf-8") for c in line] for line in desc]
desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
if self.lastaction is not None:
outfile.write(" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction]))
outfile.write(
" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction])
)
else:
outfile.write("\n")
outfile.write("\n".join("".join(line) for line in desc) + "\n")

View File

@@ -80,7 +80,11 @@ class GuessingGame(gym.Env):
reward = 0
done = False
if (self.number - self.range * 0.01) < action < (self.number + self.range * 0.01):
if (
(self.number - self.range * 0.01)
< action
< (self.number + self.range * 0.01)
):
reward = 1
done = True

View File

@@ -62,7 +62,10 @@ class HotterColder(gym.Env):
elif action > self.number:
self.observation = 3
reward = ((min(action, self.number) + self.bounds) / (max(action, self.number) + self.bounds)) ** 2
reward = (
(min(action, self.number) + self.bounds)
/ (max(action, self.number) + self.bounds)
) ** 2
self.guess_count += 1
done = self.guess_count >= self.guess_max

View File

@@ -65,7 +65,9 @@ class KellyCoinflipEnv(gym.Env):
return [seed]
def step(self, action):
bet_in_dollars = min(action / 100.0, self.wealth) # action = desired bet in pennies
bet_in_dollars = min(
action / 100.0, self.wealth
) # action = desired bet in pennies
self.rounds -= 1
coinflip = flip(self.edge, self.np_random)
@@ -147,19 +149,35 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
if self.clip_distributions:
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m))
max_wealth_bound = round(
genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m)
)
max_wealth = max_wealth_bound + 1.0
while max_wealth > max_wealth_bound:
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)))
max_wealth = round(
genpareto.rvs(
max_wealth_alpha, max_wealth_m, random_state=self.np_random
)
)
max_rounds_bound = int(
round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd))
)
max_rounds = max_rounds_bound + 1
while max_rounds > max_rounds_bound:
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
max_rounds = int(
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
)
else:
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
max_wealth = round(
genpareto.rvs(
max_wealth_alpha, max_wealth_m, random_state=self.np_random
)
)
max_wealth_bound = max_wealth
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
max_rounds = int(
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
)
max_rounds_bound = max_rounds
# add an additional global variable which is the sufficient statistic for the
@@ -176,7 +194,9 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
self.action_space = spaces.Discrete(int(max_wealth_bound * 100))
self.observation_space = spaces.Tuple(
(
spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth
spaces.Box(
0, max_wealth_bound, shape=[1], dtype=np.float32
), # current wealth
spaces.Discrete(max_rounds_bound + 1), # rounds elapsed
spaces.Discrete(max_rounds_bound + 1), # wins
spaces.Discrete(max_rounds_bound + 1), # losses

View File

@@ -80,7 +80,10 @@ class TaxiEnv(discrete.DiscreteEnv):
max_col = num_columns - 1
initial_state_distrib = np.zeros(num_states)
num_actions = 6
P = {state: {action: [] for action in range(num_actions)} for state in range(num_states)}
P = {
state: {action: [] for action in range(num_actions)}
for state in range(num_states)
}
for row in range(num_rows):
for col in range(num_columns):
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
@@ -91,7 +94,9 @@ class TaxiEnv(discrete.DiscreteEnv):
for action in range(num_actions):
# defaults
new_row, new_col, new_pass_idx = row, col, pass_idx
reward = -1 # default reward when there is no pickup/dropoff
reward = (
-1
) # default reward when there is no pickup/dropoff
done = False
taxi_loc = (row, col)
@@ -117,10 +122,14 @@ class TaxiEnv(discrete.DiscreteEnv):
new_pass_idx = locs.index(taxi_loc)
else: # dropoff at wrong location
reward = -10
new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
new_state = self.encode(
new_row, new_col, new_pass_idx, dest_idx
)
P[state][action].append((1.0, new_state, reward, done))
initial_state_distrib /= initial_state_distrib.sum()
discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, initial_state_distrib)
discrete.DiscreteEnv.__init__(
self, num_states, num_actions, P, initial_state_distrib
)
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
# (5) 5, 5, 4
@@ -156,9 +165,13 @@ class TaxiEnv(discrete.DiscreteEnv):
return "_" if x == " " else x
if pass_idx < 4:
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True)
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True
)
pi, pj = self.locs[pass_idx]
out[1 + pi][2 * pj + 1] = utils.colorize(out[1 + pi][2 * pj + 1], "blue", bold=True)
out[1 + pi][2 * pj + 1] = utils.colorize(
out[1 + pi][2 * pj + 1], "blue", bold=True
)
else: # passenger in taxi
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
@@ -168,7 +181,13 @@ class TaxiEnv(discrete.DiscreteEnv):
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
outfile.write("\n".join(["".join(row) for row in out]) + "\n")
if self.lastaction is not None:
outfile.write(" ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
outfile.write(
" ({})\n".format(
["South", "North", "East", "West", "Pickup", "Dropoff"][
self.lastaction
]
)
)
else:
outfile.write("\n")

View File

@@ -55,7 +55,9 @@ class CubeCrash(gym.Env):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3)
self.reset()
@@ -81,9 +83,16 @@ class CubeCrash(gym.Env):
self.potential = None
self.step_n = 0
while 1:
self.wall_color = self.random_color() if self.use_random_colors else color_white
self.cube_color = self.random_color() if self.use_random_colors else color_green
if np.linalg.norm(self.wall_color - self.bg_color) < 50 or np.linalg.norm(self.cube_color - self.bg_color) < 50:
self.wall_color = (
self.random_color() if self.use_random_colors else color_white
)
self.cube_color = (
self.random_color() if self.use_random_colors else color_green
)
if (
np.linalg.norm(self.wall_color - self.bg_color) < 50
or np.linalg.norm(self.cube_color - self.bg_color) < 50
):
continue
break
return self.step(0)[0]
@@ -108,7 +117,9 @@ class CubeCrash(gym.Env):
self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1,
:,
] = self.bg_color
obs[self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :] = self.cube_color
obs[
self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :
] = self.cube_color
if self.use_black_screen and self.step_n > 4:
obs[:] = np.zeros((3,), dtype=np.uint8)

View File

@@ -62,12 +62,16 @@ class MemorizeDigits(gym.Env):
def __init__(self):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(10)
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
for digit in range(10):
for y in range(6):
self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]]
self.bogus_mnist[digit, y, :] = [
ord(char) for char in bogus_mnist[digit][y]
]
self.reset()
def seed(self, seed=None):
@@ -89,7 +93,9 @@ class MemorizeDigits(gym.Env):
self.color_bg = self.random_color() if self.use_random_colors else color_black
self.step_n = 0
while 1:
self.color_digit = self.random_color() if self.use_random_colors else color_white
self.color_digit = (
self.random_color() if self.use_random_colors else color_white
)
if np.linalg.norm(self.color_digit - self.color_bg) < 50:
continue
break
@@ -113,7 +119,9 @@ class MemorizeDigits(gym.Env):
digit_img[:] = self.color_bg
xxx = self.bogus_mnist[self.digit] == 42
digit_img[xxx] = self.color_digit
obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img
obs[
self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
] = digit_img
self.last_obs = obs
return obs, reward, done, {}

View File

@@ -102,7 +102,10 @@ class APIError(Error):
try:
http_body = http_body.decode("utf-8")
except:
http_body = "<Could not decode body as utf-8. " "Please report to gym@openai.com>"
http_body = (
"<Could not decode body as utf-8. "
"Please report to gym@openai.com>"
)
self._message = message
self.http_body = http_body
@@ -139,7 +142,9 @@ class InvalidRequestError(APIError):
json_body=None,
headers=None,
):
super(InvalidRequestError, self).__init__(message, http_body, http_status, json_body, headers)
super(InvalidRequestError, self).__init__(
message, http_body, http_status, json_body, headers
)
self.param = param

View File

@@ -29,16 +29,26 @@ class Box(Space):
# determine shape if it isn't provided directly
if shape is not None:
shape = tuple(shape)
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match provided shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match provided shape"
elif not np.isscalar(low):
shape = low.shape
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match low.shape"
elif not np.isscalar(high):
shape = high.shape
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match high.shape"
else:
raise ValueError("shape must be provided or inferred from the shapes of low or high")
raise ValueError(
"shape must be provided or inferred from the shapes of low or high"
)
if np.isscalar(low):
low = np.full(shape, low, dtype=dtype)
@@ -60,7 +70,9 @@ class Box(Space):
high_precision = _get_precision(self.high.dtype)
dtype_precision = _get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision:
logger.warn("Box bound precision lowered by casting to {}".format(self.dtype))
logger.warn(
"Box bound precision lowered by casting to {}".format(self.dtype)
)
self.low = self.low.astype(self.dtype)
self.high = self.high.astype(self.dtype)
@@ -107,11 +119,19 @@ class Box(Space):
# Vectorized sampling by interval type
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded]
sample[low_bounded] = (
self.np_random.exponential(size=low_bounded[low_bounded].shape)
+ self.low[low_bounded]
)
sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded]
sample[upp_bounded] = (
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
+ self.high[upp_bounded]
)
sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape)
sample[bounded] = self.np_random.uniform(
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
)
if self.dtype.kind == "i":
sample = np.floor(sample)
@@ -120,7 +140,9 @@ class Box(Space):
def contains(self, x):
if isinstance(x, list):
x = np.array(x) # Promote list to array for contains check
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
return (
x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
)
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
@@ -129,7 +151,9 @@ class Box(Space):
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype)
return "Box({}, {}, {}, {})".format(
self.low.min(), self.high.max(), self.shape, self.dtype
)
def __eq__(self, other):
return (

View File

@@ -33,7 +33,9 @@ class Dict(Space):
"""
def __init__(self, spaces=None, **spaces_kwargs):
assert (spaces is None) or (not spaces_kwargs), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
assert (spaces is None) or (
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
if spaces is None:
spaces = spaces_kwargs
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
@@ -42,8 +44,12 @@ class Dict(Space):
spaces = OrderedDict(spaces)
self.spaces = spaces
for space in spaces.values():
assert isinstance(space, Space), "Values of the dict should be instances of gym.Space"
super(Dict, self).__init__(None, None) # None for shape and dtype, since it'll require special handling
assert isinstance(
space, Space
), "Values of the dict should be instances of gym.Space"
super(Dict, self).__init__(
None, None
) # None for shape and dtype, since it'll require special handling
def seed(self, seed=None):
[space.seed(seed) for space in self.spaces.values()]
@@ -69,11 +75,18 @@ class Dict(Space):
yield key
def __repr__(self):
return "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
return (
"Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)
def to_jsonable(self, sample_n):
# serialize as dict-repr of vectors
return {key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items()}
return {
key: space.to_jsonable([sample[key] for sample in sample_n])
for key, space in self.spaces.items()
}
def from_jsonable(self, sample_n):
dict_of_list = {}

View File

@@ -22,7 +22,9 @@ class Discrete(Space):
def contains(self, x):
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()):
elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
):
as_int = int(x)
else:
return False

View File

@@ -34,7 +34,9 @@ class MultiDiscrete(Space):
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
def sample(self):
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
self.dtype
)
def contains(self, x):
if isinstance(x, list):

View File

@@ -25,7 +25,9 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
@@ -69,7 +71,9 @@ def test_roundtripping(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],

View File

@@ -28,7 +28,9 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, u
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
7,
@@ -58,14 +60,18 @@ def test_flatdim(space, flatdim):
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
)
def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(
type(flat_space), Box
)
flatdim = utils.flatdim(space)
(single_dim,) = flat_space.shape
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
@@ -89,7 +95,9 @@ def test_flatten_space_boxes(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
@@ -99,7 +107,9 @@ def test_flat_space_contains_flat_points(space):
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(i, flat_sample, flat_space)
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(
i, flat_sample, flat_space
)
@pytest.mark.parametrize(
@@ -120,7 +130,9 @@ def test_flat_space_contains_flat_points(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
@@ -150,7 +162,9 @@ def test_flatten_dim(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
}
),
],
@@ -158,9 +172,15 @@ def test_flatten_dim(space):
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
assert compare_nested(original, roundtripped), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
roundtripped_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
for i, (original, roundtripped) in enumerate(
zip(some_samples, roundtripped_samples)
):
assert compare_nested(
original, roundtripped
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
def compare_nested(left, right):
@@ -168,7 +188,9 @@ def compare_nested(left, right):
return np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
for ((left_key, left_value), (right_key, right_value)) in zip(
left.items(), right.items()
):
if not res:
return False
res = left_key == right_key and compare_nested(left_value, right_value)
@@ -216,7 +238,9 @@ Expecteded flattened types are based off:
Dict(
{
"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16
),
}
),
np.float64,
@@ -230,8 +254,12 @@ def test_dtypes(original_space, expected_flattened_dtype):
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " "{}".format(
assert flattened_space.contains(
flattened_sample
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), "Expected flattened_space's dtype to equal " "{}".format(
expected_flattened_dtype
)
@@ -244,10 +272,9 @@ def test_dtypes(original_space, expected_flattened_dtype):
def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(
unflattened_sample, int
), "Expected unflattened_sample to be an int. unflattened_sample: " "{} original_sample: {}".format(
unflattened_sample, original_sample
assert isinstance(unflattened_sample, int), (
"Expected unflattened_sample to be an int. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):

View File

@@ -13,7 +13,9 @@ class Tuple(Space):
def __init__(self, spaces):
self.spaces = spaces
for space in spaces:
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
super(Tuple, self).__init__(None, None)
def seed(self, seed=None):
@@ -36,10 +38,21 @@ class Tuple(Space):
def to_jsonable(self, sample_n):
# serialize as list-repr of tuple of vectors
return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)]
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n):
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
return [
sample
for sample in zip(
*[
space.from_jsonable(sample_n[i])
for i, space in enumerate(self.spaces)
]
)
]
def __getitem__(self, index):
return self.spaces[index]

View File

@@ -49,7 +49,9 @@ def flatten(space, x):
onehot[x] = 1
return onehot
elif isinstance(space, Tuple):
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
)
elif isinstance(space, Dict):
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
@@ -77,13 +79,17 @@ def unflatten(space, x):
elif isinstance(space, Tuple):
dims = [flatdim(s) for s in space.spaces]
list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)]
list_unflattened = [
unflatten(s, flattened)
for flattened, s in zip(list_flattened, space.spaces)
]
return tuple(list_unflattened)
elif isinstance(space, Dict):
dims = [flatdim(s) for s in space.spaces.values()]
list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [
(key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
(key, unflatten(s, flattened))
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
]
return OrderedDict(list_unflattened)
elif isinstance(space, MultiBinary):

View File

@@ -88,7 +88,11 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N
elif hasattr(env.unwrapped, "get_keys_to_action"):
keys_to_action = env.unwrapped.get_keys_to_action()
else:
assert False, env.spec.id + " does not have explicit key to action mapping, " + "please specify one manually"
assert False, (
env.spec.id
+ " does not have explicit key to action mapping, "
+ "please specify one manually"
)
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
video_size = [rendered.shape[1], rendered.shape[0]]
@@ -168,7 +172,9 @@ class PlayPlot(object):
for i, plot in enumerate(self.cur_plot):
if plot is not None:
plot.remove()
self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]), c="blue")
self.cur_plot[i] = self.ax[i].scatter(
range(xmin, xmax), list(self.data[i]), c="blue"
)
self.ax[i].set_xlim(xmin, xmax)
plt.pause(0.000001)

View File

@@ -10,7 +10,9 @@ from gym import error
def np_random(seed=None):
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error("Seed must be a non-negative integer or omitted, not {}".format(seed))
raise error.Error(
"Seed must be a non-negative integer or omitted, not {}".format(seed)
)
seed = create_seed(seed)

View File

@@ -53,7 +53,9 @@ def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
if wrappers is not None:
if callable(wrappers):
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all([callable(w) for w in wrappers]):
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
for wrapper in wrappers:
env = wrapper(env)
else:

View File

@@ -107,8 +107,12 @@ class AsyncVectorEnv(VectorEnv):
if self.shared_memory:
try:
_obs_buffer = create_shared_memory(self.single_observation_space, n=self.num_envs, ctx=ctx)
self.observations = read_from_shared_memory(_obs_buffer, self.single_observation_space, n=self.num_envs)
_obs_buffer = create_shared_memory(
self.single_observation_space, n=self.num_envs, ctx=ctx
)
self.observations = read_from_shared_memory(
_obs_buffer, self.single_observation_space, n=self.num_envs
)
except CustomSpaceError:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` "
@@ -120,7 +124,9 @@ class AsyncVectorEnv(VectorEnv):
)
else:
_obs_buffer = None
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue()
@@ -162,7 +168,8 @@ class AsyncVectorEnv(VectorEnv):
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `seed` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
"Calling `seed` while waiting "
"for a pending call to `{0}` to complete.".format(self._state.value),
self._state.value,
)
@@ -175,7 +182,8 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `reset_async` while waiting " "for a pending call to `{0}` to complete".format(self._state.value),
"Calling `reset_async` while waiting "
"for a pending call to `{0}` to complete".format(self._state.value),
self._state.value,
)
@@ -206,7 +214,8 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
"The call to `reset_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
"The call to `reset_wait` has timed out after "
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -214,7 +223,9 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
if not self.shared_memory:
self.observations = concatenate(results, self.observations, self.single_observation_space)
self.observations = concatenate(
results, self.observations, self.single_observation_space
)
return deepcopy(self.observations) if self.copy else self.observations
@@ -228,7 +239,8 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `step_async` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
"Calling `step_async` while waiting "
"for a pending call to `{0}` to complete.".format(self._state.value),
self._state.value,
)
@@ -268,7 +280,8 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
"The call to `step_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
"The call to `step_wait` has timed out after "
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -277,7 +290,9 @@ class AsyncVectorEnv(VectorEnv):
observations_list, rewards, dones, infos = zip(*results)
if not self.shared_memory:
self.observations = concatenate(observations_list, self.observations, self.single_observation_space)
self.observations = concatenate(
observations_list, self.observations, self.single_observation_space
)
return (
deepcopy(self.observations) if self.copy else self.observations,
@@ -303,7 +318,8 @@ class AsyncVectorEnv(VectorEnv):
try:
if self._state != AsyncState.DEFAULT:
logger.warn(
"Calling `close` while waiting for a pending " "call to `{0}` to complete.".format(self._state.value)
"Calling `close` while waiting for a pending "
"call to `{0}` to complete.".format(self._state.value)
)
function = getattr(self, "{0}_wait".format(self._state.value))
function(timeout)
@@ -359,7 +375,8 @@ class AsyncVectorEnv(VectorEnv):
def _assert_is_running(self):
if self.closed:
raise ClosedEnvironmentError(
"Trying to operate on `{0}`, after a " "call to `close()`.".format(type(self).__name__)
"Trying to operate on `{0}`, after a "
"call to `close()`.".format(type(self).__name__)
)
def _raise_if_errors(self, successes):
@@ -370,7 +387,10 @@ class AsyncVectorEnv(VectorEnv):
assert num_errors > 0
for _ in range(num_errors):
index, exctype, value = self.error_queue.get()
logger.error("Received the following error from Worker-{0}: " "{1}: {2}".format(index, exctype.__name__, value))
logger.error(
"Received the following error from Worker-{0}: "
"{1}: {2}".format(index, exctype.__name__, value)
)
logger.error("Shutting down Worker-{0}.".format(index))
self.parent_pipes[index].close()
self.parent_pipes[index] = None
@@ -425,13 +445,17 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
command, data = pipe.recv()
if command == "reset":
observation = env.reset()
write_to_shared_memory(index, observation, shared_memory, observation_space)
write_to_shared_memory(
index, observation, shared_memory, observation_space
)
pipe.send((None, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
observation = env.reset()
write_to_shared_memory(index, observation, shared_memory, observation_space)
write_to_shared_memory(
index, observation, shared_memory, observation_space
)
pipe.send(((None, reward, done, info), True))
elif command == "seed":
env.seed(data)

View File

@@ -44,7 +44,9 @@ class SyncVectorEnv(VectorEnv):
)
self._check_observation_spaces()
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
@@ -65,7 +67,9 @@ class SyncVectorEnv(VectorEnv):
for env in self.envs:
observation = env.reset()
observations.append(observation)
self.observations = concatenate(observations, self.observations, self.single_observation_space)
self.observations = concatenate(
observations, self.observations, self.single_observation_space
)
return deepcopy(self.observations) if self.copy else self.observations
@@ -80,7 +84,9 @@ class SyncVectorEnv(VectorEnv):
observation = env.reset()
observations.append(observation)
infos.append(info)
self.observations = concatenate(observations, self.observations, self.single_observation_space)
self.observations = concatenate(
observations, self.observations, self.single_observation_space
)
return (
deepcopy(self.observations) if self.copy else self.observations,

View File

@@ -10,7 +10,9 @@ from gym.vector.tests.utils import spaces
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_concatenate(space):
def assert_type(lhs, rhs, n):
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
@@ -51,7 +53,9 @@ def test_concatenate(space):
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array(space, n):
def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces):
@@ -79,7 +83,9 @@ def test_create_empty_array(space, n):
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_zeros(space, n):
def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces):
@@ -107,7 +113,9 @@ def test_create_empty_array_zeros(space, n):
assert_nested_type(array, space, n=n)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_none_shape_ones(space):
def assert_nested_type(arr, space):
if isinstance(space, _BaseGymSpaces):

View File

@@ -46,7 +46,9 @@ expected_types = [
list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces],
)
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
def test_create_shared_memory(space, expected_type, n, ctx):
def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs)
@@ -75,7 +77,9 @@ def test_create_shared_memory(space, expected_type, n, ctx):
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
@pytest.mark.parametrize("space", custom_spaces)
def test_create_shared_memory_custom_space(n, ctx, space):
ctx = mp if (ctx is None) else mp.get_context(ctx)
@@ -83,7 +87,9 @@ def test_create_shared_memory_custom_space(n, ctx, space):
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_write_to_shared_memory(space):
def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list)
@@ -107,7 +113,9 @@ def test_write_to_shared_memory(space):
shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
processes = [
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
]
for process in processes:
process.start()
@@ -117,19 +125,25 @@ def test_write_to_shared_memory(space):
assert_nested_equal(shared_memory_n8, samples)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_read_from_shared_memory(space):
def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list)
if isinstance(space, Tuple):
assert isinstance(lhs, tuple)
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n)
assert_nested_equal(
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
)
elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict)
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n)
assert_nested_equal(
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
)
elif isinstance(space, _BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
@@ -147,7 +161,9 @@ def test_read_from_shared_memory(space):
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
processes = [
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
]
for process in processes:
process.start()

View File

@@ -10,8 +10,12 @@ expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32),
Box(
low=np.array([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]),
high=np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
low=np.array(
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]
),
high=np.array(
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
),
dtype=np.float32,
),
Box(

View File

@@ -7,8 +7,12 @@ from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
spaces = [
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32),
Box(low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32),
Box(low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32),
Box(
low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32
),
Box(
low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32
),
Box(low=0, high=255, shape=(), dtype=np.uint8),
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Discrete(2),
@@ -24,13 +28,17 @@ spaces = [
Dict(
{
"position": Discrete(23),
"velocity": Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32),
"velocity": Box(
low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32
),
}
),
Dict(
{
"position": Dict({"x": Discrete(29), "y": Discrete(31)}),
"velocity": Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))),
"velocity": Tuple(
(Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))
),
}
),
]
@@ -42,7 +50,9 @@ class UnittestSlowEnv(gym.Env):
def __init__(self, slow_reset=0.3):
super(UnittestSlowEnv, self).__init__()
self.slow_reset = slow_reset
self.observation_space = Box(low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
self.observation_space = Box(
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
)
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
def reset(self):

View File

@@ -46,7 +46,10 @@ def concatenate(items, out, space):
elif isinstance(space, Space):
return concatenate_custom(items, out, space)
else:
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
raise ValueError(
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
def concatenate_base(items, out, space):
@@ -54,12 +57,18 @@ def concatenate_base(items, out, space):
def concatenate_tuple(items, out, space):
return tuple(concatenate([item[i] for item in items], out[i], subspace) for (i, subspace) in enumerate(space.spaces))
return tuple(
concatenate([item[i] for item in items], out[i], subspace)
for (i, subspace) in enumerate(space.spaces)
)
def concatenate_dict(items, out, space):
return OrderedDict(
[(key, concatenate([item[key] for item in items], out[key], subspace)) for (key, subspace) in space.spaces.items()]
[
(key, concatenate([item[key] for item in items], out[key], subspace))
for (key, subspace) in space.spaces.items()
]
)
@@ -109,7 +118,10 @@ def create_empty_array(space, n=1, fn=np.zeros):
elif isinstance(space, Space):
return create_empty_array_custom(space, n=n, fn=fn)
else:
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
raise ValueError(
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
def create_empty_array_base(space, n=1, fn=np.zeros):
@@ -122,7 +134,12 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros):
def create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn)) for (key, subspace) in space.spaces.items()])
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
)
def create_empty_array_custom(space, n=1, fn=np.zeros):

View File

@@ -56,11 +56,18 @@ def create_base_shared_memory(space, n=1, ctx=mp):
def create_tuple_shared_memory(space, n=1, ctx=mp):
return tuple(create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces)
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
def create_dict_shared_memory(space, n=1, ctx=mp):
return OrderedDict([(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()])
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
def read_from_shared_memory(shared_memory, space, n=1):
@@ -107,16 +114,24 @@ def read_from_shared_memory(shared_memory, space, n=1):
def read_base_from_shared_memory(shared_memory, space, n=1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape((n,) + space.shape)
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
def read_tuple_from_shared_memory(shared_memory, space, n=1):
return tuple(read_from_shared_memory(memory, subspace, n=n) for (memory, subspace) in zip(shared_memory, space.spaces))
return tuple(
read_from_shared_memory(memory, subspace, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
def read_dict_from_shared_memory(shared_memory, space, n=1):
return OrderedDict(
[(key, read_from_shared_memory(shared_memory[key], subspace, n=n)) for (key, subspace) in space.spaces.items()]
[
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
for (key, subspace) in space.spaces.items()
]
)

View File

@@ -44,7 +44,8 @@ def batch_space(space, n=1):
return batch_space_custom(space, n=n)
else:
raise ValueError(
"Cannot batch space with type `{0}`. The space must " "be a valid `gym.Space` instance.".format(type(space))
"Cannot batch space with type `{0}`. The space must "
"be a valid `gym.Space` instance.".format(type(space))
)
@@ -74,7 +75,14 @@ def batch_space_tuple(space, n=1):
def batch_space_dict(space, n=1):
return Dict(OrderedDict([(key, batch_space(subspace, n=n)) for (key, subspace) in space.spaces.items()]))
return Dict(
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
)
)
def batch_space_custom(space, n=1):

View File

@@ -141,7 +141,9 @@ class VectorEnv(gym.Env):
if self.spec is None:
return "{}({})".format(self.__class__.__name__, self.num_envs)
else:
return "{}({}, {})".format(self.__class__.__name__, self.spec.id, self.num_envs)
return "{}({}, {})".format(
self.__class__.__name__, self.spec.id, self.num_envs
)
class VectorEnvWrapper(VectorEnv):
@@ -187,7 +189,9 @@ class VectorEnvWrapper(VectorEnv):
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
return getattr(self.env, name)
@property

View File

@@ -62,7 +62,8 @@ class AtariPreprocessing(gym.Wrapper):
assert noop_max >= 0
if frame_skip > 1:
assert "NoFrameskip" in env.spec.id, (
"disable frame-skipping in the original env. for more than one" " frame-skip as it will be done by the wrapper"
"disable frame-skipping in the original env. for more than one"
" frame-skip as it will be done by the wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
@@ -93,11 +94,15 @@ class AtariPreprocessing(gym.Wrapper):
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
def step(self, action):
R = 0.0
@@ -129,7 +134,11 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs):
# NoopReset
self.env.reset(**kwargs)
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
noops = (
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, done, _ = self.env.step(0)
if done:

View File

@@ -51,7 +51,11 @@ class FilterObservation(ObservationWrapper):
)
self.observation_space = type(wrapped_observation_space)(
[(name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys]
[
(name, copy.deepcopy(space))
for name, space in wrapped_observation_space.spaces.items()
if name in filter_keys
]
)
self._env = env
@@ -62,5 +66,11 @@ class FilterObservation(ObservationWrapper):
return filter_observation
def _filter_observation(self, observation):
observation = type(observation)([(name, value) for name, value in observation.items() if name in self._filter_keys])
observation = type(observation)(
[
(name, value)
for name, value in observation.items()
if name in self._filter_keys
]
)
return observation

View File

@@ -47,7 +47,9 @@ class LazyFrames(object):
def __getitem__(self, int_or_slice):
if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame
return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0)
return np.stack(
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
def __eq__(self, other):
return self.__array__() == other
@@ -56,7 +58,9 @@ class LazyFrames(object):
if self.lz4_compress:
from lz4.block import decompress
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape)
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
self.frame_shape
)
return frame
@@ -100,8 +104,12 @@ class FrameStack(Wrapper):
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
high = np.repeat(
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
def _get_observation(self):
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)

View File

@@ -11,15 +11,22 @@ class GrayScaleObservation(ObservationWrapper):
super(GrayScaleObservation, self).__init__(env)
self.keep_dim = keep_dim
assert len(env.observation_space.shape) == 3 and env.observation_space.shape[-1] == 3
assert (
len(env.observation_space.shape) == 3
and env.observation_space.shape[-1] == 3
)
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
obs_shape = self.observation_space.shape[:2]
if self.keep_dim:
self.observation_space = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8)
self.observation_space = Box(
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
)
else:
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
self.observation_space = Box(
low=0, high=255, shape=obs_shape, dtype=np.uint8
)
def observation(self, observation):
import cv2

View File

@@ -37,7 +37,9 @@ class Monitor(Wrapper):
self._monitor_id = None
self.env_semantics_autoreset = env.metadata.get("semantics.autoreset")
self._start(directory, video_callable, force, resume, write_upon_reset, uid, mode)
self._start(
directory, video_callable, force, resume, write_upon_reset, uid, mode
)
def step(self, action):
self._before_step(action)
@@ -161,7 +163,10 @@ class Monitor(Wrapper):
json.dump(
{
"stats": os.path.basename(self.stats_recorder.path),
"videos": [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos],
"videos": [
(os.path.basename(v), os.path.basename(m))
for v, m in self.videos
],
"env_info": self._env_info(),
},
f,
@@ -194,7 +199,9 @@ class Monitor(Wrapper):
elif mode == "training":
type = "t"
else:
raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
raise error.Error(
'Invalid mode {}: must be "training" or "evaluation"', mode
)
self.stats_recorder.type = type
def _before_step(self, action):
@@ -250,7 +257,9 @@ class Monitor(Wrapper):
env=self.env,
base_path=os.path.join(
self.directory,
"{}.video.{}.video{:06}".format(self.file_prefix, self.file_infix, self.episode_id),
"{}.video.{}.video{:06}".format(
self.file_prefix, self.file_infix, self.episode_id
),
),
metadata={"episode_id": self.episode_id},
enabled=self._video_enabled(),
@@ -260,7 +269,9 @@ class Monitor(Wrapper):
def _close_video_recorder(self):
self.video_recorder.close()
if self.video_recorder.functional:
self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
self.videos.append(
(self.video_recorder.path, self.video_recorder.metadata_path)
)
def _video_enabled(self):
return self.video_callable(self.episode_id)
@@ -290,11 +301,19 @@ class Monitor(Wrapper):
def detect_training_manifests(training_dir, files=None):
if files is None:
files = os.listdir(training_dir)
return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + ".")]
return [
os.path.join(training_dir, f)
for f in files
if f.startswith(MANIFEST_PREFIX + ".")
]
def detect_monitor_files(training_dir):
return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + ".")]
return [
os.path.join(training_dir, f)
for f in os.listdir(training_dir)
if f.startswith(FILE_PREFIX + ".")
]
def clear_monitor_files(training_dir):
@@ -363,7 +382,10 @@ def load_results(training_dir):
contents = json.load(f)
# Make these paths absolute again
stats_files.append(os.path.join(training_dir, contents["stats"]))
videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m)) for v, m in contents["videos"]]
videos += [
(os.path.join(training_dir, v), os.path.join(training_dir, m))
for v, m in contents["videos"]
]
env_infos.append(contents["env_info"])
env_info = collapse_env_infos(env_infos, training_dir)

View File

@@ -48,7 +48,9 @@ class VideoRecorder(object):
self.ansi_mode = True
else:
logger.info(
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env)
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(
env
)
)
# Whoops, turns out we shouldn't be enabled after all
self.enabled = False
@@ -67,7 +69,9 @@ class VideoRecorder(object):
path = base_path + required_ext
else:
# Otherwise, just generate a unique filename
with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f:
with tempfile.NamedTemporaryFile(
suffix=required_ext, delete=False
) as f:
path = f.name
self.path = path
@@ -79,20 +83,28 @@ class VideoRecorder(object):
if self.ansi_mode
else ""
)
raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint))
raise error.Error(
"Invalid path given: {} -- must have file extension {}.{}".format(
self.path, required_ext, hint
)
)
# Touch the file in any case, so we know it's present. (This
# corrects for platform platform differences. Using ffmpeg on
# OS X, the file is precreated, but not on Linux.
touch(path)
self.frames_per_sec = env.metadata.get("video.frames_per_second", 30)
self.output_frames_per_sec = env.metadata.get("video.output_frames_per_second", self.frames_per_sec)
self.output_frames_per_sec = env.metadata.get(
"video.output_frames_per_second", self.frames_per_sec
)
self.encoder = None # lazily start the process
self.broken = False
# Dump metadata
self.metadata = metadata or {}
self.metadata["content_type"] = "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
self.metadata["content_type"] = (
"video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
)
self.metadata_path = "{}.meta.json".format(path_base)
self.write_metadata()
@@ -179,7 +191,9 @@ class VideoRecorder(object):
def _encode_image_frame(self, frame):
if not self.encoder:
self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec)
self.encoder = ImageEncoder(
self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
)
self.metadata["encoder_version"] = self.encoder.version_info
try:
@@ -208,16 +222,24 @@ class TextEncoder(object):
string = frame.getvalue()
else:
raise error.InvalidFrame(
"Wrong type {} for {}: text frame must be a string or StringIO".format(type(frame), frame)
"Wrong type {} for {}: text frame must be a string or StringIO".format(
type(frame), frame
)
)
frame_bytes = string.encode("utf-8")
if frame_bytes[-1:] != b"\n":
raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string))
raise error.InvalidFrame(
'Frame must end with a newline: """{}"""'.format(string)
)
if b"\r" in frame_bytes:
raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string))
raise error.InvalidFrame(
'Frame contains carriage returns (only newlines are allowed: """{}"""'.format(
string
)
)
self.frames.append(frame_bytes)
@@ -241,7 +263,15 @@ class TextEncoder(object):
# Calculate frame size from the largest frames.
# Add some padding since we'll get cut off otherwise.
height = max([frame.count(b"\n") for frame in self.frames]) + 1
width = max([max([len(line) for line in frame.split(b"\n")]) for frame in self.frames]) + 2
width = (
max(
[
max([len(line) for line in frame.split(b"\n")])
for frame in self.frames
]
)
+ 2
)
data = {
"version": 1,
@@ -295,7 +325,11 @@ class ImageEncoder(object):
def version_info(self):
return {
"backend": self.backend,
"version": str(subprocess.check_output([self.backend, "-version"], stderr=subprocess.STDOUT)),
"version": str(
subprocess.check_output(
[self.backend, "-version"], stderr=subprocess.STDOUT
)
),
"cmdline": self.cmdline,
}
@@ -362,13 +396,19 @@ class ImageEncoder(object):
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
if hasattr(os, "setsid"): # setsid not present on Windows
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
self.proc = subprocess.Popen(
self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
)
else:
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
def capture_frame(self, frame):
if not isinstance(frame, (np.ndarray, np.generic)):
raise error.InvalidFrame("Wrong type {} for {} (must be np.ndarray or np.generic)".format(type(frame), frame))
raise error.InvalidFrame(
"Wrong type {} for {} (must be np.ndarray or np.generic)".format(
type(frame), frame
)
)
if frame.shape != self.frame_shape:
raise error.InvalidFrame(
"Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(
@@ -377,11 +417,15 @@ class ImageEncoder(object):
)
if frame.dtype != np.uint8:
raise error.InvalidFrame(
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype)
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(
frame.dtype
)
)
try:
if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion("1.9.0"):
if distutils.version.LooseVersion(
np.__version__
) >= distutils.version.LooseVersion("1.9.0"):
self.proc.stdin.write(frame.tobytes())
else:
self.proc.stdin.write(frame.tostring())

View File

@@ -14,7 +14,9 @@ STATE_KEY = "state"
class PixelObservationWrapper(ObservationWrapper):
"""Augment observations by pixel values."""
def __init__(self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)):
def __init__(
self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
):
"""Initializes a new pixel Wrapper.
Args:
@@ -70,7 +72,9 @@ class PixelObservationWrapper(ObservationWrapper):
# `observation_keys`
overlapping_keys = set(pixel_keys) & set(invalid_keys)
if overlapping_keys:
raise ValueError("Duplicate or reserved pixel keys {!r}.".format(overlapping_keys))
raise ValueError(
"Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)
)
if pixels_only:
self.observation_space = spaces.Dict()
@@ -93,7 +97,9 @@ class PixelObservationWrapper(ObservationWrapper):
else:
raise TypeError(pixels.dtype)
pixels_space = spaces.Box(shape=pixels.shape, low=low, high=high, dtype=pixels.dtype)
pixels_space = spaces.Box(
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
)
pixels_spaces[pixel_key] = pixels_space
self.observation_space.spaces.update(pixels_spaces)
@@ -116,7 +122,10 @@ class PixelObservationWrapper(ObservationWrapper):
observation = collections.OrderedDict()
observation[STATE_KEY] = wrapped_observation
pixel_observations = {pixel_key: self.env.render(**self._render_kwargs[pixel_key]) for pixel_key in self._pixel_keys}
pixel_observations = {
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
for pixel_key in self._pixel_keys
}
observation.update(pixel_observations)

View File

@@ -7,7 +7,9 @@ import gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support
self.t0 = (
time.time()
) # TODO: use perf_counter when gym removes Python 2 support
self.episode_return = 0.0
self.episode_length = 0
self.return_queue = deque(maxlen=deque_size)
@@ -23,7 +25,9 @@ class RecordEpisodeStatistics(gym.Wrapper):
return observation
def step(self, action):
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(action)
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
action
)
self.episode_return += reward
self.episode_length += 1
if done:

View File

@@ -15,7 +15,9 @@ class RescaleAction(gym.ActionWrapper):
"""
def __init__(self, env, a, b):
assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space))
assert isinstance(
env.action_space, spaces.Box
), "expected Box action space, got {}".format(type(env.action_space))
assert np.less_equal(a, b).all(), (a, b)
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
@@ -23,7 +25,9 @@ class RescaleAction(gym.ActionWrapper):
super(RescaleAction, self).__init__(env)
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
self.action_space = spaces.Box(
low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype
)
def action(self, action):
assert np.all(np.greater_equal(action, self.a)), (action, self.a)

View File

@@ -23,7 +23,9 @@ class ResizeObservation(ObservationWrapper):
def observation(self, observation):
import cv2
observation = cv2.resize(observation, self.shape[::-1], interpolation=cv2.INTER_AREA)
observation = cv2.resize(
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
)
if observation.ndim == 2:
observation = np.expand_dims(observation, -1)
return observation

View File

@@ -15,8 +15,12 @@ def test_atari_preprocessing_grayscale(env_fn):
import cv2
env1 = env_fn()
env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
env2 = AtariPreprocessing(
env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0
)
env3 = AtariPreprocessing(
env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0
)
env4 = AtariPreprocessing(
env_fn(),
screen_size=84,
@@ -75,11 +79,15 @@ def test_atari_preprocessing_scale(env_fn):
obs = env.reset().flatten()
done, step_i = False, 0
max_obs = 1 if scaled else 255
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
assert (0 <= obs).all() and (
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
while not done or step_i <= max_test_steps:
obs, _, done, _ = env.step(env.action_space.sample())
obs = obs.flatten()
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
assert (0 <= obs).all() and (
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
step_i += 1
env.close()

View File

@@ -19,7 +19,9 @@ def test_clip_action():
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
for action in actions:
obs1, r1, d1, _ = env.step(np.clip(action, env.action_space.low, env.action_space.high))
obs1, r1, d1, _ = env.step(
np.clip(action, env.action_space.low, env.action_space.high)
)
obs2, r2, d2, _ = wrapped_env.step(action)
assert np.allclose(r1, r2)
assert np.allclose(obs1, obs2)

View File

@@ -9,7 +9,10 @@ from gym.wrappers.filter_observation import FilterObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_keys=("state")):
self.observation_space = spaces.Dict(
{name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) for name in observation_keys}
{
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
for name in observation_keys
}
)
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
@@ -45,7 +48,9 @@ ERROR_TEST_CASES = (
class TestFilterObservation(object):
@pytest.mark.parametrize("observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES)
@pytest.mark.parametrize(
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
)
def test_filter_observation(self, observation_keys, filter_keys):
env = FakeEnvironment(observation_keys=observation_keys)
@@ -68,7 +73,9 @@ class TestFilterObservation(object):
assert len(observation) == len(filter_keys)
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
def test_raises_with_incorrect_arguments(self, filter_keys, error_type, error_match):
def test_raises_with_incorrect_arguments(
self, filter_keys, error_type, error_match
):
env = FakeEnvironment(observation_keys=("key1", "key2"))
ValueError

View File

@@ -16,10 +16,14 @@ def test_flatten_observation(env_id):
wrapped_obs = wrapped_env.reset()
if env_id == "Blackjack-v0":
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
elif env_id == "KellyCoinflip-v0":
space = spaces.Tuple((spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)))
space = spaces.Tuple(
(spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))
)
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
assert space.contains(obs)

View File

@@ -19,7 +19,9 @@ except ImportError:
[
pytest.param(
True,
marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression"),
marks=pytest.mark.skipif(
lz4 is None, reason="Need lz4 to run tests with compression"
),
),
False,
],

View File

@@ -10,7 +10,9 @@ pytest.importorskip("atari_py")
pytest.importorskip("cv2")
@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"])
@pytest.mark.parametrize(
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("keep_dim", [True, False])
def test_gray_scale_observation(env_id, keep_dim):
gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True)

View File

@@ -32,7 +32,9 @@ class FakeEnvironment(gym.Env):
class FakeArrayObservationEnvironment(FakeEnvironment):
def __init__(self, *args, **kwargs):
self.observation_space = spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
self.observation_space = spaces.Box(
shape=(2,), low=-1, high=1, dtype=np.float32
)
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
@@ -73,7 +75,10 @@ class TestPixelObservationWrapper(object):
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert len(wrapped_env.observation_space.spaces) == len(observation_space.spaces) + 1
assert (
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
)
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
@@ -92,7 +97,9 @@ class TestPixelObservationWrapper(object):
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Box)
wrapped_env = PixelObservationWrapper(env, pixel_keys=(pixel_key,), pixels_only=pixels_only)
wrapped_env = PixelObservationWrapper(
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
)
wrapped_env.observation_space = wrapped_env.observation_space
assert isinstance(wrapped_env.observation_space, spaces.Dict)

View File

@@ -9,8 +9,12 @@ except ImportError:
atari_py = None
@pytest.mark.skipif(atari_py is None, reason="Only run this test when atari_py is installed")
@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"])
@pytest.mark.skipif(
atari_py is None, reason="Only run this test when atari_py is installed"
)
@pytest.mark.parametrize(
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
def test_resize_observation(env_id, shape):
env = gym.make(env_id)

View File

@@ -10,7 +10,9 @@ from gym.wrappers import TransformObservation
def test_transform_observation(env_id):
affine_transform = lambda x: 3 * x + 2
env = gym.make(env_id)
wrapped_env = TransformObservation(gym.make(env_id), lambda obs: affine_transform(obs))
wrapped_env = TransformObservation(
gym.make(env_id), lambda obs: affine_transform(obs)
)
env.seed(0)
wrapped_env.seed(0)

Some files were not shown because too many files have changed in this diff Show More