redo black

This commit is contained in:
Justin Terry
2021-07-29 12:42:48 -04:00
parent d5004b7ec1
commit e9d2c41f2b
109 changed files with 459 additions and 1363 deletions

View File

@@ -3,9 +3,7 @@ import argparse
import gym
parser = argparse.ArgumentParser(
description="Renders a Gym environment for quick inspection."
)
parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.")
parser.add_argument(
"env_id",
type=str,

View File

@@ -35,12 +35,7 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0):
th_std = np.ones_like(th_mean) * initial_std
for _ in range(n_iter):
ths = np.array(
[
th_mean + dth
for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)
]
)
ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)])
ys = np.array([f(th) for th in ths])
elite_inds = ys.argsort()[::-1][:n_elite]
elite_ths = ths[elite_inds]
@@ -101,9 +96,7 @@ if __name__ == "__main__":
return rew
# Train the agent, and snapshot each stage
for (i, iterdata) in enumerate(
cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)
):
for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)):
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
if args.display:

View File

@@ -17,9 +17,7 @@ class RandomAgent(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=None)
parser.add_argument(
"env_id", nargs="?", default="CartPole-v0", help="Select the environment to run"
)
parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run")
args = parser.parse_args()
# You can set the level to logger.DEBUG or logger.WARN if you

View File

@@ -173,16 +173,10 @@ class GoalEnv(Env):
def reset(self):
# Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error(
"GoalEnv requires an observation space of type gym.spaces.Dict"
)
raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict")
for key in ["observation", "achieved_goal", "desired_goal"]:
if key not in self.observation_space.spaces:
raise error.Error(
'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(
key
)
)
raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))
def compute_reward(self, achieved_goal, desired_goal, info):
"""Compute the step reward. This externalizes the reward function and makes
@@ -227,9 +221,7 @@ class Wrapper(Env):
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.env, name)
@property

View File

@@ -422,9 +422,7 @@ for reward_type in ["sparse", "dense"]:
register(
id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
entry_point="gym.envs.robotics:HandBlockEnv",
kwargs=_merge(
{"target_position": "ignore", "target_rotation": "parallel"}, kwargs
),
kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs),
max_episode_steps=100,
)

View File

@@ -73,9 +73,7 @@ class AlgorithmicEnv(Env):
# 1. Move read head left or right (or up/down)
# 2. Write or not
# 3. Which character to write. (Ignored if should_write=0)
self.action_space = Tuple(
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
)
self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
# Can see just what is on the input tape (one of n characters, or
# nothing)
self.observation_space = Discrete(self.base + 1)
@@ -147,10 +145,7 @@ class AlgorithmicEnv(Env):
move = self.MOVEMENTS[inp_act]
outfile.write("Action : Tuple(move over input: %s,\n" % move)
out_act = out_act == 1
outfile.write(
" write to the output tape: %s,\n"
% out_act
)
outfile.write(" write to the output tape: %s,\n" % out_act)
outfile.write(" prediction: %s)\n" % pred_str)
else:
outfile.write("\n" * 5)
@@ -276,9 +271,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
x_str = "Observation Tape : "
for i in range(-2, self.input_width + 2):
if i == x:
x_str += colorize(
self._get_str_obs(np.array([i])), "green", highlight=True
)
x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
else:
x_str += self._get_str_obs(np.array([i]))
x_str += "\n"
@@ -311,10 +304,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
self.read_head_position = x, y
def generate_input_data(self, size):
return [
[self.np_random.randint(self.base) for _ in range(self.rows)]
for __ in range(size)
]
return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]
def _get_obs(self, pos=None):
if pos is None:
@@ -336,9 +326,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
x_str += " " * len(label)
for i in range(-2, self.input_width + 2):
if i == x[0] and j == x[1]:
x_str += colorize(
self._get_str_obs((i, j)), "green", highlight=True
)
x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
else:
x_str += self._get_str_obs((i, j))
x_str += "\n"

View File

@@ -10,12 +10,8 @@ ALL_ENVS = [
alg.reverse.ReverseEnv,
alg.reversed_addition.ReversedAdditionEnv,
]
ALL_TAPE_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)
]
ALL_GRID_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)
]
ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)]
ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)]
def imprint(env, input_arr):
@@ -92,10 +88,7 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
def test_grid_naviation(self):
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
N, S, E, W = [
env._movement_idx(named_dir)
for named_dir in ["up", "down", "right", "left"]
]
N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]]
# Corresponds to a grid that looks like...
# 0 1 2
# 3 4 5
@@ -204,9 +197,7 @@ class TestTargets(unittest.TestCase):
def test_repeat_copy_target(self):
env = alg.repeat_copy.RepeatCopyEnv()
self.assertEqual(
env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]
)
self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2])
class TestInputGeneration(unittest.TestCase):

View File

@@ -9,8 +9,7 @@ try:
import atari_py
except ImportError as e:
raise error.DependencyNotInstalled(
"{}. (HINT: you can install Atari dependencies by running "
"'pip install gym[atari]'.)".format(e)
"{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e)
)
@@ -64,35 +63,23 @@ class AtariEnv(gym.Env, utils.EzPickle):
# Tune (or disable) ALE's action repeat:
# https://github.com/openai/gym/issues/349
assert isinstance(
repeat_action_probability, (float, int)
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
self.ale.setFloat(
"repeat_action_probability".encode("utf-8"), repeat_action_probability
assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format(
repeat_action_probability
)
self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability)
self.seed()
self._action_set = (
self.ale.getLegalActionSet()
if full_action_space
else self.ale.getMinimalActionSet()
)
self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet()
self.action_space = spaces.Discrete(len(self._action_set))
(screen_width, screen_height) = self.ale.getScreenDims()
if self._obs_type == "ram":
self.observation_space = spaces.Box(
low=0, high=255, dtype=np.uint8, shape=(128,)
)
self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,))
elif self._obs_type == "image":
self.observation_space = spaces.Box(
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
)
self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8)
else:
raise error.Error(
"Unrecognized observation type: {}".format(self._obs_type)
)
raise error.Error("Unrecognized observation type: {}".format(self._obs_type))
def seed(self, seed=None):
self.np_random, seed1 = seeding.np_random(seed)
@@ -107,9 +94,9 @@ class AtariEnv(gym.Env, utils.EzPickle):
if self.game_mode is not None:
modes = self.ale.getAvailableModes()
assert self.game_mode in modes, (
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
).format(self.game_mode, self.game, modes)
assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format(
self.game_mode, self.game, modes
)
self.ale.setMode(self.game_mode)
if self.game_difficulty is not None:

View File

@@ -100,10 +100,7 @@ class ContactDetector(contactListener):
self.env = env
def BeginContact(self, contact):
if (
self.env.hull == contact.fixtureA.body
or self.env.hull == contact.fixtureB.body
):
if self.env.hull == contact.fixtureA.body or self.env.hull == contact.fixtureB.body:
self.env.game_over = True
for leg in [self.env.legs[1], self.env.legs[3]]:
if leg in [contact.fixtureA.body, contact.fixtureB.body]:
@@ -202,9 +199,7 @@ class BipedalWalker(gym.Env, EzPickle):
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
self.terrain.append(t)
self.fd_polygon.shape.vertices = [
(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly
]
self.fd_polygon.shape.vertices = [(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly]
t = self.world.CreateStaticBody(fixtures=self.fd_polygon)
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
self.terrain.append(t)
@@ -301,12 +296,8 @@ class BipedalWalker(gym.Env, EzPickle):
y = VIEWPORT_H / SCALE * 3 / 4
poly = [
(
x
+ 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5)
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
y
+ 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5)
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
x + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
y + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
)
for a in range(5)
]
@@ -331,14 +322,10 @@ class BipedalWalker(gym.Env, EzPickle):
init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2
init_y = TERRAIN_HEIGHT + 2 * LEG_H
self.hull = self.world.CreateDynamicBody(
position=(init_x, init_y), fixtures=HULL_FD
)
self.hull = self.world.CreateDynamicBody(position=(init_x, init_y), fixtures=HULL_FD)
self.hull.color1 = (0.5, 0.4, 0.9)
self.hull.color2 = (0.3, 0.3, 0.5)
self.hull.ApplyForceToCenter(
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
)
self.hull.ApplyForceToCenter((self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True)
self.legs = []
self.joints = []
@@ -412,21 +399,13 @@ class BipedalWalker(gym.Env, EzPickle):
self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1))
else:
self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0]))
self.joints[0].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1)
)
self.joints[0].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1))
self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1]))
self.joints[1].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1)
)
self.joints[1].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1))
self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2]))
self.joints[2].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1)
)
self.joints[2].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1))
self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3]))
self.joints[3].maxMotorTorque = float(
MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1)
)
self.joints[3].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1))
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
@@ -465,12 +444,8 @@ class BipedalWalker(gym.Env, EzPickle):
self.scroll = pos.x - VIEWPORT_W / SCALE / 5
shaping = (
130 * pos[0] / SCALE
) # moving forward is a way to receive reward (normalized to get 300 on completion)
shaping -= 5.0 * abs(
state[0]
) # keep head straight, other than that and falling, any behavior is unpunished
shaping = 130 * pos[0] / SCALE # moving forward is a way to receive reward (normalized to get 300 on completion)
shaping -= 5.0 * abs(state[0]) # keep head straight, other than that and falling, any behavior is unpunished
reward = 0
if self.prev_shaping is not None:
@@ -494,9 +469,7 @@ class BipedalWalker(gym.Env, EzPickle):
if self.viewer is None:
self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H)
self.viewer.set_bounds(
self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE
)
self.viewer.set_bounds(self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE)
self.viewer.draw_polygon(
[
@@ -512,9 +485,7 @@ class BipedalWalker(gym.Env, EzPickle):
continue
if x1 > self.scroll / 2 + VIEWPORT_W / SCALE:
continue
self.viewer.draw_polygon(
[(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1)
)
self.viewer.draw_polygon([(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1))
for poly, color in self.terrain_poly:
if poly[1][0] < self.scroll:
continue
@@ -525,11 +496,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.lidar_render = (self.lidar_render + 1) % 100
i = self.lidar_render
if i < 2 * len(self.lidar):
l = (
self.lidar[i]
if i < len(self.lidar)
else self.lidar[len(self.lidar) - i - 1]
)
l = self.lidar[i] if i < len(self.lidar) else self.lidar[len(self.lidar) - i - 1]
self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1)
for obj in self.drawlist:
@@ -537,12 +504,8 @@ class BipedalWalker(gym.Env, EzPickle):
trans = f.body.transform
if type(f.shape) is circleShape:
t = rendering.Transform(translation=trans * f.shape.pos)
self.viewer.draw_circle(
f.shape.radius, 30, color=obj.color1
).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2
).add_attr(t)
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t)
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t)
else:
path = [trans * v for v in f.shape.vertices]
self.viewer.draw_polygon(path, color=obj.color1)
@@ -552,9 +515,7 @@ class BipedalWalker(gym.Env, EzPickle):
flagy1 = TERRAIN_HEIGHT
flagy2 = flagy1 + 50 / SCALE
x = TERRAIN_STEP * 3
self.viewer.draw_polyline(
[(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2
)
self.viewer.draw_polyline([(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2)
f = [
(x, flagy2),
(x, flagy2 - 10 / SCALE),

View File

@@ -23,9 +23,7 @@ from Box2D.b2 import (
SIZE = 0.02
ENGINE_POWER = 100000000 * SIZE * SIZE
WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE
FRICTION_LIMIT = (
1000000 * SIZE * SIZE
) # friction ~= mass ~= size^2 (calculated implicitly using density)
FRICTION_LIMIT = 1000000 * SIZE * SIZE # friction ~= mass ~= size^2 (calculated implicitly using density)
WHEEL_R = 27
WHEEL_W = 14
WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)]
@@ -55,27 +53,19 @@ class Car:
angle=init_angle,
fixtures=[
fixtureDef(
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]
),
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]),
density=1.0,
),
fixtureDef(
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]
),
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]),
density=1.0,
),
fixtureDef(
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]
),
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]),
density=1.0,
),
fixtureDef(
shape=polygonShape(
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]
),
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]),
density=1.0,
),
],
@@ -95,12 +85,7 @@ class Car:
position=(init_x + wx * SIZE, init_y + wy * SIZE),
angle=init_angle,
fixtures=fixtureDef(
shape=polygonShape(
vertices=[
(x * front_k * SIZE, y * front_k * SIZE)
for x, y in WHEEL_POLY
]
),
shape=polygonShape(vertices=[(x * front_k * SIZE, y * front_k * SIZE) for x, y in WHEEL_POLY]),
density=0.1,
categoryBits=0x0020,
maskBits=0x001,
@@ -175,9 +160,7 @@ class Car:
grass = True
friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile
for tile in w.tiles:
friction_limit = max(
friction_limit, FRICTION_LIMIT * tile.road_friction
)
friction_limit = max(friction_limit, FRICTION_LIMIT * tile.road_friction)
grass = False
# Force
@@ -192,13 +175,7 @@ class Car:
# domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega
# add small coef not to divide by zero
w.omega += (
dt
* ENGINE_POWER
* w.gas
/ WHEEL_MOMENT_OF_INERTIA
/ (abs(w.omega) + 5.0)
)
w.omega += dt * ENGINE_POWER * w.gas / WHEEL_MOMENT_OF_INERTIA / (abs(w.omega) + 5.0)
self.fuel_spent += dt * ENGINE_POWER * w.gas
if w.brake >= 0.9:
@@ -226,18 +203,12 @@ class Car:
# Skid trace
if abs(force) > 2.0 * friction_limit:
if (
w.skid_particle
and w.skid_particle.grass == grass
and len(w.skid_particle.poly) < 30
):
if w.skid_particle and w.skid_particle.grass == grass and len(w.skid_particle.poly) < 30:
w.skid_particle.poly.append((w.position[0], w.position[1]))
elif w.skid_start is None:
w.skid_start = w.position
else:
w.skid_particle = self._create_particle(
w.skid_start, w.position, grass
)
w.skid_particle = self._create_particle(w.skid_start, w.position, grass)
w.skid_start = None
else:
w.skid_start = None

View File

@@ -132,18 +132,14 @@ class CarRacing(gym.Env, EzPickle):
self.reward = 0.0
self.prev_reward = 0.0
self.verbose = verbose
self.fd_tile = fixtureDef(
shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])
)
self.fd_tile = fixtureDef(shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)]))
self.action_space = spaces.Box(
np.array([-1, 0, 0]).astype(np.float32),
np.array([+1, +1, +1]).astype(np.float32),
) # steer, gas, brake
self.observation_space = spaces.Box(
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
)
self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
@@ -246,9 +242,7 @@ class CarRacing(gym.Env, EzPickle):
i -= 1
if i == 0:
return False # Failed
pass_through_start = (
track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
)
pass_through_start = track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
if pass_through_start and i2 == -1:
i2 = i
elif pass_through_start and i1 == -1:
@@ -266,8 +260,7 @@ class CarRacing(gym.Env, EzPickle):
first_perp_y = math.sin(first_beta)
# Length of perpendicular jump to put together head and tail
well_glued_together = np.sqrt(
np.square(first_perp_x * (track[0][2] - track[-1][2]))
+ np.square(first_perp_y * (track[0][3] - track[-1][3]))
np.square(first_perp_x * (track[0][2] - track[-1][2])) + np.square(first_perp_y * (track[0][3] - track[-1][3]))
)
if well_glued_together > TRACK_DETAIL_STEP:
return False
@@ -337,9 +330,7 @@ class CarRacing(gym.Env, EzPickle):
x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
)
self.road_poly.append(
([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0))
)
self.road_poly.append(([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0)))
self.track = track
return True
@@ -356,10 +347,7 @@ class CarRacing(gym.Env, EzPickle):
if success:
break
if self.verbose == 1:
print(
"retry to generate track (normal if there are not many"
"instances of this message)"
)
print("retry to generate track (normal if there are not many" "instances of this message)")
self.car = Car(self.world, *self.track[0][1:4])
return self.step(None)[0]
@@ -424,10 +412,8 @@ class CarRacing(gym.Env, EzPickle):
angle = math.atan2(vel[0], vel[1])
self.transform.set_scale(zoom, zoom)
self.transform.set_translation(
WINDOW_W / 2
- (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
WINDOW_H / 4
- (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
WINDOW_W / 2 - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
WINDOW_H / 4 - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
)
self.transform.set_rotation(angle)
@@ -449,9 +435,7 @@ class CarRacing(gym.Env, EzPickle):
else:
pixel_scale = 1
if hasattr(win.context, "_nscontext"):
pixel_scale = (
win.context._nscontext.view().backingScaleFactor()
) # pylint: disable=protected-access
pixel_scale = win.context._nscontext.view().backingScaleFactor() # pylint: disable=protected-access
VP_W = int(pixel_scale * WINDOW_W)
VP_H = int(pixel_scale * WINDOW_H)
@@ -468,9 +452,7 @@ class CarRacing(gym.Env, EzPickle):
win.flip()
return self.viewer.isopen
image_data = (
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
)
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
arr = arr.reshape(VP_H, VP_W, 4)
arr = arr[::-1, :, 0:3]
@@ -525,9 +507,7 @@ class CarRacing(gym.Env, EzPickle):
for p in poly:
polygons_.extend([p[0], p[1], 0])
vl = pyglet.graphics.vertex_list(
len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors) # gl.GL_QUADS,
)
vl = pyglet.graphics.vertex_list(len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)) # gl.GL_QUADS,
vl.draw(gl.GL_QUADS)
vl.delete()
@@ -575,10 +555,7 @@ class CarRacing(gym.Env, EzPickle):
]
)
true_speed = np.sqrt(
np.square(self.car.hull.linearVelocity[0])
+ np.square(self.car.hull.linearVelocity[1])
)
true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1]))
vertical_ind(5, 0.02 * true_speed, (1, 1, 1))
vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors
@@ -587,9 +564,7 @@ class CarRacing(gym.Env, EzPickle):
vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1))
horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0))
horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0))
vl = pyglet.graphics.vertex_list(
len(polygons) // 3, ("v3f", polygons), ("c4f", colors) # gl.GL_QUADS,
)
vl = pyglet.graphics.vertex_list(len(polygons) // 3, ("v3f", polygons), ("c4f", colors)) # gl.GL_QUADS,
vl.draw(gl.GL_QUADS)
vl.delete()
self.score_label.text = "%04i" % self.reward

View File

@@ -71,10 +71,7 @@ class ContactDetector(contactListener):
self.env = env
def BeginContact(self, contact):
if (
self.env.lander == contact.fixtureA.body
or self.env.lander == contact.fixtureB.body
):
if self.env.lander == contact.fixtureA.body or self.env.lander == contact.fixtureB.body:
self.env.game_over = True
for i in range(2):
if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
@@ -104,9 +101,7 @@ class LunarLander(gym.Env, EzPickle):
self.prev_reward = None
# useful range is -1 .. +1, but spikes can be higher
self.observation_space = spaces.Box(
-np.inf, np.inf, shape=(8,), dtype=np.float32
)
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32)
if self.continuous:
# Action is two floats [main engine, left-right engines].
@@ -157,14 +152,9 @@ class LunarLander(gym.Env, EzPickle):
height[CHUNKS // 2 + 0] = self.helipad_y
height[CHUNKS // 2 + 1] = self.helipad_y
height[CHUNKS // 2 + 2] = self.helipad_y
smooth_y = [
0.33 * (height[i - 1] + height[i + 0] + height[i + 1])
for i in range(CHUNKS)
]
smooth_y = [0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) for i in range(CHUNKS)]
self.moon = self.world.CreateStaticBody(
shapes=edgeShape(vertices=[(0, 0), (W, 0)])
)
self.moon = self.world.CreateStaticBody(shapes=edgeShape(vertices=[(0, 0), (W, 0)]))
self.sky_polys = []
for i in range(CHUNKS - 1):
p1 = (chunk_x[i], smooth_y[i])
@@ -180,9 +170,7 @@ class LunarLander(gym.Env, EzPickle):
position=(VIEWPORT_W / SCALE / 2, initial_y),
angle=0.0,
fixtures=fixtureDef(
shape=polygonShape(
vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]
),
shape=polygonShape(vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]),
density=5.0,
friction=0.1,
categoryBits=0x0010,
@@ -227,9 +215,7 @@ class LunarLander(gym.Env, EzPickle):
motorSpeed=+0.3 * i, # low enough not to jump back into the sky
)
if i == -1:
rjd.lowerAngle = (
+0.9 - 0.5
) # The most esoteric numbers here, angled legs have freedom to travel within
rjd.lowerAngle = +0.9 - 0.5 # The most esoteric numbers here, angled legs have freedom to travel within
rjd.upperAngle = +0.9
else:
rjd.lowerAngle = -0.9
@@ -278,9 +264,7 @@ class LunarLander(gym.Env, EzPickle):
dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]
m_power = 0.0
if (self.continuous and action[0] > 0.0) or (
not self.continuous and action == 2
):
if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2):
# Main engine
if self.continuous:
m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0
@@ -310,9 +294,7 @@ class LunarLander(gym.Env, EzPickle):
)
s_power = 0.0
if (self.continuous and np.abs(action[1]) > 0.5) or (
not self.continuous and action in [1, 3]
):
if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]):
# Orientation engines
if self.continuous:
direction = np.sign(action[1])
@@ -321,12 +303,8 @@ class LunarLander(gym.Env, EzPickle):
else:
direction = action - 2
s_power = 1.0
ox = tip[0] * dispersion[0] + side[0] * (
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
)
oy = -tip[1] * dispersion[0] - side[1] * (
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
)
ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
impulse_pos = (
self.lander.position[0] + ox - tip[0] * 17 / SCALE,
self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
@@ -372,9 +350,7 @@ class LunarLander(gym.Env, EzPickle):
reward = shaping - self.prev_shaping
self.prev_shaping = shaping
reward -= (
m_power * 0.30
) # less fuel spent is better, about -30 for heuristic landing
reward -= m_power * 0.30 # less fuel spent is better, about -30 for heuristic landing
reward -= s_power * 0.03
done = False
@@ -416,12 +392,8 @@ class LunarLander(gym.Env, EzPickle):
trans = f.body.transform
if type(f.shape) is circleShape:
t = rendering.Transform(translation=trans * f.shape.pos)
self.viewer.draw_circle(
f.shape.radius, 20, color=obj.color1
).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2
).add_attr(t)
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color1).add_attr(t)
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2).add_attr(t)
else:
path = [trans * v for v in f.shape.vertices]
self.viewer.draw_polygon(path, color=obj.color1)
@@ -479,18 +451,14 @@ def heuristic(env, s):
angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad
if angle_targ < -0.4:
angle_targ = -0.4
hover_targ = 0.55 * np.abs(
s[0]
) # target y should be proportional to horizontal offset
hover_targ = 0.55 * np.abs(s[0]) # target y should be proportional to horizontal offset
angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5
if s[6] or s[7]: # legs have contact
angle_todo = 0
hover_todo = (
-(s[3]) * 0.5
) # override to reduce fall speed, that's all we need after contact
hover_todo = -(s[3]) * 0.5 # override to reduce fall speed, that's all we need after contact
if env.continuous:
a = np.array([hover_todo * 20 - 1, -angle_todo * 20])

View File

@@ -88,9 +88,7 @@ class AcrobotEnv(core.Env):
def __init__(self):
self.viewer = None
high = np.array(
[1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32
)
high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32)
low = -high
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
self.action_space = spaces.Discrete(3)
@@ -111,9 +109,7 @@ class AcrobotEnv(core.Env):
# Add noise to the force action
if self.torque_noise_max > 0:
torque += self.np_random.uniform(
-self.torque_noise_max, self.torque_noise_max
)
torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max)
# Now, augment the state with our force action so it can be passed to
# _dsdt
@@ -160,12 +156,7 @@ class AcrobotEnv(core.Env):
theta2 = s[1]
dtheta1 = s[2]
dtheta2 = s[3]
d1 = (
m1 * lc1 ** 2
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2))
+ I1
+ I2
)
d1 = m1 * lc1 ** 2 + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2
d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2
phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
phi1 = (
@@ -181,9 +172,9 @@ class AcrobotEnv(core.Env):
else:
# the following line is consistent with the java implementation and the
# book
ddtheta2 = (
a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2) / (
m2 * lc2 ** 2 + I2 - d2 ** 2 / d1
)
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)

View File

@@ -111,9 +111,7 @@ class CartPoleEnv(gym.Env):
# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (
force + self.polemass_length * theta_dot ** 2 * sintheta
) / self.total_mass
temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
)

View File

@@ -62,27 +62,17 @@ class Continuous_MountainCarEnv(gym.Env):
self.min_position = -1.2
self.max_position = 0.6
self.max_speed = 0.07
self.goal_position = (
0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
)
self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
self.goal_velocity = goal_velocity
self.power = 0.0015
self.low_state = np.array(
[self.min_position, -self.max_speed], dtype=np.float32
)
self.high_state = np.array(
[self.max_position, self.max_speed], dtype=np.float32
)
self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)
self.viewer = None
self.action_space = spaces.Box(
low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32
)
self.observation_space = spaces.Box(
low=self.low_state, high=self.high_state, dtype=np.float32
)
self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32)
self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32)
self.seed()
self.reset()
@@ -159,15 +149,11 @@ class Continuous_MountainCarEnv(gym.Env):
self.viewer.add_geom(car)
frontwheel = rendering.make_circle(carheight / 2.5)
frontwheel.set_color(0.5, 0.5, 0.5)
frontwheel.add_attr(
rendering.Transform(translation=(carwidth / 4, clearance))
)
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
frontwheel.add_attr(self.cartrans)
self.viewer.add_geom(frontwheel)
backwheel = rendering.make_circle(carheight / 2.5)
backwheel.add_attr(
rendering.Transform(translation=(-carwidth / 4, clearance))
)
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
backwheel.add_attr(self.cartrans)
backwheel.set_color(0.5, 0.5, 0.5)
self.viewer.add_geom(backwheel)
@@ -176,16 +162,12 @@ class Continuous_MountainCarEnv(gym.Env):
flagy2 = flagy1 + 50
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
self.viewer.add_geom(flagpole)
flag = rendering.FilledPolygon(
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
)
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
flag.set_color(0.8, 0.8, 0)
self.viewer.add_geom(flag)
pos = self.state[0]
self.cartrans.set_translation(
(pos - self.min_position) * scale, self._height(pos) * scale
)
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
self.cartrans.set_rotation(math.cos(3 * pos))
return self.viewer.render(return_rgb_array=mode == "rgb_array")

View File

@@ -136,15 +136,11 @@ class MountainCarEnv(gym.Env):
self.viewer.add_geom(car)
frontwheel = rendering.make_circle(carheight / 2.5)
frontwheel.set_color(0.5, 0.5, 0.5)
frontwheel.add_attr(
rendering.Transform(translation=(carwidth / 4, clearance))
)
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
frontwheel.add_attr(self.cartrans)
self.viewer.add_geom(frontwheel)
backwheel = rendering.make_circle(carheight / 2.5)
backwheel.add_attr(
rendering.Transform(translation=(-carwidth / 4, clearance))
)
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
backwheel.add_attr(self.cartrans)
backwheel.set_color(0.5, 0.5, 0.5)
self.viewer.add_geom(backwheel)
@@ -153,16 +149,12 @@ class MountainCarEnv(gym.Env):
flagy2 = flagy1 + 50
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
self.viewer.add_geom(flagpole)
flag = rendering.FilledPolygon(
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
)
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
flag.set_color(0.8, 0.8, 0)
self.viewer.add_geom(flag)
pos = self.state[0]
self.cartrans.set_translation(
(pos - self.min_position) * scale, self._height(pos) * scale
)
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
self.cartrans.set_rotation(math.cos(3 * pos))
return self.viewer.render(return_rgb_array=mode == "rgb_array")

View File

@@ -18,9 +18,7 @@ class PendulumEnv(gym.Env):
self.viewer = None
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
self.action_space = spaces.Box(
low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32
)
self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
self.seed()
@@ -41,10 +39,7 @@ class PendulumEnv(gym.Env):
self.last_u = u # for rendering
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2)
newthdot = (
thdot
+ (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
)
newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
newth = th + newthdot * dt
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)

View File

@@ -54,11 +54,7 @@ def get_display(spec):
elif isinstance(spec, str):
return pyglet.canvas.Display(spec)
else:
raise error.Error(
"Invalid display specification: {}. (Must be a string like :0 or None.)".format(
spec
)
)
raise error.Error("Invalid display specification: {}. (Must be a string like :0 or None.)".format(spec))
def get_window(width, height, display, **kwargs):
@@ -69,14 +65,7 @@ def get_window(width, height, display, **kwargs):
config = screen[0].get_best_config() # selecting the first screen
context = config.create_context(None) # create GL context
return pyglet.window.Window(
width=width,
height=height,
display=display,
config=config,
context=context,
**kwargs
)
return pyglet.window.Window(width=width, height=height, display=display, config=config, context=context, **kwargs)
class Viewer(object):
@@ -108,9 +97,7 @@ class Viewer(object):
assert right > left and top > bottom
scalex = self.width / (right - left)
scaley = self.height / (top - bottom)
self.transform = Transform(
translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley)
)
self.transform = Transform(translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley))
def add_geom(self, geom):
self.geoms.append(geom)
@@ -173,9 +160,7 @@ class Viewer(object):
def get_array(self):
self.window.flip()
image_data = (
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
)
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
self.window.flip()
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
arr = arr.reshape(self.height, self.width, 4)
@@ -230,9 +215,7 @@ class Transform(Attr):
def enable(self):
glPushMatrix()
glTranslatef(
self.translation[0], self.translation[1], 0
) # translate to GL loc ppint
glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
glScalef(self.scale[0], self.scale[1], 1)
@@ -392,9 +375,7 @@ class Image(Geom):
self.flip = False
def render1(self):
self.img.blit(
-self.width / 2, -self.height / 2, width=self.width, height=self.height
)
self.img.blit(-self.width / 2, -self.height / 2, width=self.width, height=self.height)
# ================================================================
@@ -435,9 +416,7 @@ class SimpleImageViewer(object):
self.isopen = False
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData(
arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3
)
image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
texture = image.get_texture()
texture.width = self.width

View File

@@ -14,9 +14,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
xposafter = self.get_body_com("torso")[0]
forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = 0.5 * np.square(a).sum()
contact_cost = (
0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
)
contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
survive_reward = 1.0
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self.state_vector()
@@ -45,9 +43,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-0.1, high=0.1
)
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -34,18 +34,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@property
def healthy_reward(self):
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -60,9 +55,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@property
def contact_cost(self):
contact_cost = self._contact_cost_weight * np.sum(
np.square(self.contact_forces)
)
contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces))
return contact_cost
@property
@@ -128,12 +121,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
self.set_state(qpos, qvel)
observation = self._get_obs()

View File

@@ -28,9 +28,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nq
)
qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -25,9 +25,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@@ -71,12 +69,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
self.set_state(qpos, qvel)

View File

@@ -17,27 +17,16 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward += alive_bonus
reward -= 1e-3 * np.square(a).sum()
s = self.state_vector()
done = not (
np.isfinite(s).all()
and (np.abs(s[2:]) < 100).all()
and (height > 0.7)
and (abs(ang) < 0.2)
)
done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2))
ob = self._get_obs()
return ob, reward, done, {}
def _get_obs(self):
return np.concatenate(
[self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
)
return np.concatenate([self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)])
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -40,18 +40,13 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@property
def healthy_reward(self):
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -117,12 +112,8 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
self.set_state(qpos, qvel)

View File

@@ -43,18 +43,13 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@property
def healthy_reward(self):
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl))
@@ -143,12 +138,8 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
self.set_state(qpos, qvel)
observation = self._get_obs()

View File

@@ -33,8 +33,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self):
self.set_state(
self.init_qpos
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
)
return self._get_obs()

View File

@@ -17,12 +17,8 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, {}
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
size=self.model.nq, low=-0.01, high=0.01
)
qvel = self.init_qvel + self.np_random.uniform(
size=self.model.nv, low=-0.01, high=0.01
)
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -22,14 +22,7 @@ DEFAULT_SIZE = 500
def convert_observation_to_space(observation):
if isinstance(observation, dict):
space = spaces.Dict(
OrderedDict(
[
(key, convert_observation_to_space(value))
for key, value in observation.items()
]
)
)
space = spaces.Dict(OrderedDict([(key, convert_observation_to_space(value)) for key, value in observation.items()]))
elif isinstance(observation, np.ndarray):
low = np.full(observation.shape, -float("inf"), dtype=np.float32)
high = np.full(observation.shape, float("inf"), dtype=np.float32)
@@ -117,9 +110,7 @@ class MujocoEnv(gym.Env):
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
old_state = self.sim.get_state()
new_state = mujoco_py.MjSimState(
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
)
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state)
self.sim.set_state(new_state)
self.sim.forward()
@@ -142,10 +133,7 @@ class MujocoEnv(gym.Env):
):
if mode == "rgb_array" or mode == "depth_array":
if camera_id is not None and camera_name is not None:
raise ValueError(
"Both `camera_id` and `camera_name` cannot be"
" specified at the same time."
)
raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.")
no_camera_specified = camera_name is None and camera_id is None
if no_camera_specified:

View File

@@ -44,9 +44,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos[-4:-2] = self.cylinder_pos
qpos[-2:] = self.goal_pos
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qvel[-4:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -22,18 +22,13 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.viewer.cam.trackbodyid = 0
def reset_model(self):
qpos = (
self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
+ self.init_qpos
)
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
while True:
self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
if np.linalg.norm(self.goal) < 0.2:
break
qpos[-2:] = self.goal
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qvel[-2:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -62,9 +62,7 @@ class StrikerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
diff = self.ball - self.goal
angle = -np.arctan(diff[0] / (diff[1] + 1e-8))
qpos[-1] = angle / 3.14
qvel = self.init_qvel + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nv
)
qvel = self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv)
qvel[7:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -26,9 +26,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self):
self.set_state(
self.init_qpos
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
)
return self._get_obs()

View File

@@ -22,9 +22,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@@ -74,12 +72,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
self.set_state(qpos, qvel)

View File

@@ -48,9 +48,7 @@ class ThrowerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
)
qpos[-9:-7] = self.goal
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
qvel[7:] = 0
self.set_state(qpos, qvel)
return self._get_obs()

View File

@@ -27,10 +27,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self):
self.set_state(
self.init_qpos
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
self.init_qvel
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
)
return self._get_obs()

View File

@@ -37,18 +37,13 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@property
def healthy_reward(self):
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -110,12 +105,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
self.set_state(qpos, qvel)

View File

@@ -108,11 +108,7 @@ class EnvRegistry(object):
# reset/step. Set _gym_disable_underscore_compat = True on
# your environment if you use these methods and don't want
# compatibility code to be invoked.
if (
hasattr(env, "_reset")
and hasattr(env, "_step")
and not getattr(env, "_gym_disable_underscore_compat", False)
):
if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False):
patch_deprecated_methods(env)
if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit
@@ -158,11 +154,7 @@ class EnvRegistry(object):
if env_name == valid_env_spec._env_name
]
if matching_envs:
raise error.DeprecatedEnv(
"Env {} not found (valid versions include {})".format(
id, matching_envs
)
)
raise error.DeprecatedEnv("Env {} not found (valid versions include {})".format(id, matching_envs))
else:
raise error.UnregisteredEnv("No registered env with id: {}".format(id))

View File

@@ -81,9 +81,7 @@ class FetchEnv(robot_env.RobotEnv):
def _set_action(self, action):
assert action.shape == (4,)
action = (
action.copy()
) # ensure that we don't change the action outside of this scope
action = action.copy() # ensure that we don't change the action outside of this scope
pos_ctrl, gripper_ctrl = action[:3], action[3]
pos_ctrl *= 0.05 # limit maximum change in position
@@ -120,13 +118,9 @@ class FetchEnv(robot_env.RobotEnv):
object_rel_pos = object_pos - grip_pos
object_velp -= grip_velp
else:
object_pos = (
object_rot
) = object_velp = object_velr = object_rel_pos = np.zeros(0)
object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
gripper_state = robot_qpos[-2:]
gripper_vel = (
robot_qvel[-2:] * dt
) # change to a scalar if the gripper is made symmetric
gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric
if not self.has_object:
achieved_goal = grip_pos.copy()
@@ -175,9 +169,7 @@ class FetchEnv(robot_env.RobotEnv):
if self.has_object:
object_xpos = self.initial_gripper_xpos[:2]
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(
-self.obj_range, self.obj_range, size=2
)
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2)
object_qpos = self.sim.data.get_joint_qpos("object0:joint")
assert object_qpos.shape == (7,)
object_qpos[:2] = object_xpos
@@ -188,17 +180,13 @@ class FetchEnv(robot_env.RobotEnv):
def _sample_goal(self):
if self.has_object:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
-self.target_range, self.target_range, size=3
)
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
goal += self.target_offset
goal[2] = self.height_offset
if self.target_in_the_air and self.np_random.uniform() < 0.5:
goal[2] += self.np_random.uniform(0, 0.45)
else:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
-self.target_range, self.target_range, size=3
)
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
return goal.copy()
def _is_success(self, achieved_goal, desired_goal):
@@ -212,9 +200,9 @@ class FetchEnv(robot_env.RobotEnv):
self.sim.forward()
# Move end effector into position.
gripper_target = np.array(
[-0.498, 0.005, -0.431 + self.gripper_extra_height]
) + self.sim.data.get_site_xpos("robot0:grip")
gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos(
"robot0:grip"
)
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)

View File

@@ -74,9 +74,7 @@ class ManipulateEnv(hand_env.HandEnv):
self.target_position = target_position
self.target_rotation = target_rotation
self.target_position_range = target_position_range
self.parallel_quats = [
rotations.euler2quat(r) for r in rotations.get_parallel_rotations()
]
self.parallel_quats = [rotations.euler2quat(r) for r in rotations.get_parallel_rotations()]
self.randomize_initial_rotation = randomize_initial_rotation
self.randomize_initial_position = randomize_initial_position
self.distance_threshold = distance_threshold
@@ -182,9 +180,7 @@ class ManipulateEnv(hand_env.HandEnv):
angle = self.np_random.uniform(-np.pi, np.pi)
axis = np.array([0.0, 0.0, 1.0])
z_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats))
]
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
elif self.target_rotation in ["xyz", "ignore"]:
@@ -195,9 +191,7 @@ class ManipulateEnv(hand_env.HandEnv):
elif self.target_rotation == "fixed":
pass
else:
raise error.Error(
'Unknown target_rotation option "{}".'.format(self.target_rotation)
)
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
# Randomize initial position.
if self.randomize_initial_position:
@@ -229,17 +223,13 @@ class ManipulateEnv(hand_env.HandEnv):
target_pos = None
if self.target_position == "random":
assert self.target_position_range.shape == (3, 2)
offset = self.np_random.uniform(
self.target_position_range[:, 0], self.target_position_range[:, 1]
)
offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1])
assert offset.shape == (3,)
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
elif self.target_position in ["ignore", "fixed"]:
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
else:
raise error.Error(
'Unknown target_position option "{}".'.format(self.target_position)
)
raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
assert target_pos is not None
assert target_pos.shape == (3,)
@@ -253,9 +243,7 @@ class ManipulateEnv(hand_env.HandEnv):
angle = self.np_random.uniform(-np.pi, np.pi)
axis = np.array([0.0, 0.0, 1.0])
target_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[
self.np_random.randint(len(self.parallel_quats))
]
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
target_quat = rotations.quat_mul(target_quat, parallel_quat)
elif self.target_rotation == "xyz":
angle = self.np_random.uniform(-np.pi, np.pi)
@@ -264,9 +252,7 @@ class ManipulateEnv(hand_env.HandEnv):
elif self.target_rotation in ["ignore", "fixed"]:
target_quat = self.sim.data.get_joint_qpos("object:joint")
else:
raise error.Error(
'Unknown target_rotation option "{}".'.format(self.target_rotation)
)
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
assert target_quat is not None
assert target_quat.shape == (4,)
@@ -293,12 +279,8 @@ class ManipulateEnv(hand_env.HandEnv):
def _get_obs(self):
robot_qpos, robot_qvel = robot_get_obs(self.sim)
object_qvel = self.sim.data.get_joint_qvel("object:joint")
achieved_goal = (
self._get_achieved_goal().ravel()
) # this contains the object position + rotation
observation = np.concatenate(
[robot_qpos, robot_qvel, object_qvel, achieved_goal]
)
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, achieved_goal])
return {
"observation": observation.copy(),
"achieved_goal": achieved_goal.copy(),
@@ -307,9 +289,7 @@ class ManipulateEnv(hand_env.HandEnv):
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
def __init__(
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(
self,
@@ -322,9 +302,7 @@ class HandBlockEnv(ManipulateEnv, utils.EzPickle):
class HandEggEnv(ManipulateEnv, utils.EzPickle):
def __init__(
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(
self,
@@ -337,9 +315,7 @@ class HandEggEnv(ManipulateEnv, utils.EzPickle):
class HandPenEnv(ManipulateEnv, utils.EzPickle):
def __init__(
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(
self,

View File

@@ -70,16 +70,12 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
for (
k,
v,
) in (
self.sim.model._sensor_name2id.items()
): # get touch sensor site names and their ids
) in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
if "robot0:TS_" in k:
self._touch_sensor_id_site_id.append(
(
v,
self.sim.model._site_name2id[
k.replace("robot0:TS_", "robot0:T_")
],
self.sim.model._site_name2id[k.replace("robot0:TS_", "robot0:T_")],
)
)
self._touch_sensor_id.append(v)
@@ -93,15 +89,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
obs = self._get_obs()
self.observation_space = spaces.Dict(
dict(
desired_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
achieved_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
observation=spaces.Box(
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
),
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
)
)
@@ -117,9 +107,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
def _get_obs(self):
robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim)
object_qvel = self.sim.data.get_joint_qvel("object:joint")
achieved_goal = (
self._get_achieved_goal().ravel()
) # this contains the object position + rotation
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
touch_values = [] # get touch sensor readings. if there is one, set value to 1
if self.touch_get_obs == "sensordata":
touch_values = self.sim.data.sensordata[self._touch_sensor_id]
@@ -127,9 +115,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0
elif self.touch_get_obs == "log":
touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0)
observation = np.concatenate(
[robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal]
)
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal])
return {
"observation": observation.copy(),
@@ -146,9 +132,7 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata",
reward_type="sparse",
):
utils.EzPickle.__init__(
self, target_position, target_rotation, touch_get_obs, reward_type
)
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
ManipulateTouchSensorsEnv.__init__(
self,
model_path=MANIPULATE_BLOCK_XML,
@@ -168,9 +152,7 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata",
reward_type="sparse",
):
utils.EzPickle.__init__(
self, target_position, target_rotation, touch_get_obs, reward_type
)
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
ManipulateTouchSensorsEnv.__init__(
self,
model_path=MANIPULATE_EGG_XML,
@@ -190,9 +172,7 @@ class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata",
reward_type="sparse",
):
utils.EzPickle.__init__(
self, target_position, target_rotation, touch_get_obs, reward_type
)
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
ManipulateTouchSensorsEnv.__init__(
self,
model_path=MANIPULATE_PEN_XML,

View File

@@ -96,9 +96,7 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
self.sim.forward()
self.initial_goal = self._get_achieved_goal().copy()
self.palm_xpos = self.sim.data.body_xpos[
self.sim.model.body_name2id("robot0:palm")
].copy()
self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id("robot0:palm")].copy()
def _get_obs(self):
robot_qpos, robot_qvel = robot_get_obs(self.sim)
@@ -155,7 +153,5 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
for finger_idx in range(5):
site_name = "finger{}".format(finger_idx)
site_id = self.sim.model.site_name2id(site_name)
self.sim.model.site_pos[site_id] = (
achieved_goal[finger_idx] - sites_offset[site_id]
)
self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id]
self.sim.forward()

View File

@@ -30,22 +30,14 @@ class HandEnv(robot_env.RobotEnv):
if self.relative_control:
actuation_center = np.zeros_like(action)
for i in range(self.sim.data.ctrl.shape[0]):
actuation_center[i] = self.sim.data.get_joint_qpos(
self.sim.model.actuator_names[i].replace(":A_", ":")
)
actuation_center[i] = self.sim.data.get_joint_qpos(self.sim.model.actuator_names[i].replace(":A_", ":"))
for joint_name in ["FF", "MF", "RF", "LF"]:
act_idx = self.sim.model.actuator_name2id(
"robot0:A_{}J1".format(joint_name)
)
actuation_center[act_idx] += self.sim.data.get_joint_qpos(
"robot0:{}J0".format(joint_name)
)
act_idx = self.sim.model.actuator_name2id("robot0:A_{}J1".format(joint_name))
actuation_center[act_idx] += self.sim.data.get_joint_qpos("robot0:{}J0".format(joint_name))
else:
actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0
self.sim.data.ctrl[:] = actuation_center + action * actuation_range
self.sim.data.ctrl[:] = np.clip(
self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]
)
self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1])
def _viewer_setup(self):
body_id = self.sim.model.body_name2id("robot0:palm")

View File

@@ -46,15 +46,9 @@ class RobotEnv(gym.GoalEnv):
self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32")
self.observation_space = spaces.Dict(
dict(
desired_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
achieved_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
observation=spaces.Box(
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
),
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
)
)

View File

@@ -164,12 +164,8 @@ def mat2euler(mat):
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]),
)
euler[..., 1] = np.where(
condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)
)
euler[..., 0] = np.where(
condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0
)
euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy))
euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0)
return euler

View File

@@ -75,15 +75,9 @@ def reset_mocap2body_xpos(sim):
values as the bodies they're welded to.
"""
if (
sim.model.eq_type is None
or sim.model.eq_obj1id is None
or sim.model.eq_obj2id is None
):
if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None:
return
for eq_type, obj1_id, obj2_id in zip(
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id
):
for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id):
if eq_type != mujoco_py.const.EQ_WELD:
continue

View File

@@ -2,10 +2,7 @@ from gym import envs, logger
import os
SKIP_MUJOCO_WARNING_MESSAGE = (
"Cannot run mujoco test (either license key not found or mujoco not"
"installed properly)."
)
SKIP_MUJOCO_WARNING_MESSAGE = "Cannot run mujoco test (either license key not found or mujoco not" "installed properly)."
skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
@@ -21,9 +18,7 @@ def should_skip_env_spec_for_tests(spec):
# troublesome to run frequently
ep = spec.entry_point
# Skip mujoco tests for pull request CI
if skip_mujoco and (
ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")
):
if skip_mujoco and (ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")):
return True
try:
import atari_py
@@ -39,11 +34,7 @@ def should_skip_env_spec_for_tests(spec):
if (
"GoEnv" in ep
or "HexEnv" in ep
or (
ep.startswith("gym.envs.atari")
and not spec.id.startswith("Pong")
and not spec.id.startswith("Seaquest")
)
or (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
):
logger.warn("Skipping tests for env {}".format(ep))
return True

View File

@@ -25,9 +25,7 @@ def test_env(spec):
step_responses2 = [env2.step(action) for action in action_samples2]
env2.close()
for i, (action_sample1, action_sample2) in enumerate(
zip(action_samples1, action_samples2)
):
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
try:
assert_equals(action_sample1, action_sample2)
except AssertionError:
@@ -35,11 +33,7 @@ def test_env(spec):
print("env2.action_space=", env2.action_space)
print("action_samples1=", action_samples1)
print("action_samples2=", action_samples2)
print(
"[{}] action_sample1: {}, action_sample2: {}".format(
i, action_sample1, action_sample2
)
)
print("[{}] action_sample1: {}, action_sample2: {}".format(i, action_sample1, action_sample2))
raise
# Don't check rollout equality if it's a a nondeterministic
@@ -49,9 +43,7 @@ def test_env(spec):
assert_equals(initial_observation1, initial_observation2)
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
zip(step_responses1, step_responses2)
):
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
assert_equals(o1, o2, "[{}] ".format(i))
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
@@ -66,9 +58,7 @@ def test_env(spec):
def assert_equals(a, b, prefix=None):
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(
prefix, a, b
)
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b)
for k in a.keys():
v_a = a[k]

View File

@@ -24,9 +24,7 @@ def test_env(spec):
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
a = act_space.sample()
observation, reward, done, _info = env.step(a)
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(
observation
)
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(observation)
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)

View File

@@ -81,9 +81,7 @@ def test_env_semantics(spec):
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn(
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(
spec.id
)
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id)
)
return
@@ -100,21 +98,15 @@ def test_env_semantics(spec):
)
if rollout_dict[spec.id]["actions"] != actions_now:
errors.append(
"Actions not equal for {} -- expected {} but got {}".format(
spec.id, rollout_dict[spec.id]["actions"], actions_now
)
"Actions not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["actions"], actions_now)
)
if rollout_dict[spec.id]["rewards"] != rewards_now:
errors.append(
"Rewards not equal for {} -- expected {} but got {}".format(
spec.id, rollout_dict[spec.id]["rewards"], rewards_now
)
"Rewards not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["rewards"], rewards_now)
)
if rollout_dict[spec.id]["dones"] != dones_now:
errors.append(
"Dones not equal for {} -- expected {} but got {}".format(
spec.id, rollout_dict[spec.id]["dones"], dones_now
)
"Dones not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["dones"], dones_now)
)
if len(errors):
for error in errors:

View File

@@ -4,9 +4,7 @@ from gym import envs
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
def verify_environments_match(
old_environment_id, new_environment_id, seed=1, num_actions=1000
):
def verify_environments_match(old_environment_id, new_environment_id, seed=1, num_actions=1000):
old_environment = envs.make(old_environment_id)
new_environment = envs.make(new_environment_id)

View File

@@ -83,8 +83,6 @@ def test_malformed_lookup():
try:
registry.spec(u"“Breakout-v0”")
except error.Error as e:
assert "malformed environment ID" in "{}".format(
e
), "Unexpected message: {}".format(e)
assert "malformed environment ID" in "{}".format(e), "Unexpected message: {}".format(e)
else:
assert False

View File

@@ -75,9 +75,7 @@ class BlackjackEnv(gym.Env):
def __init__(self, natural=False):
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
self.observation_space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
self.seed()
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules

View File

@@ -141,9 +141,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
else:
if is_slippery:
for b in [(a - 1) % 4, a, (a + 1) % 4]:
li.append(
(1.0 / 3.0, *update_probability_matrix(row, col, b))
)
li.append((1.0 / 3.0, *update_probability_matrix(row, col, b)))
else:
li.append((1.0, *update_probability_matrix(row, col, a)))
@@ -157,9 +155,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
desc = [[c.decode("utf-8") for c in line] for line in desc]
desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
if self.lastaction is not None:
outfile.write(
" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction])
)
outfile.write(" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction]))
else:
outfile.write("\n")
outfile.write("\n".join("".join(line) for line in desc) + "\n")

View File

@@ -80,11 +80,7 @@ class GuessingGame(gym.Env):
reward = 0
done = False
if (
(self.number - self.range * 0.01)
< action
< (self.number + self.range * 0.01)
):
if (self.number - self.range * 0.01) < action < (self.number + self.range * 0.01):
reward = 1
done = True

View File

@@ -62,10 +62,7 @@ class HotterColder(gym.Env):
elif action > self.number:
self.observation = 3
reward = (
(min(action, self.number) + self.bounds)
/ (max(action, self.number) + self.bounds)
) ** 2
reward = ((min(action, self.number) + self.bounds) / (max(action, self.number) + self.bounds)) ** 2
self.guess_count += 1
done = self.guess_count >= self.guess_max

View File

@@ -65,9 +65,7 @@ class KellyCoinflipEnv(gym.Env):
return [seed]
def step(self, action):
bet_in_dollars = min(
action / 100.0, self.wealth
) # action = desired bet in pennies
bet_in_dollars = min(action / 100.0, self.wealth) # action = desired bet in pennies
self.rounds -= 1
coinflip = flip(self.edge, self.np_random)
@@ -149,35 +147,19 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
if self.clip_distributions:
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
max_wealth_bound = round(
genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m)
)
max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m))
max_wealth = max_wealth_bound + 1.0
while max_wealth > max_wealth_bound:
max_wealth = round(
genpareto.rvs(
max_wealth_alpha, max_wealth_m, random_state=self.np_random
)
)
max_rounds_bound = int(
round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd))
)
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)))
max_rounds = max_rounds_bound + 1
while max_rounds > max_rounds_bound:
max_rounds = int(
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
)
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
else:
max_wealth = round(
genpareto.rvs(
max_wealth_alpha, max_wealth_m, random_state=self.np_random
)
)
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
max_wealth_bound = max_wealth
max_rounds = int(
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
)
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
max_rounds_bound = max_rounds
# add an additional global variable which is the sufficient statistic for the
@@ -194,9 +176,7 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
self.action_space = spaces.Discrete(int(max_wealth_bound * 100))
self.observation_space = spaces.Tuple(
(
spaces.Box(
0, max_wealth_bound, shape=[1], dtype=np.float32
), # current wealth
spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth
spaces.Discrete(max_rounds_bound + 1), # rounds elapsed
spaces.Discrete(max_rounds_bound + 1), # wins
spaces.Discrete(max_rounds_bound + 1), # losses

View File

@@ -80,10 +80,7 @@ class TaxiEnv(discrete.DiscreteEnv):
max_col = num_columns - 1
initial_state_distrib = np.zeros(num_states)
num_actions = 6
P = {
state: {action: [] for action in range(num_actions)}
for state in range(num_states)
}
P = {state: {action: [] for action in range(num_actions)} for state in range(num_states)}
for row in range(num_rows):
for col in range(num_columns):
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
@@ -94,9 +91,7 @@ class TaxiEnv(discrete.DiscreteEnv):
for action in range(num_actions):
# defaults
new_row, new_col, new_pass_idx = row, col, pass_idx
reward = (
-1
) # default reward when there is no pickup/dropoff
reward = -1 # default reward when there is no pickup/dropoff
done = False
taxi_loc = (row, col)
@@ -122,14 +117,10 @@ class TaxiEnv(discrete.DiscreteEnv):
new_pass_idx = locs.index(taxi_loc)
else: # dropoff at wrong location
reward = -10
new_state = self.encode(
new_row, new_col, new_pass_idx, dest_idx
)
new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
P[state][action].append((1.0, new_state, reward, done))
initial_state_distrib /= initial_state_distrib.sum()
discrete.DiscreteEnv.__init__(
self, num_states, num_actions, P, initial_state_distrib
)
discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, initial_state_distrib)
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
# (5) 5, 5, 4
@@ -165,13 +156,9 @@ class TaxiEnv(discrete.DiscreteEnv):
return "_" if x == " " else x
if pass_idx < 4:
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True
)
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True)
pi, pj = self.locs[pass_idx]
out[1 + pi][2 * pj + 1] = utils.colorize(
out[1 + pi][2 * pj + 1], "blue", bold=True
)
out[1 + pi][2 * pj + 1] = utils.colorize(out[1 + pi][2 * pj + 1], "blue", bold=True)
else: # passenger in taxi
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
@@ -181,13 +168,7 @@ class TaxiEnv(discrete.DiscreteEnv):
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
outfile.write("\n".join(["".join(row) for row in out]) + "\n")
if self.lastaction is not None:
outfile.write(
" ({})\n".format(
["South", "North", "East", "West", "Pickup", "Dropoff"][
self.lastaction
]
)
)
outfile.write(" ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
else:
outfile.write("\n")

View File

@@ -55,9 +55,7 @@ class CubeCrash(gym.Env):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
self.action_space = spaces.Discrete(3)
self.reset()
@@ -83,16 +81,9 @@ class CubeCrash(gym.Env):
self.potential = None
self.step_n = 0
while 1:
self.wall_color = (
self.random_color() if self.use_random_colors else color_white
)
self.cube_color = (
self.random_color() if self.use_random_colors else color_green
)
if (
np.linalg.norm(self.wall_color - self.bg_color) < 50
or np.linalg.norm(self.cube_color - self.bg_color) < 50
):
self.wall_color = self.random_color() if self.use_random_colors else color_white
self.cube_color = self.random_color() if self.use_random_colors else color_green
if np.linalg.norm(self.wall_color - self.bg_color) < 50 or np.linalg.norm(self.cube_color - self.bg_color) < 50:
continue
break
return self.step(0)[0]
@@ -117,9 +108,7 @@ class CubeCrash(gym.Env):
self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1,
:,
] = self.bg_color
obs[
self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :
] = self.cube_color
obs[self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :] = self.cube_color
if self.use_black_screen and self.step_n > 4:
obs[:] = np.zeros((3,), dtype=np.uint8)

View File

@@ -62,16 +62,12 @@ class MemorizeDigits(gym.Env):
def __init__(self):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
self.action_space = spaces.Discrete(10)
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
for digit in range(10):
for y in range(6):
self.bogus_mnist[digit, y, :] = [
ord(char) for char in bogus_mnist[digit][y]
]
self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]]
self.reset()
def seed(self, seed=None):
@@ -93,9 +89,7 @@ class MemorizeDigits(gym.Env):
self.color_bg = self.random_color() if self.use_random_colors else color_black
self.step_n = 0
while 1:
self.color_digit = (
self.random_color() if self.use_random_colors else color_white
)
self.color_digit = self.random_color() if self.use_random_colors else color_white
if np.linalg.norm(self.color_digit - self.color_bg) < 50:
continue
break
@@ -119,9 +113,7 @@ class MemorizeDigits(gym.Env):
digit_img[:] = self.color_bg
xxx = self.bogus_mnist[self.digit] == 42
digit_img[xxx] = self.color_digit
obs[
self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
] = digit_img
obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img
self.last_obs = obs
return obs, reward, done, {}

View File

@@ -102,10 +102,7 @@ class APIError(Error):
try:
http_body = http_body.decode("utf-8")
except:
http_body = (
"<Could not decode body as utf-8. "
"Please report to gym@openai.com>"
)
http_body = "<Could not decode body as utf-8. " "Please report to gym@openai.com>"
self._message = message
self.http_body = http_body
@@ -142,9 +139,7 @@ class InvalidRequestError(APIError):
json_body=None,
headers=None,
):
super(InvalidRequestError, self).__init__(
message, http_body, http_status, json_body, headers
)
super(InvalidRequestError, self).__init__(message, http_body, http_status, json_body, headers)
self.param = param

View File

@@ -29,26 +29,16 @@ class Box(Space):
# determine shape if it isn't provided directly
if shape is not None:
shape = tuple(shape)
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match provided shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match provided shape"
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
elif not np.isscalar(low):
shape = low.shape
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match low.shape"
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
elif not np.isscalar(high):
shape = high.shape
assert (
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match high.shape"
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
else:
raise ValueError(
"shape must be provided or inferred from the shapes of low or high"
)
raise ValueError("shape must be provided or inferred from the shapes of low or high")
if np.isscalar(low):
low = np.full(shape, low, dtype=dtype)
@@ -70,9 +60,7 @@ class Box(Space):
high_precision = _get_precision(self.high.dtype)
dtype_precision = _get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision:
logger.warn(
"Box bound precision lowered by casting to {}".format(self.dtype)
)
logger.warn("Box bound precision lowered by casting to {}".format(self.dtype))
self.low = self.low.astype(self.dtype)
self.high = self.high.astype(self.dtype)
@@ -119,19 +107,11 @@ class Box(Space):
# Vectorized sampling by interval type
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
sample[low_bounded] = (
self.np_random.exponential(size=low_bounded[low_bounded].shape)
+ self.low[low_bounded]
)
sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded]
sample[upp_bounded] = (
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
+ self.high[upp_bounded]
)
sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded]
sample[bounded] = self.np_random.uniform(
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
)
sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape)
if self.dtype.kind == "i":
sample = np.floor(sample)
@@ -140,9 +120,7 @@ class Box(Space):
def contains(self, x):
if isinstance(x, list):
x = np.array(x) # Promote list to array for contains check
return (
x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
)
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
def to_jsonable(self, sample_n):
return np.array(sample_n).tolist()
@@ -151,9 +129,7 @@ class Box(Space):
return [np.asarray(sample) for sample in sample_n]
def __repr__(self):
return "Box({}, {}, {}, {})".format(
self.low.min(), self.high.max(), self.shape, self.dtype
)
return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype)
def __eq__(self, other):
return (

View File

@@ -33,9 +33,7 @@ class Dict(Space):
"""
def __init__(self, spaces=None, **spaces_kwargs):
assert (spaces is None) or (
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
assert (spaces is None) or (not spaces_kwargs), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
if spaces is None:
spaces = spaces_kwargs
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
@@ -44,12 +42,8 @@ class Dict(Space):
spaces = OrderedDict(spaces)
self.spaces = spaces
for space in spaces.values():
assert isinstance(
space, Space
), "Values of the dict should be instances of gym.Space"
super(Dict, self).__init__(
None, None
) # None for shape and dtype, since it'll require special handling
assert isinstance(space, Space), "Values of the dict should be instances of gym.Space"
super(Dict, self).__init__(None, None) # None for shape and dtype, since it'll require special handling
def seed(self, seed=None):
[space.seed(seed) for space in self.spaces.values()]
@@ -75,18 +69,11 @@ class Dict(Space):
yield key
def __repr__(self):
return (
"Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)
return "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
def to_jsonable(self, sample_n):
# serialize as dict-repr of vectors
return {
key: space.to_jsonable([sample[key] for sample in sample_n])
for key, space in self.spaces.items()
}
return {key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items()}
def from_jsonable(self, sample_n):
dict_of_list = {}

View File

@@ -22,9 +22,7 @@ class Discrete(Space):
def contains(self, x):
if isinstance(x, int):
as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and (
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
):
elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()):
as_int = int(x)
else:
return False

View File

@@ -34,9 +34,7 @@ class MultiDiscrete(Space):
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
def sample(self):
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
self.dtype
)
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x):
if isinstance(x, list):

View File

@@ -25,9 +25,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
],
@@ -71,9 +69,7 @@ def test_roundtripping(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
],

View File

@@ -28,9 +28,7 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, u
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
7,
@@ -60,18 +58,14 @@ def test_flatdim(space, flatdim):
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
],
)
def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(
type(flat_space), Box
)
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
flatdim = utils.flatdim(space)
(single_dim,) = flat_space.shape
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
@@ -95,9 +89,7 @@ def test_flatten_space_boxes(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
],
@@ -107,9 +99,7 @@ def test_flat_space_contains_flat_points(space):
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples):
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(
i, flat_sample, flat_space
)
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(i, flat_sample, flat_space)
@pytest.mark.parametrize(
@@ -130,9 +120,7 @@ def test_flat_space_contains_flat_points(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
],
@@ -162,9 +150,7 @@ def test_flatten_dim(space):
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
}
),
],
@@ -172,15 +158,9 @@ def test_flatten_dim(space):
def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
roundtripped_samples = [
utils.unflatten(space, sample) for sample in flattened_samples
]
for i, (original, roundtripped) in enumerate(
zip(some_samples, roundtripped_samples)
):
assert compare_nested(
original, roundtripped
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
assert compare_nested(original, roundtripped), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
def compare_nested(left, right):
@@ -188,9 +168,7 @@ def compare_nested(left, right):
return np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip(
left.items(), right.items()
):
for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
if not res:
return False
res = left_key == right_key and compare_nested(left_value, right_value)
@@ -238,9 +216,7 @@ Expecteded flattened types are based off:
Dict(
{
"position": Discrete(5),
"velocity": Box(
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16
),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16),
}
),
np.float64,
@@ -254,12 +230,8 @@ def test_dtypes(original_space, expected_flattened_dtype):
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains(
flattened_sample
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), "Expected flattened_space's dtype to equal " "{}".format(
assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " "{}".format(
expected_flattened_dtype
)
@@ -272,9 +244,10 @@ def test_dtypes(original_space, expected_flattened_dtype):
def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), (
"Expected unflattened_sample to be an int. unflattened_sample: "
"{} original_sample: {}".format(unflattened_sample, original_sample)
assert isinstance(
unflattened_sample, int
), "Expected unflattened_sample to be an int. unflattened_sample: " "{} original_sample: {}".format(
unflattened_sample, original_sample
)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):

View File

@@ -13,9 +13,7 @@ class Tuple(Space):
def __init__(self, spaces):
self.spaces = spaces
for space in spaces:
assert isinstance(
space, Space
), "Elements of the tuple must be instances of gym.Space"
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
super(Tuple, self).__init__(None, None)
def seed(self, seed=None):
@@ -38,21 +36,10 @@ class Tuple(Space):
def to_jsonable(self, sample_n):
# serialize as list-repr of tuple of vectors
return [
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)]
def from_jsonable(self, sample_n):
return [
sample
for sample in zip(
*[
space.from_jsonable(sample_n[i])
for i, space in enumerate(self.spaces)
]
)
]
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
def __getitem__(self, index):
return self.spaces[index]

View File

@@ -49,9 +49,7 @@ def flatten(space, x):
onehot[x] = 1
return onehot
elif isinstance(space, Tuple):
return np.concatenate(
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
)
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
elif isinstance(space, Dict):
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
@@ -79,17 +77,13 @@ def unflatten(space, x):
elif isinstance(space, Tuple):
dims = [flatdim(s) for s in space.spaces]
list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [
unflatten(s, flattened)
for flattened, s in zip(list_flattened, space.spaces)
]
list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)]
return tuple(list_unflattened)
elif isinstance(space, Dict):
dims = [flatdim(s) for s in space.spaces.values()]
list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [
(key, unflatten(s, flattened))
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
(key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
]
return OrderedDict(list_unflattened)
elif isinstance(space, MultiBinary):

View File

@@ -88,11 +88,7 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N
elif hasattr(env.unwrapped, "get_keys_to_action"):
keys_to_action = env.unwrapped.get_keys_to_action()
else:
assert False, (
env.spec.id
+ " does not have explicit key to action mapping, "
+ "please specify one manually"
)
assert False, env.spec.id + " does not have explicit key to action mapping, " + "please specify one manually"
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
video_size = [rendered.shape[1], rendered.shape[0]]
@@ -172,9 +168,7 @@ class PlayPlot(object):
for i, plot in enumerate(self.cur_plot):
if plot is not None:
plot.remove()
self.cur_plot[i] = self.ax[i].scatter(
range(xmin, xmax), list(self.data[i]), c="blue"
)
self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]), c="blue")
self.ax[i].set_xlim(xmin, xmax)
plt.pause(0.000001)

View File

@@ -10,9 +10,7 @@ from gym import error
def np_random(seed=None):
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(
"Seed must be a non-negative integer or omitted, not {}".format(seed)
)
raise error.Error("Seed must be a non-negative integer or omitted, not {}".format(seed))
seed = create_seed(seed)

View File

@@ -53,9 +53,7 @@ def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
if wrappers is not None:
if callable(wrappers):
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
elif isinstance(wrappers, Iterable) and all([callable(w) for w in wrappers]):
for wrapper in wrappers:
env = wrapper(env)
else:

View File

@@ -107,12 +107,8 @@ class AsyncVectorEnv(VectorEnv):
if self.shared_memory:
try:
_obs_buffer = create_shared_memory(
self.single_observation_space, n=self.num_envs, ctx=ctx
)
self.observations = read_from_shared_memory(
_obs_buffer, self.single_observation_space, n=self.num_envs
)
_obs_buffer = create_shared_memory(self.single_observation_space, n=self.num_envs, ctx=ctx)
self.observations = read_from_shared_memory(_obs_buffer, self.single_observation_space, n=self.num_envs)
except CustomSpaceError:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` "
@@ -124,9 +120,7 @@ class AsyncVectorEnv(VectorEnv):
)
else:
_obs_buffer = None
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue()
@@ -168,8 +162,7 @@ class AsyncVectorEnv(VectorEnv):
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `seed` while waiting "
"for a pending call to `{0}` to complete.".format(self._state.value),
"Calling `seed` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
self._state.value,
)
@@ -182,8 +175,7 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `reset_async` while waiting "
"for a pending call to `{0}` to complete".format(self._state.value),
"Calling `reset_async` while waiting " "for a pending call to `{0}` to complete".format(self._state.value),
self._state.value,
)
@@ -214,8 +206,7 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
"The call to `reset_wait` has timed out after "
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
"The call to `reset_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -223,9 +214,7 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
if not self.shared_memory:
self.observations = concatenate(
results, self.observations, self.single_observation_space
)
self.observations = concatenate(results, self.observations, self.single_observation_space)
return deepcopy(self.observations) if self.copy else self.observations
@@ -239,8 +228,7 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `step_async` while waiting "
"for a pending call to `{0}` to complete.".format(self._state.value),
"Calling `step_async` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
self._state.value,
)
@@ -280,8 +268,7 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
"The call to `step_wait` has timed out after "
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
"The call to `step_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -290,9 +277,7 @@ class AsyncVectorEnv(VectorEnv):
observations_list, rewards, dones, infos = zip(*results)
if not self.shared_memory:
self.observations = concatenate(
observations_list, self.observations, self.single_observation_space
)
self.observations = concatenate(observations_list, self.observations, self.single_observation_space)
return (
deepcopy(self.observations) if self.copy else self.observations,
@@ -318,8 +303,7 @@ class AsyncVectorEnv(VectorEnv):
try:
if self._state != AsyncState.DEFAULT:
logger.warn(
"Calling `close` while waiting for a pending "
"call to `{0}` to complete.".format(self._state.value)
"Calling `close` while waiting for a pending " "call to `{0}` to complete.".format(self._state.value)
)
function = getattr(self, "{0}_wait".format(self._state.value))
function(timeout)
@@ -375,8 +359,7 @@ class AsyncVectorEnv(VectorEnv):
def _assert_is_running(self):
if self.closed:
raise ClosedEnvironmentError(
"Trying to operate on `{0}`, after a "
"call to `close()`.".format(type(self).__name__)
"Trying to operate on `{0}`, after a " "call to `close()`.".format(type(self).__name__)
)
def _raise_if_errors(self, successes):
@@ -387,10 +370,7 @@ class AsyncVectorEnv(VectorEnv):
assert num_errors > 0
for _ in range(num_errors):
index, exctype, value = self.error_queue.get()
logger.error(
"Received the following error from Worker-{0}: "
"{1}: {2}".format(index, exctype.__name__, value)
)
logger.error("Received the following error from Worker-{0}: " "{1}: {2}".format(index, exctype.__name__, value))
logger.error("Shutting down Worker-{0}.".format(index))
self.parent_pipes[index].close()
self.parent_pipes[index] = None
@@ -445,17 +425,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
command, data = pipe.recv()
if command == "reset":
observation = env.reset()
write_to_shared_memory(
index, observation, shared_memory, observation_space
)
write_to_shared_memory(index, observation, shared_memory, observation_space)
pipe.send((None, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
observation = env.reset()
write_to_shared_memory(
index, observation, shared_memory, observation_space
)
write_to_shared_memory(index, observation, shared_memory, observation_space)
pipe.send(((None, reward, done, info), True))
elif command == "seed":
env.seed(data)

View File

@@ -44,9 +44,7 @@ class SyncVectorEnv(VectorEnv):
)
self._check_observation_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
@@ -67,9 +65,7 @@ class SyncVectorEnv(VectorEnv):
for env in self.envs:
observation = env.reset()
observations.append(observation)
self.observations = concatenate(
observations, self.observations, self.single_observation_space
)
self.observations = concatenate(observations, self.observations, self.single_observation_space)
return deepcopy(self.observations) if self.copy else self.observations
@@ -84,9 +80,7 @@ class SyncVectorEnv(VectorEnv):
observation = env.reset()
observations.append(observation)
infos.append(info)
self.observations = concatenate(
observations, self.observations, self.single_observation_space
)
self.observations = concatenate(observations, self.observations, self.single_observation_space)
return (
deepcopy(self.observations) if self.copy else self.observations,

View File

@@ -10,9 +10,7 @@ from gym.vector.tests.utils import spaces
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
def test_concatenate(space):
def assert_type(lhs, rhs, n):
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
@@ -53,9 +51,7 @@ def test_concatenate(space):
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
def test_create_empty_array(space, n):
def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces):
@@ -83,9 +79,7 @@ def test_create_empty_array(space, n):
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
def test_create_empty_array_zeros(space, n):
def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces):
@@ -113,9 +107,7 @@ def test_create_empty_array_zeros(space, n):
assert_nested_type(array, space, n=n)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
def test_create_empty_array_none_shape_ones(space):
def assert_nested_type(arr, space):
if isinstance(space, _BaseGymSpaces):

View File

@@ -46,9 +46,7 @@ expected_types = [
list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces],
)
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
def test_create_shared_memory(space, expected_type, n, ctx):
def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs)
@@ -77,9 +75,7 @@ def test_create_shared_memory(space, expected_type, n, ctx):
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
@pytest.mark.parametrize("space", custom_spaces)
def test_create_shared_memory_custom_space(n, ctx, space):
ctx = mp if (ctx is None) else mp.get_context(ctx)
@@ -87,9 +83,7 @@ def test_create_shared_memory_custom_space(n, ctx, space):
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
def test_write_to_shared_memory(space):
def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list)
@@ -113,9 +107,7 @@ def test_write_to_shared_memory(space):
shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
]
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
for process in processes:
process.start()
@@ -125,25 +117,19 @@ def test_write_to_shared_memory(space):
assert_nested_equal(shared_memory_n8, samples)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
def test_read_from_shared_memory(space):
def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list)
if isinstance(space, Tuple):
assert isinstance(lhs, tuple)
for i in range(len(lhs)):
assert_nested_equal(
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
)
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict)
for key in lhs.keys():
assert_nested_equal(
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
)
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n)
elif isinstance(space, _BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
@@ -161,9 +147,7 @@ def test_read_from_shared_memory(space):
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
]
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
for process in processes:
process.start()

View File

@@ -10,12 +10,8 @@ expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32),
Box(
low=np.array(
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]
),
high=np.array(
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
),
low=np.array([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]),
high=np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
dtype=np.float32,
),
Box(

View File

@@ -7,12 +7,8 @@ from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
spaces = [
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32),
Box(
low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32
),
Box(
low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32
),
Box(low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32),
Box(low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32),
Box(low=0, high=255, shape=(), dtype=np.uint8),
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Discrete(2),
@@ -28,17 +24,13 @@ spaces = [
Dict(
{
"position": Discrete(23),
"velocity": Box(
low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32
),
"velocity": Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32),
}
),
Dict(
{
"position": Dict({"x": Discrete(29), "y": Discrete(31)}),
"velocity": Tuple(
(Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))
),
"velocity": Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))),
}
),
]
@@ -50,9 +42,7 @@ class UnittestSlowEnv(gym.Env):
def __init__(self, slow_reset=0.3):
super(UnittestSlowEnv, self).__init__()
self.slow_reset = slow_reset
self.observation_space = Box(
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
)
self.observation_space = Box(low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
def reset(self):

View File

@@ -46,10 +46,7 @@ def concatenate(items, out, space):
elif isinstance(space, Space):
return concatenate_custom(items, out, space)
else:
raise ValueError(
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
def concatenate_base(items, out, space):
@@ -57,18 +54,12 @@ def concatenate_base(items, out, space):
def concatenate_tuple(items, out, space):
return tuple(
concatenate([item[i] for item in items], out[i], subspace)
for (i, subspace) in enumerate(space.spaces)
)
return tuple(concatenate([item[i] for item in items], out[i], subspace) for (i, subspace) in enumerate(space.spaces))
def concatenate_dict(items, out, space):
return OrderedDict(
[
(key, concatenate([item[key] for item in items], out[key], subspace))
for (key, subspace) in space.spaces.items()
]
[(key, concatenate([item[key] for item in items], out[key], subspace)) for (key, subspace) in space.spaces.items()]
)
@@ -118,10 +109,7 @@ def create_empty_array(space, n=1, fn=np.zeros):
elif isinstance(space, Space):
return create_empty_array_custom(space, n=n, fn=fn)
else:
raise ValueError(
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
def create_empty_array_base(space, n=1, fn=np.zeros):
@@ -134,12 +122,7 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros):
def create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
)
return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn)) for (key, subspace) in space.spaces.items()])
def create_empty_array_custom(space, n=1, fn=np.zeros):

View File

@@ -56,18 +56,11 @@ def create_base_shared_memory(space, n=1, ctx=mp):
def create_tuple_shared_memory(space, n=1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
return tuple(create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces)
def create_dict_shared_memory(space, n=1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
return OrderedDict([(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()])
def read_from_shared_memory(shared_memory, space, n=1):
@@ -114,24 +107,16 @@ def read_from_shared_memory(shared_memory, space, n=1):
def read_base_from_shared_memory(shared_memory, space, n=1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape((n,) + space.shape)
def read_tuple_from_shared_memory(shared_memory, space, n=1):
return tuple(
read_from_shared_memory(memory, subspace, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
return tuple(read_from_shared_memory(memory, subspace, n=n) for (memory, subspace) in zip(shared_memory, space.spaces))
def read_dict_from_shared_memory(shared_memory, space, n=1):
return OrderedDict(
[
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
for (key, subspace) in space.spaces.items()
]
[(key, read_from_shared_memory(shared_memory[key], subspace, n=n)) for (key, subspace) in space.spaces.items()]
)

View File

@@ -44,8 +44,7 @@ def batch_space(space, n=1):
return batch_space_custom(space, n=n)
else:
raise ValueError(
"Cannot batch space with type `{0}`. The space must "
"be a valid `gym.Space` instance.".format(type(space))
"Cannot batch space with type `{0}`. The space must " "be a valid `gym.Space` instance.".format(type(space))
)
@@ -75,14 +74,7 @@ def batch_space_tuple(space, n=1):
def batch_space_dict(space, n=1):
return Dict(
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
)
)
return Dict(OrderedDict([(key, batch_space(subspace, n=n)) for (key, subspace) in space.spaces.items()]))
def batch_space_custom(space, n=1):

View File

@@ -141,9 +141,7 @@ class VectorEnv(gym.Env):
if self.spec is None:
return "{}({})".format(self.__class__.__name__, self.num_envs)
else:
return "{}({}, {})".format(
self.__class__.__name__, self.spec.id, self.num_envs
)
return "{}({}, {})".format(self.__class__.__name__, self.spec.id, self.num_envs)
class VectorEnvWrapper(VectorEnv):
@@ -189,9 +187,7 @@ class VectorEnvWrapper(VectorEnv):
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.env, name)
@property

View File

@@ -62,12 +62,13 @@ class AtariPreprocessing(gym.Wrapper):
assert noop_max >= 0
if frame_skip > 1:
assert "NoFrameskip" in env.spec.id, (
"disable frame-skipping in the original env. for more than one"
" frame-skip as it will be done by the wrapper"
"disable frame-skipping in the original env. for more than one" " frame-skip as it will be done by the wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.frame_skip = frame_skip
self.screen_size = screen_size
@@ -92,15 +93,11 @@ class AtariPreprocessing(gym.Wrapper):
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
def step(self, action):
R = 0.0
@@ -132,11 +129,7 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs):
# NoopReset
self.env.reset(**kwargs)
noops = (
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
for _ in range(noops):
_, _, done, _ = self.env.step(0)
if done:

View File

@@ -9,7 +9,9 @@ class ClipAction(ActionWrapper):
def __init__(self, env):
assert isinstance(env.action_space, Box)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
super(ClipAction, self).__init__(env)
def action(self, action):

View File

@@ -26,7 +26,9 @@ class FilterObservation(ObservationWrapper):
assert isinstance(
wrapped_observation_space, spaces.Dict
), "FilterObservationWrapper is only usable with dict observations."
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
observation_keys = wrapped_observation_space.spaces.keys()
@@ -49,11 +51,7 @@ class FilterObservation(ObservationWrapper):
)
self.observation_space = type(wrapped_observation_space)(
[
(name, copy.deepcopy(space))
for name, space in wrapped_observation_space.spaces.items()
if name in filter_keys
]
[(name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys]
)
self._env = env
@@ -64,11 +62,5 @@ class FilterObservation(ObservationWrapper):
return filter_observation
def _filter_observation(self, observation):
observation = type(observation)(
[
(name, value)
for name, value in observation.items()
if name in self._filter_keys
]
)
observation = type(observation)([(name, value) for name, value in observation.items() if name in self._filter_keys])
return observation

View File

@@ -9,7 +9,9 @@ class FlattenObservation(ObservationWrapper):
def __init__(self, env):
super(FlattenObservation, self).__init__(env)
self.observation_space = spaces.flatten_space(env.observation_space)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
def observation(self, observation):
return spaces.flatten(self.env.observation_space, observation)

View File

@@ -22,7 +22,9 @@ class LazyFrames(object):
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
def __init__(self, frames, lz4_compress=False):
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape
self.dtype = frames[0].dtype
@@ -45,9 +47,7 @@ class LazyFrames(object):
def __getitem__(self, int_or_slice):
if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame
return np.stack(
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0)
def __eq__(self, other):
return self.__array__() == other
@@ -56,9 +56,7 @@ class LazyFrames(object):
if self.lz4_compress:
from lz4.block import decompress
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
self.frame_shape
)
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape)
return frame
@@ -102,12 +100,8 @@ class FrameStack(Wrapper):
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
def _get_observation(self):
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)

View File

@@ -11,20 +11,15 @@ class GrayScaleObservation(ObservationWrapper):
super(GrayScaleObservation, self).__init__(env)
self.keep_dim = keep_dim
assert (
len(env.observation_space.shape) == 3
and env.observation_space.shape[-1] == 3
assert len(env.observation_space.shape) == 3 and env.observation_space.shape[-1] == 3
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
obs_shape = self.observation_space.shape[:2]
if self.keep_dim:
self.observation_space = Box(
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
)
self.observation_space = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8)
else:
self.observation_space = Box(
low=0, high=255, shape=obs_shape, dtype=np.uint8
)
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
def observation(self, observation):
import cv2

View File

@@ -37,9 +37,7 @@ class Monitor(Wrapper):
self._monitor_id = None
self.env_semantics_autoreset = env.metadata.get("semantics.autoreset")
self._start(
directory, video_callable, force, resume, write_upon_reset, uid, mode
)
self._start(directory, video_callable, force, resume, write_upon_reset, uid, mode)
def step(self, action):
self._before_step(action)
@@ -163,10 +161,7 @@ class Monitor(Wrapper):
json.dump(
{
"stats": os.path.basename(self.stats_recorder.path),
"videos": [
(os.path.basename(v), os.path.basename(m))
for v, m in self.videos
],
"videos": [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos],
"env_info": self._env_info(),
},
f,
@@ -199,9 +194,7 @@ class Monitor(Wrapper):
elif mode == "training":
type = "t"
else:
raise error.Error(
'Invalid mode {}: must be "training" or "evaluation"', mode
)
raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
self.stats_recorder.type = type
def _before_step(self, action):
@@ -257,9 +250,7 @@ class Monitor(Wrapper):
env=self.env,
base_path=os.path.join(
self.directory,
"{}.video.{}.video{:06}".format(
self.file_prefix, self.file_infix, self.episode_id
),
"{}.video.{}.video{:06}".format(self.file_prefix, self.file_infix, self.episode_id),
),
metadata={"episode_id": self.episode_id},
enabled=self._video_enabled(),
@@ -269,9 +260,7 @@ class Monitor(Wrapper):
def _close_video_recorder(self):
self.video_recorder.close()
if self.video_recorder.functional:
self.videos.append(
(self.video_recorder.path, self.video_recorder.metadata_path)
)
self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
def _video_enabled(self):
return self.video_callable(self.episode_id)
@@ -301,19 +290,11 @@ class Monitor(Wrapper):
def detect_training_manifests(training_dir, files=None):
if files is None:
files = os.listdir(training_dir)
return [
os.path.join(training_dir, f)
for f in files
if f.startswith(MANIFEST_PREFIX + ".")
]
return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + ".")]
def detect_monitor_files(training_dir):
return [
os.path.join(training_dir, f)
for f in os.listdir(training_dir)
if f.startswith(FILE_PREFIX + ".")
]
return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + ".")]
def clear_monitor_files(training_dir):
@@ -382,10 +363,7 @@ def load_results(training_dir):
contents = json.load(f)
# Make these paths absolute again
stats_files.append(os.path.join(training_dir, contents["stats"]))
videos += [
(os.path.join(training_dir, v), os.path.join(training_dir, m))
for v, m in contents["videos"]
]
videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m)) for v, m in contents["videos"]]
env_infos.append(contents["env_info"])
env_info = collapse_env_infos(env_infos, training_dir)

View File

@@ -48,9 +48,7 @@ class VideoRecorder(object):
self.ansi_mode = True
else:
logger.info(
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(
env
)
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env)
)
# Whoops, turns out we shouldn't be enabled after all
self.enabled = False
@@ -69,9 +67,7 @@ class VideoRecorder(object):
path = base_path + required_ext
else:
# Otherwise, just generate a unique filename
with tempfile.NamedTemporaryFile(
suffix=required_ext, delete=False
) as f:
with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f:
path = f.name
self.path = path
@@ -83,28 +79,20 @@ class VideoRecorder(object):
if self.ansi_mode
else ""
)
raise error.Error(
"Invalid path given: {} -- must have file extension {}.{}".format(
self.path, required_ext, hint
)
)
raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint))
# Touch the file in any case, so we know it's present. (This
# corrects for platform platform differences. Using ffmpeg on
# OS X, the file is precreated, but not on Linux.
touch(path)
self.frames_per_sec = env.metadata.get("video.frames_per_second", 30)
self.output_frames_per_sec = env.metadata.get(
"video.output_frames_per_second", self.frames_per_sec
)
self.output_frames_per_sec = env.metadata.get("video.output_frames_per_second", self.frames_per_sec)
self.encoder = None # lazily start the process
self.broken = False
# Dump metadata
self.metadata = metadata or {}
self.metadata["content_type"] = (
"video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
)
self.metadata["content_type"] = "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
self.metadata_path = "{}.meta.json".format(path_base)
self.write_metadata()
@@ -191,9 +179,7 @@ class VideoRecorder(object):
def _encode_image_frame(self, frame):
if not self.encoder:
self.encoder = ImageEncoder(
self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
)
self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec)
self.metadata["encoder_version"] = self.encoder.version_info
try:
@@ -222,24 +208,16 @@ class TextEncoder(object):
string = frame.getvalue()
else:
raise error.InvalidFrame(
"Wrong type {} for {}: text frame must be a string or StringIO".format(
type(frame), frame
)
"Wrong type {} for {}: text frame must be a string or StringIO".format(type(frame), frame)
)
frame_bytes = string.encode("utf-8")
if frame_bytes[-1:] != b"\n":
raise error.InvalidFrame(
'Frame must end with a newline: """{}"""'.format(string)
)
raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string))
if b"\r" in frame_bytes:
raise error.InvalidFrame(
'Frame contains carriage returns (only newlines are allowed: """{}"""'.format(
string
)
)
raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string))
self.frames.append(frame_bytes)
@@ -263,15 +241,7 @@ class TextEncoder(object):
# Calculate frame size from the largest frames.
# Add some padding since we'll get cut off otherwise.
height = max([frame.count(b"\n") for frame in self.frames]) + 1
width = (
max(
[
max([len(line) for line in frame.split(b"\n")])
for frame in self.frames
]
)
+ 2
)
width = max([max([len(line) for line in frame.split(b"\n")]) for frame in self.frames]) + 2
data = {
"version": 1,
@@ -325,11 +295,7 @@ class ImageEncoder(object):
def version_info(self):
return {
"backend": self.backend,
"version": str(
subprocess.check_output(
[self.backend, "-version"], stderr=subprocess.STDOUT
)
),
"version": str(subprocess.check_output([self.backend, "-version"], stderr=subprocess.STDOUT)),
"cmdline": self.cmdline,
}
@@ -396,19 +362,13 @@ class ImageEncoder(object):
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
if hasattr(os, "setsid"): # setsid not present on Windows
self.proc = subprocess.Popen(
self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
)
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
else:
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
def capture_frame(self, frame):
if not isinstance(frame, (np.ndarray, np.generic)):
raise error.InvalidFrame(
"Wrong type {} for {} (must be np.ndarray or np.generic)".format(
type(frame), frame
)
)
raise error.InvalidFrame("Wrong type {} for {} (must be np.ndarray or np.generic)".format(type(frame), frame))
if frame.shape != self.frame_shape:
raise error.InvalidFrame(
"Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(
@@ -417,15 +377,11 @@ class ImageEncoder(object):
)
if frame.dtype != np.uint8:
raise error.InvalidFrame(
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(
frame.dtype
)
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype)
)
try:
if distutils.version.LooseVersion(
np.__version__
) >= distutils.version.LooseVersion("1.9.0"):
if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion("1.9.0"):
self.proc.stdin.write(frame.tobytes())
else:
self.proc.stdin.write(frame.tostring())

View File

@@ -14,9 +14,7 @@ STATE_KEY = "state"
class PixelObservationWrapper(ObservationWrapper):
"""Augment observations by pixel values."""
def __init__(
self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
):
def __init__(self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)):
"""Initializes a new pixel Wrapper.
Args:
@@ -52,7 +50,9 @@ class PixelObservationWrapper(ObservationWrapper):
assert render_mode == "rgb_array", render_mode
render_kwargs[key]["mode"] = "rgb_array"
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
wrapped_observation_space = env.observation_space
@@ -70,9 +70,7 @@ class PixelObservationWrapper(ObservationWrapper):
# `observation_keys`
overlapping_keys = set(pixel_keys) & set(invalid_keys)
if overlapping_keys:
raise ValueError(
"Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)
)
raise ValueError("Duplicate or reserved pixel keys {!r}.".format(overlapping_keys))
if pixels_only:
self.observation_space = spaces.Dict()
@@ -95,9 +93,7 @@ class PixelObservationWrapper(ObservationWrapper):
else:
raise TypeError(pixels.dtype)
pixels_space = spaces.Box(
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
)
pixels_space = spaces.Box(shape=pixels.shape, low=low, high=high, dtype=pixels.dtype)
pixels_spaces[pixel_key] = pixels_space
self.observation_space.spaces.update(pixels_spaces)
@@ -120,10 +116,7 @@ class PixelObservationWrapper(ObservationWrapper):
observation = collections.OrderedDict()
observation[STATE_KEY] = wrapped_observation
pixel_observations = {
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
for pixel_key in self._pixel_keys
}
pixel_observations = {pixel_key: self.env.render(**self._render_kwargs[pixel_key]) for pixel_key in self._pixel_keys}
observation.update(pixel_observations)

View File

@@ -7,14 +7,14 @@ import gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.t0 = (
time.time()
) # TODO: use perf_counter when gym removes Python 2 support
self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support
self.episode_return = 0.0
self.episode_length = 0
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
def reset(self, **kwargs):
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
@@ -23,9 +23,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
return observation
def step(self, action):
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
action
)
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(action)
self.episode_return += reward
self.episode_length += 1
if done:

View File

@@ -15,17 +15,15 @@ class RescaleAction(gym.ActionWrapper):
"""
def __init__(self, env, a, b):
assert isinstance(
env.action_space, spaces.Box
), "expected Box action space, got {}".format(type(env.action_space))
assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space))
assert np.less_equal(a, b).all(), (a, b)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
super(RescaleAction, self).__init__(env)
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
self.action_space = spaces.Box(
low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype
)
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
def action(self, action):
assert np.all(np.greater_equal(action, self.a)), (action, self.a)

View File

@@ -12,7 +12,9 @@ class ResizeObservation(ObservationWrapper):
if isinstance(shape, int):
shape = (shape, shape)
assert all(x > 0 for x in shape), shape
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.shape = tuple(shape)
obs_shape = self.shape + self.observation_space.shape[2:]
@@ -21,9 +23,7 @@ class ResizeObservation(ObservationWrapper):
def observation(self, observation):
import cv2
observation = cv2.resize(
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
)
observation = cv2.resize(observation, self.shape[::-1], interpolation=cv2.INTER_AREA)
if observation.ndim == 2:
observation = np.expand_dims(observation, -1)
return observation

View File

@@ -15,12 +15,8 @@ def test_atari_preprocessing_grayscale(env_fn):
import cv2
env1 = env_fn()
env2 = AtariPreprocessing(
env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0
)
env3 = AtariPreprocessing(
env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0
)
env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
env4 = AtariPreprocessing(
env_fn(),
screen_size=84,
@@ -79,15 +75,11 @@ def test_atari_preprocessing_scale(env_fn):
obs = env.reset().flatten()
done, step_i = False, 0
max_obs = 1 if scaled else 255
assert (0 <= obs).all() and (
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
while not done or step_i <= max_test_steps:
obs, _, done, _ = env.step(env.action_space.sample())
obs = obs.flatten()
assert (0 <= obs).all() and (
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
step_i += 1
env.close()

View File

@@ -19,9 +19,7 @@ def test_clip_action():
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
for action in actions:
obs1, r1, d1, _ = env.step(
np.clip(action, env.action_space.low, env.action_space.high)
)
obs1, r1, d1, _ = env.step(np.clip(action, env.action_space.low, env.action_space.high))
obs2, r2, d2, _ = wrapped_env.step(action)
assert np.allclose(r1, r2)
assert np.allclose(obs1, obs2)

View File

@@ -9,10 +9,7 @@ from gym.wrappers.filter_observation import FilterObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_keys=("state")):
self.observation_space = spaces.Dict(
{
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
for name in observation_keys
}
{name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) for name in observation_keys}
)
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
@@ -48,9 +45,7 @@ ERROR_TEST_CASES = (
class TestFilterObservation(object):
@pytest.mark.parametrize(
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
)
@pytest.mark.parametrize("observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES)
def test_filter_observation(self, observation_keys, filter_keys):
env = FakeEnvironment(observation_keys=observation_keys)
@@ -73,9 +68,7 @@ class TestFilterObservation(object):
assert len(observation) == len(filter_keys)
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
def test_raises_with_incorrect_arguments(
self, filter_keys, error_type, error_match
):
def test_raises_with_incorrect_arguments(self, filter_keys, error_type, error_match):
env = FakeEnvironment(observation_keys=("key1", "key2"))
ValueError

View File

@@ -16,14 +16,10 @@ def test_flatten_observation(env_id):
wrapped_obs = wrapped_env.reset()
if env_id == "Blackjack-v0":
space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
elif env_id == "KellyCoinflip-v0":
space = spaces.Tuple(
(spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))
)
space = spaces.Tuple((spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)))
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
assert space.contains(obs)

View File

@@ -19,9 +19,7 @@ except ImportError:
[
pytest.param(
True,
marks=pytest.mark.skipif(
lz4 is None, reason="Need lz4 to run tests with compression"
),
marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression"),
),
False,
],

View File

@@ -10,9 +10,7 @@ pytest.importorskip("atari_py")
pytest.importorskip("cv2")
@pytest.mark.parametrize(
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"])
@pytest.mark.parametrize("keep_dim", [True, False])
def test_gray_scale_observation(env_id, keep_dim):
gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True)

View File

@@ -32,9 +32,7 @@ class FakeEnvironment(gym.Env):
class FakeArrayObservationEnvironment(FakeEnvironment):
def __init__(self, *args, **kwargs):
self.observation_space = spaces.Box(
shape=(2,), low=-1, high=1, dtype=np.float32
)
self.observation_space = spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
@@ -75,10 +73,7 @@ class TestPixelObservationWrapper(object):
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert (
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
)
assert len(wrapped_env.observation_space.spaces) == len(observation_space.spaces) + 1
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
@@ -97,9 +92,7 @@ class TestPixelObservationWrapper(object):
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Box)
wrapped_env = PixelObservationWrapper(
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
)
wrapped_env = PixelObservationWrapper(env, pixel_keys=(pixel_key,), pixels_only=pixels_only)
wrapped_env.observation_space = wrapped_env.observation_space
assert isinstance(wrapped_env.observation_space, spaces.Dict)

Some files were not shown because too many files have changed in this diff Show More