redo black

This commit is contained in:
Justin Terry
2021-07-29 12:42:48 -04:00
parent d5004b7ec1
commit e9d2c41f2b
109 changed files with 459 additions and 1363 deletions

View File

@@ -3,9 +3,7 @@ import argparse
import gym import gym
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.")
description="Renders a Gym environment for quick inspection."
)
parser.add_argument( parser.add_argument(
"env_id", "env_id",
type=str, type=str,

View File

@@ -35,12 +35,7 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0):
th_std = np.ones_like(th_mean) * initial_std th_std = np.ones_like(th_mean) * initial_std
for _ in range(n_iter): for _ in range(n_iter):
ths = np.array( ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)])
[
th_mean + dth
for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)
]
)
ys = np.array([f(th) for th in ths]) ys = np.array([f(th) for th in ths])
elite_inds = ys.argsort()[::-1][:n_elite] elite_inds = ys.argsort()[::-1][:n_elite]
elite_ths = ths[elite_inds] elite_ths = ths[elite_inds]
@@ -101,9 +96,7 @@ if __name__ == "__main__":
return rew return rew
# Train the agent, and snapshot each stage # Train the agent, and snapshot each stage
for (i, iterdata) in enumerate( for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)):
cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)
):
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"])) print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
agent = BinaryActionLinearPolicy(iterdata["theta_mean"]) agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
if args.display: if args.display:

View File

@@ -17,9 +17,7 @@ class RandomAgent(object):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description=None) parser = argparse.ArgumentParser(description=None)
parser.add_argument( parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run")
"env_id", nargs="?", default="CartPole-v0", help="Select the environment to run"
)
args = parser.parse_args() args = parser.parse_args()
# You can set the level to logger.DEBUG or logger.WARN if you # You can set the level to logger.DEBUG or logger.WARN if you

View File

@@ -173,16 +173,10 @@ class GoalEnv(Env):
def reset(self): def reset(self):
# Enforce that each GoalEnv uses a Goal-compatible observation space. # Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict): if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error( raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict")
"GoalEnv requires an observation space of type gym.spaces.Dict"
)
for key in ["observation", "achieved_goal", "desired_goal"]: for key in ["observation", "achieved_goal", "desired_goal"]:
if key not in self.observation_space.spaces: if key not in self.observation_space.spaces:
raise error.Error( raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))
'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(
key
)
)
def compute_reward(self, achieved_goal, desired_goal, info): def compute_reward(self, achieved_goal, desired_goal, info):
"""Compute the step reward. This externalizes the reward function and makes """Compute the step reward. This externalizes the reward function and makes
@@ -227,9 +221,7 @@ class Wrapper(Env):
def __getattr__(self, name): def __getattr__(self, name):
if name.startswith("_"): if name.startswith("_"):
raise AttributeError( raise AttributeError("attempted to get missing private attribute '{}'".format(name))
"attempted to get missing private attribute '{}'".format(name)
)
return getattr(self.env, name) return getattr(self.env, name)
@property @property

View File

@@ -422,9 +422,7 @@ for reward_type in ["sparse", "dense"]:
register( register(
id="HandManipulateBlockRotateParallel{}-v0".format(suffix), id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
entry_point="gym.envs.robotics:HandBlockEnv", entry_point="gym.envs.robotics:HandBlockEnv",
kwargs=_merge( kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs),
{"target_position": "ignore", "target_rotation": "parallel"}, kwargs
),
max_episode_steps=100, max_episode_steps=100,
) )

View File

@@ -73,9 +73,7 @@ class AlgorithmicEnv(Env):
# 1. Move read head left or right (or up/down) # 1. Move read head left or right (or up/down)
# 2. Write or not # 2. Write or not
# 3. Which character to write. (Ignored if should_write=0) # 3. Which character to write. (Ignored if should_write=0)
self.action_space = Tuple( self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
)
# Can see just what is on the input tape (one of n characters, or # Can see just what is on the input tape (one of n characters, or
# nothing) # nothing)
self.observation_space = Discrete(self.base + 1) self.observation_space = Discrete(self.base + 1)
@@ -147,10 +145,7 @@ class AlgorithmicEnv(Env):
move = self.MOVEMENTS[inp_act] move = self.MOVEMENTS[inp_act]
outfile.write("Action : Tuple(move over input: %s,\n" % move) outfile.write("Action : Tuple(move over input: %s,\n" % move)
out_act = out_act == 1 out_act = out_act == 1
outfile.write( outfile.write(" write to the output tape: %s,\n" % out_act)
" write to the output tape: %s,\n"
% out_act
)
outfile.write(" prediction: %s)\n" % pred_str) outfile.write(" prediction: %s)\n" % pred_str)
else: else:
outfile.write("\n" * 5) outfile.write("\n" * 5)
@@ -276,9 +271,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
x_str = "Observation Tape : " x_str = "Observation Tape : "
for i in range(-2, self.input_width + 2): for i in range(-2, self.input_width + 2):
if i == x: if i == x:
x_str += colorize( x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
self._get_str_obs(np.array([i])), "green", highlight=True
)
else: else:
x_str += self._get_str_obs(np.array([i])) x_str += self._get_str_obs(np.array([i]))
x_str += "\n" x_str += "\n"
@@ -311,10 +304,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
self.read_head_position = x, y self.read_head_position = x, y
def generate_input_data(self, size): def generate_input_data(self, size):
return [ return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]
[self.np_random.randint(self.base) for _ in range(self.rows)]
for __ in range(size)
]
def _get_obs(self, pos=None): def _get_obs(self, pos=None):
if pos is None: if pos is None:
@@ -336,9 +326,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
x_str += " " * len(label) x_str += " " * len(label)
for i in range(-2, self.input_width + 2): for i in range(-2, self.input_width + 2):
if i == x[0] and j == x[1]: if i == x[0] and j == x[1]:
x_str += colorize( x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
self._get_str_obs((i, j)), "green", highlight=True
)
else: else:
x_str += self._get_str_obs((i, j)) x_str += self._get_str_obs((i, j))
x_str += "\n" x_str += "\n"

View File

@@ -10,12 +10,8 @@ ALL_ENVS = [
alg.reverse.ReverseEnv, alg.reverse.ReverseEnv,
alg.reversed_addition.ReversedAdditionEnv, alg.reversed_addition.ReversedAdditionEnv,
] ]
ALL_TAPE_ENVS = [ ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)]
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv) ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)]
]
ALL_GRID_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)
]
def imprint(env, input_arr): def imprint(env, input_arr):
@@ -92,10 +88,7 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
def test_grid_naviation(self): def test_grid_naviation(self):
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6) env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
N, S, E, W = [ N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]]
env._movement_idx(named_dir)
for named_dir in ["up", "down", "right", "left"]
]
# Corresponds to a grid that looks like... # Corresponds to a grid that looks like...
# 0 1 2 # 0 1 2
# 3 4 5 # 3 4 5
@@ -204,9 +197,7 @@ class TestTargets(unittest.TestCase):
def test_repeat_copy_target(self): def test_repeat_copy_target(self):
env = alg.repeat_copy.RepeatCopyEnv() env = alg.repeat_copy.RepeatCopyEnv()
self.assertEqual( self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2])
env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]
)
class TestInputGeneration(unittest.TestCase): class TestInputGeneration(unittest.TestCase):

View File

@@ -9,8 +9,7 @@ try:
import atari_py import atari_py
except ImportError as e: except ImportError as e:
raise error.DependencyNotInstalled( raise error.DependencyNotInstalled(
"{}. (HINT: you can install Atari dependencies by running " "{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e)
"'pip install gym[atari]'.)".format(e)
) )
@@ -64,35 +63,23 @@ class AtariEnv(gym.Env, utils.EzPickle):
# Tune (or disable) ALE's action repeat: # Tune (or disable) ALE's action repeat:
# https://github.com/openai/gym/issues/349 # https://github.com/openai/gym/issues/349
assert isinstance( assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format(
repeat_action_probability, (float, int) repeat_action_probability
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
self.ale.setFloat(
"repeat_action_probability".encode("utf-8"), repeat_action_probability
) )
self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability)
self.seed() self.seed()
self._action_set = ( self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet()
self.ale.getLegalActionSet()
if full_action_space
else self.ale.getMinimalActionSet()
)
self.action_space = spaces.Discrete(len(self._action_set)) self.action_space = spaces.Discrete(len(self._action_set))
(screen_width, screen_height) = self.ale.getScreenDims() (screen_width, screen_height) = self.ale.getScreenDims()
if self._obs_type == "ram": if self._obs_type == "ram":
self.observation_space = spaces.Box( self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,))
low=0, high=255, dtype=np.uint8, shape=(128,)
)
elif self._obs_type == "image": elif self._obs_type == "image":
self.observation_space = spaces.Box( self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8)
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
)
else: else:
raise error.Error( raise error.Error("Unrecognized observation type: {}".format(self._obs_type))
"Unrecognized observation type: {}".format(self._obs_type)
)
def seed(self, seed=None): def seed(self, seed=None):
self.np_random, seed1 = seeding.np_random(seed) self.np_random, seed1 = seeding.np_random(seed)
@@ -107,9 +94,9 @@ class AtariEnv(gym.Env, utils.EzPickle):
if self.game_mode is not None: if self.game_mode is not None:
modes = self.ale.getAvailableModes() modes = self.ale.getAvailableModes()
assert self.game_mode in modes, ( assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format(
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}' self.game_mode, self.game, modes
).format(self.game_mode, self.game, modes) )
self.ale.setMode(self.game_mode) self.ale.setMode(self.game_mode)
if self.game_difficulty is not None: if self.game_difficulty is not None:

View File

@@ -100,10 +100,7 @@ class ContactDetector(contactListener):
self.env = env self.env = env
def BeginContact(self, contact): def BeginContact(self, contact):
if ( if self.env.hull == contact.fixtureA.body or self.env.hull == contact.fixtureB.body:
self.env.hull == contact.fixtureA.body
or self.env.hull == contact.fixtureB.body
):
self.env.game_over = True self.env.game_over = True
for leg in [self.env.legs[1], self.env.legs[3]]: for leg in [self.env.legs[1], self.env.legs[3]]:
if leg in [contact.fixtureA.body, contact.fixtureB.body]: if leg in [contact.fixtureA.body, contact.fixtureB.body]:
@@ -202,9 +199,7 @@ class BipedalWalker(gym.Env, EzPickle):
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6) t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
self.terrain.append(t) self.terrain.append(t)
self.fd_polygon.shape.vertices = [ self.fd_polygon.shape.vertices = [(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly]
(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly
]
t = self.world.CreateStaticBody(fixtures=self.fd_polygon) t = self.world.CreateStaticBody(fixtures=self.fd_polygon)
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6) t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
self.terrain.append(t) self.terrain.append(t)
@@ -301,12 +296,8 @@ class BipedalWalker(gym.Env, EzPickle):
y = VIEWPORT_H / SCALE * 3 / 4 y = VIEWPORT_H / SCALE * 3 / 4
poly = [ poly = [
( (
x x + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
+ 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) y + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
y
+ 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5)
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
) )
for a in range(5) for a in range(5)
] ]
@@ -331,14 +322,10 @@ class BipedalWalker(gym.Env, EzPickle):
init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2 init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2
init_y = TERRAIN_HEIGHT + 2 * LEG_H init_y = TERRAIN_HEIGHT + 2 * LEG_H
self.hull = self.world.CreateDynamicBody( self.hull = self.world.CreateDynamicBody(position=(init_x, init_y), fixtures=HULL_FD)
position=(init_x, init_y), fixtures=HULL_FD
)
self.hull.color1 = (0.5, 0.4, 0.9) self.hull.color1 = (0.5, 0.4, 0.9)
self.hull.color2 = (0.3, 0.3, 0.5) self.hull.color2 = (0.3, 0.3, 0.5)
self.hull.ApplyForceToCenter( self.hull.ApplyForceToCenter((self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True)
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
)
self.legs = [] self.legs = []
self.joints = [] self.joints = []
@@ -412,21 +399,13 @@ class BipedalWalker(gym.Env, EzPickle):
self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1)) self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1))
else: else:
self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0])) self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0]))
self.joints[0].maxMotorTorque = float( self.joints[0].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1))
MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1)
)
self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1])) self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1]))
self.joints[1].maxMotorTorque = float( self.joints[1].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1))
MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1)
)
self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2])) self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2]))
self.joints[2].maxMotorTorque = float( self.joints[2].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1))
MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1)
)
self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3])) self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3]))
self.joints[3].maxMotorTorque = float( self.joints[3].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1))
MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1)
)
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30) self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
@@ -465,12 +444,8 @@ class BipedalWalker(gym.Env, EzPickle):
self.scroll = pos.x - VIEWPORT_W / SCALE / 5 self.scroll = pos.x - VIEWPORT_W / SCALE / 5
shaping = ( shaping = 130 * pos[0] / SCALE # moving forward is a way to receive reward (normalized to get 300 on completion)
130 * pos[0] / SCALE shaping -= 5.0 * abs(state[0]) # keep head straight, other than that and falling, any behavior is unpunished
) # moving forward is a way to receive reward (normalized to get 300 on completion)
shaping -= 5.0 * abs(
state[0]
) # keep head straight, other than that and falling, any behavior is unpunished
reward = 0 reward = 0
if self.prev_shaping is not None: if self.prev_shaping is not None:
@@ -494,9 +469,7 @@ class BipedalWalker(gym.Env, EzPickle):
if self.viewer is None: if self.viewer is None:
self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H) self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H)
self.viewer.set_bounds( self.viewer.set_bounds(self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE)
self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE
)
self.viewer.draw_polygon( self.viewer.draw_polygon(
[ [
@@ -512,9 +485,7 @@ class BipedalWalker(gym.Env, EzPickle):
continue continue
if x1 > self.scroll / 2 + VIEWPORT_W / SCALE: if x1 > self.scroll / 2 + VIEWPORT_W / SCALE:
continue continue
self.viewer.draw_polygon( self.viewer.draw_polygon([(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1))
[(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1)
)
for poly, color in self.terrain_poly: for poly, color in self.terrain_poly:
if poly[1][0] < self.scroll: if poly[1][0] < self.scroll:
continue continue
@@ -525,11 +496,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.lidar_render = (self.lidar_render + 1) % 100 self.lidar_render = (self.lidar_render + 1) % 100
i = self.lidar_render i = self.lidar_render
if i < 2 * len(self.lidar): if i < 2 * len(self.lidar):
l = ( l = self.lidar[i] if i < len(self.lidar) else self.lidar[len(self.lidar) - i - 1]
self.lidar[i]
if i < len(self.lidar)
else self.lidar[len(self.lidar) - i - 1]
)
self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1) self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1)
for obj in self.drawlist: for obj in self.drawlist:
@@ -537,12 +504,8 @@ class BipedalWalker(gym.Env, EzPickle):
trans = f.body.transform trans = f.body.transform
if type(f.shape) is circleShape: if type(f.shape) is circleShape:
t = rendering.Transform(translation=trans * f.shape.pos) t = rendering.Transform(translation=trans * f.shape.pos)
self.viewer.draw_circle( self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t)
f.shape.radius, 30, color=obj.color1 self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t)
).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2
).add_attr(t)
else: else:
path = [trans * v for v in f.shape.vertices] path = [trans * v for v in f.shape.vertices]
self.viewer.draw_polygon(path, color=obj.color1) self.viewer.draw_polygon(path, color=obj.color1)
@@ -552,9 +515,7 @@ class BipedalWalker(gym.Env, EzPickle):
flagy1 = TERRAIN_HEIGHT flagy1 = TERRAIN_HEIGHT
flagy2 = flagy1 + 50 / SCALE flagy2 = flagy1 + 50 / SCALE
x = TERRAIN_STEP * 3 x = TERRAIN_STEP * 3
self.viewer.draw_polyline( self.viewer.draw_polyline([(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2)
[(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2
)
f = [ f = [
(x, flagy2), (x, flagy2),
(x, flagy2 - 10 / SCALE), (x, flagy2 - 10 / SCALE),

View File

@@ -23,9 +23,7 @@ from Box2D.b2 import (
SIZE = 0.02 SIZE = 0.02
ENGINE_POWER = 100000000 * SIZE * SIZE ENGINE_POWER = 100000000 * SIZE * SIZE
WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE
FRICTION_LIMIT = ( FRICTION_LIMIT = 1000000 * SIZE * SIZE # friction ~= mass ~= size^2 (calculated implicitly using density)
1000000 * SIZE * SIZE
) # friction ~= mass ~= size^2 (calculated implicitly using density)
WHEEL_R = 27 WHEEL_R = 27
WHEEL_W = 14 WHEEL_W = 14
WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)] WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)]
@@ -55,27 +53,19 @@ class Car:
angle=init_angle, angle=init_angle,
fixtures=[ fixtures=[
fixtureDef( fixtureDef(
shape=polygonShape( shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]),
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]
),
density=1.0, density=1.0,
), ),
fixtureDef( fixtureDef(
shape=polygonShape( shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]),
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]
),
density=1.0, density=1.0,
), ),
fixtureDef( fixtureDef(
shape=polygonShape( shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]),
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]
),
density=1.0, density=1.0,
), ),
fixtureDef( fixtureDef(
shape=polygonShape( shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]),
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]
),
density=1.0, density=1.0,
), ),
], ],
@@ -95,12 +85,7 @@ class Car:
position=(init_x + wx * SIZE, init_y + wy * SIZE), position=(init_x + wx * SIZE, init_y + wy * SIZE),
angle=init_angle, angle=init_angle,
fixtures=fixtureDef( fixtures=fixtureDef(
shape=polygonShape( shape=polygonShape(vertices=[(x * front_k * SIZE, y * front_k * SIZE) for x, y in WHEEL_POLY]),
vertices=[
(x * front_k * SIZE, y * front_k * SIZE)
for x, y in WHEEL_POLY
]
),
density=0.1, density=0.1,
categoryBits=0x0020, categoryBits=0x0020,
maskBits=0x001, maskBits=0x001,
@@ -175,9 +160,7 @@ class Car:
grass = True grass = True
friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile
for tile in w.tiles: for tile in w.tiles:
friction_limit = max( friction_limit = max(friction_limit, FRICTION_LIMIT * tile.road_friction)
friction_limit, FRICTION_LIMIT * tile.road_friction
)
grass = False grass = False
# Force # Force
@@ -192,13 +175,7 @@ class Car:
# domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega # domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega
# add small coef not to divide by zero # add small coef not to divide by zero
w.omega += ( w.omega += dt * ENGINE_POWER * w.gas / WHEEL_MOMENT_OF_INERTIA / (abs(w.omega) + 5.0)
dt
* ENGINE_POWER
* w.gas
/ WHEEL_MOMENT_OF_INERTIA
/ (abs(w.omega) + 5.0)
)
self.fuel_spent += dt * ENGINE_POWER * w.gas self.fuel_spent += dt * ENGINE_POWER * w.gas
if w.brake >= 0.9: if w.brake >= 0.9:
@@ -226,18 +203,12 @@ class Car:
# Skid trace # Skid trace
if abs(force) > 2.0 * friction_limit: if abs(force) > 2.0 * friction_limit:
if ( if w.skid_particle and w.skid_particle.grass == grass and len(w.skid_particle.poly) < 30:
w.skid_particle
and w.skid_particle.grass == grass
and len(w.skid_particle.poly) < 30
):
w.skid_particle.poly.append((w.position[0], w.position[1])) w.skid_particle.poly.append((w.position[0], w.position[1]))
elif w.skid_start is None: elif w.skid_start is None:
w.skid_start = w.position w.skid_start = w.position
else: else:
w.skid_particle = self._create_particle( w.skid_particle = self._create_particle(w.skid_start, w.position, grass)
w.skid_start, w.position, grass
)
w.skid_start = None w.skid_start = None
else: else:
w.skid_start = None w.skid_start = None

View File

@@ -132,18 +132,14 @@ class CarRacing(gym.Env, EzPickle):
self.reward = 0.0 self.reward = 0.0
self.prev_reward = 0.0 self.prev_reward = 0.0
self.verbose = verbose self.verbose = verbose
self.fd_tile = fixtureDef( self.fd_tile = fixtureDef(shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)]))
shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])
)
self.action_space = spaces.Box( self.action_space = spaces.Box(
np.array([-1, 0, 0]).astype(np.float32), np.array([-1, 0, 0]).astype(np.float32),
np.array([+1, +1, +1]).astype(np.float32), np.array([+1, +1, +1]).astype(np.float32),
) # steer, gas, brake ) # steer, gas, brake
self.observation_space = spaces.Box( self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8)
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
)
def seed(self, seed=None): def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
@@ -246,9 +242,7 @@ class CarRacing(gym.Env, EzPickle):
i -= 1 i -= 1
if i == 0: if i == 0:
return False # Failed return False # Failed
pass_through_start = ( pass_through_start = track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
)
if pass_through_start and i2 == -1: if pass_through_start and i2 == -1:
i2 = i i2 = i
elif pass_through_start and i1 == -1: elif pass_through_start and i1 == -1:
@@ -266,8 +260,7 @@ class CarRacing(gym.Env, EzPickle):
first_perp_y = math.sin(first_beta) first_perp_y = math.sin(first_beta)
# Length of perpendicular jump to put together head and tail # Length of perpendicular jump to put together head and tail
well_glued_together = np.sqrt( well_glued_together = np.sqrt(
np.square(first_perp_x * (track[0][2] - track[-1][2])) np.square(first_perp_x * (track[0][2] - track[-1][2])) + np.square(first_perp_y * (track[0][3] - track[-1][3]))
+ np.square(first_perp_y * (track[0][3] - track[-1][3]))
) )
if well_glued_together > TRACK_DETAIL_STEP: if well_glued_together > TRACK_DETAIL_STEP:
return False return False
@@ -337,9 +330,7 @@ class CarRacing(gym.Env, EzPickle):
x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2), x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2), y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
) )
self.road_poly.append( self.road_poly.append(([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0)))
([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0))
)
self.track = track self.track = track
return True return True
@@ -356,10 +347,7 @@ class CarRacing(gym.Env, EzPickle):
if success: if success:
break break
if self.verbose == 1: if self.verbose == 1:
print( print("retry to generate track (normal if there are not many" "instances of this message)")
"retry to generate track (normal if there are not many"
"instances of this message)"
)
self.car = Car(self.world, *self.track[0][1:4]) self.car = Car(self.world, *self.track[0][1:4])
return self.step(None)[0] return self.step(None)[0]
@@ -424,10 +412,8 @@ class CarRacing(gym.Env, EzPickle):
angle = math.atan2(vel[0], vel[1]) angle = math.atan2(vel[0], vel[1])
self.transform.set_scale(zoom, zoom) self.transform.set_scale(zoom, zoom)
self.transform.set_translation( self.transform.set_translation(
WINDOW_W / 2 WINDOW_W / 2 - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
- (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)), WINDOW_H / 4 - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
WINDOW_H / 4
- (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
) )
self.transform.set_rotation(angle) self.transform.set_rotation(angle)
@@ -449,9 +435,7 @@ class CarRacing(gym.Env, EzPickle):
else: else:
pixel_scale = 1 pixel_scale = 1
if hasattr(win.context, "_nscontext"): if hasattr(win.context, "_nscontext"):
pixel_scale = ( pixel_scale = win.context._nscontext.view().backingScaleFactor() # pylint: disable=protected-access
win.context._nscontext.view().backingScaleFactor()
) # pylint: disable=protected-access
VP_W = int(pixel_scale * WINDOW_W) VP_W = int(pixel_scale * WINDOW_W)
VP_H = int(pixel_scale * WINDOW_H) VP_H = int(pixel_scale * WINDOW_H)
@@ -468,9 +452,7 @@ class CarRacing(gym.Env, EzPickle):
win.flip() win.flip()
return self.viewer.isopen return self.viewer.isopen
image_data = ( image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
)
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="") arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
arr = arr.reshape(VP_H, VP_W, 4) arr = arr.reshape(VP_H, VP_W, 4)
arr = arr[::-1, :, 0:3] arr = arr[::-1, :, 0:3]
@@ -525,9 +507,7 @@ class CarRacing(gym.Env, EzPickle):
for p in poly: for p in poly:
polygons_.extend([p[0], p[1], 0]) polygons_.extend([p[0], p[1], 0])
vl = pyglet.graphics.vertex_list( vl = pyglet.graphics.vertex_list(len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)) # gl.GL_QUADS,
len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors) # gl.GL_QUADS,
)
vl.draw(gl.GL_QUADS) vl.draw(gl.GL_QUADS)
vl.delete() vl.delete()
@@ -575,10 +555,7 @@ class CarRacing(gym.Env, EzPickle):
] ]
) )
true_speed = np.sqrt( true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1]))
np.square(self.car.hull.linearVelocity[0])
+ np.square(self.car.hull.linearVelocity[1])
)
vertical_ind(5, 0.02 * true_speed, (1, 1, 1)) vertical_ind(5, 0.02 * true_speed, (1, 1, 1))
vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors
@@ -587,9 +564,7 @@ class CarRacing(gym.Env, EzPickle):
vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1)) vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1))
horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0)) horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0))
horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0)) horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0))
vl = pyglet.graphics.vertex_list( vl = pyglet.graphics.vertex_list(len(polygons) // 3, ("v3f", polygons), ("c4f", colors)) # gl.GL_QUADS,
len(polygons) // 3, ("v3f", polygons), ("c4f", colors) # gl.GL_QUADS,
)
vl.draw(gl.GL_QUADS) vl.draw(gl.GL_QUADS)
vl.delete() vl.delete()
self.score_label.text = "%04i" % self.reward self.score_label.text = "%04i" % self.reward

View File

@@ -71,10 +71,7 @@ class ContactDetector(contactListener):
self.env = env self.env = env
def BeginContact(self, contact): def BeginContact(self, contact):
if ( if self.env.lander == contact.fixtureA.body or self.env.lander == contact.fixtureB.body:
self.env.lander == contact.fixtureA.body
or self.env.lander == contact.fixtureB.body
):
self.env.game_over = True self.env.game_over = True
for i in range(2): for i in range(2):
if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]: if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
@@ -104,9 +101,7 @@ class LunarLander(gym.Env, EzPickle):
self.prev_reward = None self.prev_reward = None
# useful range is -1 .. +1, but spikes can be higher # useful range is -1 .. +1, but spikes can be higher
self.observation_space = spaces.Box( self.observation_space = spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32)
-np.inf, np.inf, shape=(8,), dtype=np.float32
)
if self.continuous: if self.continuous:
# Action is two floats [main engine, left-right engines]. # Action is two floats [main engine, left-right engines].
@@ -157,14 +152,9 @@ class LunarLander(gym.Env, EzPickle):
height[CHUNKS // 2 + 0] = self.helipad_y height[CHUNKS // 2 + 0] = self.helipad_y
height[CHUNKS // 2 + 1] = self.helipad_y height[CHUNKS // 2 + 1] = self.helipad_y
height[CHUNKS // 2 + 2] = self.helipad_y height[CHUNKS // 2 + 2] = self.helipad_y
smooth_y = [ smooth_y = [0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) for i in range(CHUNKS)]
0.33 * (height[i - 1] + height[i + 0] + height[i + 1])
for i in range(CHUNKS)
]
self.moon = self.world.CreateStaticBody( self.moon = self.world.CreateStaticBody(shapes=edgeShape(vertices=[(0, 0), (W, 0)]))
shapes=edgeShape(vertices=[(0, 0), (W, 0)])
)
self.sky_polys = [] self.sky_polys = []
for i in range(CHUNKS - 1): for i in range(CHUNKS - 1):
p1 = (chunk_x[i], smooth_y[i]) p1 = (chunk_x[i], smooth_y[i])
@@ -180,9 +170,7 @@ class LunarLander(gym.Env, EzPickle):
position=(VIEWPORT_W / SCALE / 2, initial_y), position=(VIEWPORT_W / SCALE / 2, initial_y),
angle=0.0, angle=0.0,
fixtures=fixtureDef( fixtures=fixtureDef(
shape=polygonShape( shape=polygonShape(vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]),
vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]
),
density=5.0, density=5.0,
friction=0.1, friction=0.1,
categoryBits=0x0010, categoryBits=0x0010,
@@ -227,9 +215,7 @@ class LunarLander(gym.Env, EzPickle):
motorSpeed=+0.3 * i, # low enough not to jump back into the sky motorSpeed=+0.3 * i, # low enough not to jump back into the sky
) )
if i == -1: if i == -1:
rjd.lowerAngle = ( rjd.lowerAngle = +0.9 - 0.5 # The most esoteric numbers here, angled legs have freedom to travel within
+0.9 - 0.5
) # The most esoteric numbers here, angled legs have freedom to travel within
rjd.upperAngle = +0.9 rjd.upperAngle = +0.9
else: else:
rjd.lowerAngle = -0.9 rjd.lowerAngle = -0.9
@@ -278,9 +264,7 @@ class LunarLander(gym.Env, EzPickle):
dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)] dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]
m_power = 0.0 m_power = 0.0
if (self.continuous and action[0] > 0.0) or ( if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2):
not self.continuous and action == 2
):
# Main engine # Main engine
if self.continuous: if self.continuous:
m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0 m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0
@@ -310,9 +294,7 @@ class LunarLander(gym.Env, EzPickle):
) )
s_power = 0.0 s_power = 0.0
if (self.continuous and np.abs(action[1]) > 0.5) or ( if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]):
not self.continuous and action in [1, 3]
):
# Orientation engines # Orientation engines
if self.continuous: if self.continuous:
direction = np.sign(action[1]) direction = np.sign(action[1])
@@ -321,12 +303,8 @@ class LunarLander(gym.Env, EzPickle):
else: else:
direction = action - 2 direction = action - 2
s_power = 1.0 s_power = 1.0
ox = tip[0] * dispersion[0] + side[0] * ( ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
)
oy = -tip[1] * dispersion[0] - side[1] * (
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
)
impulse_pos = ( impulse_pos = (
self.lander.position[0] + ox - tip[0] * 17 / SCALE, self.lander.position[0] + ox - tip[0] * 17 / SCALE,
self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE, self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
@@ -372,9 +350,7 @@ class LunarLander(gym.Env, EzPickle):
reward = shaping - self.prev_shaping reward = shaping - self.prev_shaping
self.prev_shaping = shaping self.prev_shaping = shaping
reward -= ( reward -= m_power * 0.30 # less fuel spent is better, about -30 for heuristic landing
m_power * 0.30
) # less fuel spent is better, about -30 for heuristic landing
reward -= s_power * 0.03 reward -= s_power * 0.03
done = False done = False
@@ -416,12 +392,8 @@ class LunarLander(gym.Env, EzPickle):
trans = f.body.transform trans = f.body.transform
if type(f.shape) is circleShape: if type(f.shape) is circleShape:
t = rendering.Transform(translation=trans * f.shape.pos) t = rendering.Transform(translation=trans * f.shape.pos)
self.viewer.draw_circle( self.viewer.draw_circle(f.shape.radius, 20, color=obj.color1).add_attr(t)
f.shape.radius, 20, color=obj.color1 self.viewer.draw_circle(f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2).add_attr(t)
).add_attr(t)
self.viewer.draw_circle(
f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2
).add_attr(t)
else: else:
path = [trans * v for v in f.shape.vertices] path = [trans * v for v in f.shape.vertices]
self.viewer.draw_polygon(path, color=obj.color1) self.viewer.draw_polygon(path, color=obj.color1)
@@ -479,18 +451,14 @@ def heuristic(env, s):
angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad
if angle_targ < -0.4: if angle_targ < -0.4:
angle_targ = -0.4 angle_targ = -0.4
hover_targ = 0.55 * np.abs( hover_targ = 0.55 * np.abs(s[0]) # target y should be proportional to horizontal offset
s[0]
) # target y should be proportional to horizontal offset
angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0 angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5 hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5
if s[6] or s[7]: # legs have contact if s[6] or s[7]: # legs have contact
angle_todo = 0 angle_todo = 0
hover_todo = ( hover_todo = -(s[3]) * 0.5 # override to reduce fall speed, that's all we need after contact
-(s[3]) * 0.5
) # override to reduce fall speed, that's all we need after contact
if env.continuous: if env.continuous:
a = np.array([hover_todo * 20 - 1, -angle_todo * 20]) a = np.array([hover_todo * 20 - 1, -angle_todo * 20])

View File

@@ -88,9 +88,7 @@ class AcrobotEnv(core.Env):
def __init__(self): def __init__(self):
self.viewer = None self.viewer = None
high = np.array( high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32)
[1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32
)
low = -high low = -high
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
@@ -111,9 +109,7 @@ class AcrobotEnv(core.Env):
# Add noise to the force action # Add noise to the force action
if self.torque_noise_max > 0: if self.torque_noise_max > 0:
torque += self.np_random.uniform( torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max)
-self.torque_noise_max, self.torque_noise_max
)
# Now, augment the state with our force action so it can be passed to # Now, augment the state with our force action so it can be passed to
# _dsdt # _dsdt
@@ -160,12 +156,7 @@ class AcrobotEnv(core.Env):
theta2 = s[1] theta2 = s[1]
dtheta1 = s[2] dtheta1 = s[2]
dtheta2 = s[3] dtheta2 = s[3]
d1 = ( d1 = m1 * lc1 ** 2 + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2
m1 * lc1 ** 2
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2))
+ I1
+ I2
)
d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2 d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2
phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0) phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
phi1 = ( phi1 = (
@@ -181,9 +172,9 @@ class AcrobotEnv(core.Env):
else: else:
# the following line is consistent with the java implementation and the # the following line is consistent with the java implementation and the
# book # book
ddtheta2 = ( ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2) / (
a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2 m2 * lc2 ** 2 + I2 - d2 ** 2 / d1
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) )
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)

View File

@@ -111,9 +111,7 @@ class CartPoleEnv(gym.Env):
# For the interested reader: # For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf # https://coneural.org/florian/papers/05_cart_pole.pdf
temp = ( temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
force + self.polemass_length * theta_dot ** 2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / ( thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass) self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
) )

View File

@@ -62,27 +62,17 @@ class Continuous_MountainCarEnv(gym.Env):
self.min_position = -1.2 self.min_position = -1.2
self.max_position = 0.6 self.max_position = 0.6
self.max_speed = 0.07 self.max_speed = 0.07
self.goal_position = ( self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
)
self.goal_velocity = goal_velocity self.goal_velocity = goal_velocity
self.power = 0.0015 self.power = 0.0015
self.low_state = np.array( self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
[self.min_position, -self.max_speed], dtype=np.float32 self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)
)
self.high_state = np.array(
[self.max_position, self.max_speed], dtype=np.float32
)
self.viewer = None self.viewer = None
self.action_space = spaces.Box( self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32)
low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32 self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32)
)
self.observation_space = spaces.Box(
low=self.low_state, high=self.high_state, dtype=np.float32
)
self.seed() self.seed()
self.reset() self.reset()
@@ -159,15 +149,11 @@ class Continuous_MountainCarEnv(gym.Env):
self.viewer.add_geom(car) self.viewer.add_geom(car)
frontwheel = rendering.make_circle(carheight / 2.5) frontwheel = rendering.make_circle(carheight / 2.5)
frontwheel.set_color(0.5, 0.5, 0.5) frontwheel.set_color(0.5, 0.5, 0.5)
frontwheel.add_attr( frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
rendering.Transform(translation=(carwidth / 4, clearance))
)
frontwheel.add_attr(self.cartrans) frontwheel.add_attr(self.cartrans)
self.viewer.add_geom(frontwheel) self.viewer.add_geom(frontwheel)
backwheel = rendering.make_circle(carheight / 2.5) backwheel = rendering.make_circle(carheight / 2.5)
backwheel.add_attr( backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
rendering.Transform(translation=(-carwidth / 4, clearance))
)
backwheel.add_attr(self.cartrans) backwheel.add_attr(self.cartrans)
backwheel.set_color(0.5, 0.5, 0.5) backwheel.set_color(0.5, 0.5, 0.5)
self.viewer.add_geom(backwheel) self.viewer.add_geom(backwheel)
@@ -176,16 +162,12 @@ class Continuous_MountainCarEnv(gym.Env):
flagy2 = flagy1 + 50 flagy2 = flagy1 + 50
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
self.viewer.add_geom(flagpole) self.viewer.add_geom(flagpole)
flag = rendering.FilledPolygon( flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
)
flag.set_color(0.8, 0.8, 0) flag.set_color(0.8, 0.8, 0)
self.viewer.add_geom(flag) self.viewer.add_geom(flag)
pos = self.state[0] pos = self.state[0]
self.cartrans.set_translation( self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
(pos - self.min_position) * scale, self._height(pos) * scale
)
self.cartrans.set_rotation(math.cos(3 * pos)) self.cartrans.set_rotation(math.cos(3 * pos))
return self.viewer.render(return_rgb_array=mode == "rgb_array") return self.viewer.render(return_rgb_array=mode == "rgb_array")

View File

@@ -136,15 +136,11 @@ class MountainCarEnv(gym.Env):
self.viewer.add_geom(car) self.viewer.add_geom(car)
frontwheel = rendering.make_circle(carheight / 2.5) frontwheel = rendering.make_circle(carheight / 2.5)
frontwheel.set_color(0.5, 0.5, 0.5) frontwheel.set_color(0.5, 0.5, 0.5)
frontwheel.add_attr( frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
rendering.Transform(translation=(carwidth / 4, clearance))
)
frontwheel.add_attr(self.cartrans) frontwheel.add_attr(self.cartrans)
self.viewer.add_geom(frontwheel) self.viewer.add_geom(frontwheel)
backwheel = rendering.make_circle(carheight / 2.5) backwheel = rendering.make_circle(carheight / 2.5)
backwheel.add_attr( backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
rendering.Transform(translation=(-carwidth / 4, clearance))
)
backwheel.add_attr(self.cartrans) backwheel.add_attr(self.cartrans)
backwheel.set_color(0.5, 0.5, 0.5) backwheel.set_color(0.5, 0.5, 0.5)
self.viewer.add_geom(backwheel) self.viewer.add_geom(backwheel)
@@ -153,16 +149,12 @@ class MountainCarEnv(gym.Env):
flagy2 = flagy1 + 50 flagy2 = flagy1 + 50
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
self.viewer.add_geom(flagpole) self.viewer.add_geom(flagpole)
flag = rendering.FilledPolygon( flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
)
flag.set_color(0.8, 0.8, 0) flag.set_color(0.8, 0.8, 0)
self.viewer.add_geom(flag) self.viewer.add_geom(flag)
pos = self.state[0] pos = self.state[0]
self.cartrans.set_translation( self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
(pos - self.min_position) * scale, self._height(pos) * scale
)
self.cartrans.set_rotation(math.cos(3 * pos)) self.cartrans.set_rotation(math.cos(3 * pos))
return self.viewer.render(return_rgb_array=mode == "rgb_array") return self.viewer.render(return_rgb_array=mode == "rgb_array")

View File

@@ -18,9 +18,7 @@ class PendulumEnv(gym.Env):
self.viewer = None self.viewer = None
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32) high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
self.action_space = spaces.Box( self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32
)
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32) self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
self.seed() self.seed()
@@ -41,10 +39,7 @@ class PendulumEnv(gym.Env):
self.last_u = u # for rendering self.last_u = u # for rendering
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2) costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2)
newthdot = ( newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
thdot
+ (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
)
newth = th + newthdot * dt newth = th + newthdot * dt
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)

View File

@@ -54,11 +54,7 @@ def get_display(spec):
elif isinstance(spec, str): elif isinstance(spec, str):
return pyglet.canvas.Display(spec) return pyglet.canvas.Display(spec)
else: else:
raise error.Error( raise error.Error("Invalid display specification: {}. (Must be a string like :0 or None.)".format(spec))
"Invalid display specification: {}. (Must be a string like :0 or None.)".format(
spec
)
)
def get_window(width, height, display, **kwargs): def get_window(width, height, display, **kwargs):
@@ -69,14 +65,7 @@ def get_window(width, height, display, **kwargs):
config = screen[0].get_best_config() # selecting the first screen config = screen[0].get_best_config() # selecting the first screen
context = config.create_context(None) # create GL context context = config.create_context(None) # create GL context
return pyglet.window.Window( return pyglet.window.Window(width=width, height=height, display=display, config=config, context=context, **kwargs)
width=width,
height=height,
display=display,
config=config,
context=context,
**kwargs
)
class Viewer(object): class Viewer(object):
@@ -108,9 +97,7 @@ class Viewer(object):
assert right > left and top > bottom assert right > left and top > bottom
scalex = self.width / (right - left) scalex = self.width / (right - left)
scaley = self.height / (top - bottom) scaley = self.height / (top - bottom)
self.transform = Transform( self.transform = Transform(translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley))
translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley)
)
def add_geom(self, geom): def add_geom(self, geom):
self.geoms.append(geom) self.geoms.append(geom)
@@ -173,9 +160,7 @@ class Viewer(object):
def get_array(self): def get_array(self):
self.window.flip() self.window.flip()
image_data = ( image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
)
self.window.flip() self.window.flip()
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="") arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
arr = arr.reshape(self.height, self.width, 4) arr = arr.reshape(self.height, self.width, 4)
@@ -230,9 +215,7 @@ class Transform(Attr):
def enable(self): def enable(self):
glPushMatrix() glPushMatrix()
glTranslatef( glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
self.translation[0], self.translation[1], 0
) # translate to GL loc ppint
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0) glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
glScalef(self.scale[0], self.scale[1], 1) glScalef(self.scale[0], self.scale[1], 1)
@@ -392,9 +375,7 @@ class Image(Geom):
self.flip = False self.flip = False
def render1(self): def render1(self):
self.img.blit( self.img.blit(-self.width / 2, -self.height / 2, width=self.width, height=self.height)
-self.width / 2, -self.height / 2, width=self.width, height=self.height
)
# ================================================================ # ================================================================
@@ -435,9 +416,7 @@ class SimpleImageViewer(object):
self.isopen = False self.isopen = False
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape" assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData( image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3)
arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3
)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
texture = image.get_texture() texture = image.get_texture()
texture.width = self.width texture.width = self.width

View File

@@ -14,9 +14,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
xposafter = self.get_body_com("torso")[0] xposafter = self.get_body_com("torso")[0]
forward_reward = (xposafter - xposbefore) / self.dt forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = 0.5 * np.square(a).sum() ctrl_cost = 0.5 * np.square(a).sum()
contact_cost = ( contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
)
survive_reward = 1.0 survive_reward = 1.0
reward = forward_reward - ctrl_cost - contact_cost + survive_reward reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self.state_vector() state = self.state_vector()
@@ -45,9 +43,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
) )
def reset_model(self): def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1)
size=self.model.nq, low=-0.1, high=0.1
)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -34,18 +34,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@property @property
def healthy_reward(self): def healthy_reward(self):
return ( return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -60,9 +55,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@property @property
def contact_cost(self): def contact_cost(self):
contact_cost = self._contact_cost_weight * np.sum( contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces))
np.square(self.contact_forces)
)
return contact_cost return contact_cost
@property @property
@@ -128,12 +121,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
low=noise_low, high=noise_high, size=self.model.nq qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
observation = self._get_obs() observation = self._get_obs()

View File

@@ -28,9 +28,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
) )
def reset_model(self): def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
low=-0.1, high=0.1, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -25,9 +25,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@@ -71,12 +69,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
low=noise_low, high=noise_high, size=self.model.nq qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
)
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)

View File

@@ -17,27 +17,16 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward += alive_bonus reward += alive_bonus
reward -= 1e-3 * np.square(a).sum() reward -= 1e-3 * np.square(a).sum()
s = self.state_vector() s = self.state_vector()
done = not ( done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2))
np.isfinite(s).all()
and (np.abs(s[2:]) < 100).all()
and (height > 0.7)
and (abs(ang) < 0.2)
)
ob = self._get_obs() ob = self._get_obs()
return ob, reward, done, {} return ob, reward, done, {}
def _get_obs(self): def _get_obs(self):
return np.concatenate( return np.concatenate([self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)])
[self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
)
def reset_model(self): def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq)
low=-0.005, high=0.005, size=self.model.nq qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
)
qvel = self.init_qvel + self.np_random.uniform(
low=-0.005, high=0.005, size=self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -40,18 +40,13 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4) mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@property @property
def healthy_reward(self): def healthy_reward(self):
return ( return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -117,12 +112,8 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
low=noise_low, high=noise_high, size=self.model.nq qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)

View File

@@ -43,18 +43,13 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
@property @property
def healthy_reward(self): def healthy_reward(self):
return ( return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl)) control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl))
@@ -143,12 +138,8 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
low=noise_low, high=noise_high, size=self.model.nq qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
observation = self._get_obs() observation = self._get_obs()

View File

@@ -33,8 +33,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self.set_state( self.set_state(
self.init_qpos self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1, self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
) )
return self._get_obs() return self._get_obs()

View File

@@ -17,12 +17,8 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, {} return ob, reward, done, {}
def reset_model(self): def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
size=self.model.nq, low=-0.01, high=0.01 qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
)
qvel = self.init_qvel + self.np_random.uniform(
size=self.model.nv, low=-0.01, high=0.01
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -22,14 +22,7 @@ DEFAULT_SIZE = 500
def convert_observation_to_space(observation): def convert_observation_to_space(observation):
if isinstance(observation, dict): if isinstance(observation, dict):
space = spaces.Dict( space = spaces.Dict(OrderedDict([(key, convert_observation_to_space(value)) for key, value in observation.items()]))
OrderedDict(
[
(key, convert_observation_to_space(value))
for key, value in observation.items()
]
)
)
elif isinstance(observation, np.ndarray): elif isinstance(observation, np.ndarray):
low = np.full(observation.shape, -float("inf"), dtype=np.float32) low = np.full(observation.shape, -float("inf"), dtype=np.float32)
high = np.full(observation.shape, float("inf"), dtype=np.float32) high = np.full(observation.shape, float("inf"), dtype=np.float32)
@@ -117,9 +110,7 @@ class MujocoEnv(gym.Env):
def set_state(self, qpos, qvel): def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
old_state = self.sim.get_state() old_state = self.sim.get_state()
new_state = mujoco_py.MjSimState( new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state)
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
)
self.sim.set_state(new_state) self.sim.set_state(new_state)
self.sim.forward() self.sim.forward()
@@ -142,10 +133,7 @@ class MujocoEnv(gym.Env):
): ):
if mode == "rgb_array" or mode == "depth_array": if mode == "rgb_array" or mode == "depth_array":
if camera_id is not None and camera_name is not None: if camera_id is not None and camera_name is not None:
raise ValueError( raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.")
"Both `camera_id` and `camera_name` cannot be"
" specified at the same time."
)
no_camera_specified = camera_name is None and camera_id is None no_camera_specified = camera_name is None and camera_id is None
if no_camera_specified: if no_camera_specified:

View File

@@ -44,9 +44,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
qpos[-4:-2] = self.cylinder_pos qpos[-4:-2] = self.cylinder_pos
qpos[-2:] = self.goal_pos qpos[-2:] = self.goal_pos
qvel = self.init_qvel + self.np_random.uniform( qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
low=-0.005, high=0.005, size=self.model.nv
)
qvel[-4:] = 0 qvel[-4:] = 0
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -22,18 +22,13 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.viewer.cam.trackbodyid = 0 self.viewer.cam.trackbodyid = 0
def reset_model(self): def reset_model(self):
qpos = ( qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
+ self.init_qpos
)
while True: while True:
self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2) self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
if np.linalg.norm(self.goal) < 0.2: if np.linalg.norm(self.goal) < 0.2:
break break
qpos[-2:] = self.goal qpos[-2:] = self.goal
qvel = self.init_qvel + self.np_random.uniform( qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
low=-0.005, high=0.005, size=self.model.nv
)
qvel[-2:] = 0 qvel[-2:] = 0
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -62,9 +62,7 @@ class StrikerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
diff = self.ball - self.goal diff = self.ball - self.goal
angle = -np.arctan(diff[0] / (diff[1] + 1e-8)) angle = -np.arctan(diff[0] / (diff[1] + 1e-8))
qpos[-1] = angle / 3.14 qpos[-1] = angle / 3.14
qvel = self.init_qvel + self.np_random.uniform( qvel = self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv)
low=-0.1, high=0.1, size=self.model.nv
)
qvel[7:] = 0 qvel[7:] = 0
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -26,9 +26,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self.set_state( self.set_state(
self.init_qpos self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
self.init_qvel
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
) )
return self._get_obs() return self._get_obs()

View File

@@ -22,9 +22,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4) mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@@ -74,12 +72,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
low=noise_low, high=noise_high, size=self.model.nq qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)

View File

@@ -48,9 +48,7 @@ class ThrowerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
) )
qpos[-9:-7] = self.goal qpos[-9:-7] = self.goal
qvel = self.init_qvel + self.np_random.uniform( qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
low=-0.005, high=0.005, size=self.model.nv
)
qvel[7:] = 0 qvel[7:] = 0
self.set_state(qpos, qvel) self.set_state(qpos, qvel)
return self._get_obs() return self._get_obs()

View File

@@ -27,10 +27,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self.set_state( self.set_state(
self.init_qpos self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq), self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
self.init_qvel
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
) )
return self._get_obs() return self._get_obs()

View File

@@ -37,18 +37,13 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._reset_noise_scale = reset_noise_scale self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
exclude_current_positions_from_observation
)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4) mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
@property @property
def healthy_reward(self): def healthy_reward(self):
return ( return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -110,12 +105,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
noise_low = -self._reset_noise_scale noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform( qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
low=noise_low, high=noise_high, size=self.model.nq qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
)
qvel = self.init_qvel + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nv
)
self.set_state(qpos, qvel) self.set_state(qpos, qvel)

View File

@@ -108,11 +108,7 @@ class EnvRegistry(object):
# reset/step. Set _gym_disable_underscore_compat = True on # reset/step. Set _gym_disable_underscore_compat = True on
# your environment if you use these methods and don't want # your environment if you use these methods and don't want
# compatibility code to be invoked. # compatibility code to be invoked.
if ( if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False):
hasattr(env, "_reset")
and hasattr(env, "_step")
and not getattr(env, "_gym_disable_underscore_compat", False)
):
patch_deprecated_methods(env) patch_deprecated_methods(env)
if env.spec.max_episode_steps is not None: if env.spec.max_episode_steps is not None:
from gym.wrappers.time_limit import TimeLimit from gym.wrappers.time_limit import TimeLimit
@@ -158,11 +154,7 @@ class EnvRegistry(object):
if env_name == valid_env_spec._env_name if env_name == valid_env_spec._env_name
] ]
if matching_envs: if matching_envs:
raise error.DeprecatedEnv( raise error.DeprecatedEnv("Env {} not found (valid versions include {})".format(id, matching_envs))
"Env {} not found (valid versions include {})".format(
id, matching_envs
)
)
else: else:
raise error.UnregisteredEnv("No registered env with id: {}".format(id)) raise error.UnregisteredEnv("No registered env with id: {}".format(id))

View File

@@ -81,9 +81,7 @@ class FetchEnv(robot_env.RobotEnv):
def _set_action(self, action): def _set_action(self, action):
assert action.shape == (4,) assert action.shape == (4,)
action = ( action = action.copy() # ensure that we don't change the action outside of this scope
action.copy()
) # ensure that we don't change the action outside of this scope
pos_ctrl, gripper_ctrl = action[:3], action[3] pos_ctrl, gripper_ctrl = action[:3], action[3]
pos_ctrl *= 0.05 # limit maximum change in position pos_ctrl *= 0.05 # limit maximum change in position
@@ -120,13 +118,9 @@ class FetchEnv(robot_env.RobotEnv):
object_rel_pos = object_pos - grip_pos object_rel_pos = object_pos - grip_pos
object_velp -= grip_velp object_velp -= grip_velp
else: else:
object_pos = ( object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
object_rot
) = object_velp = object_velr = object_rel_pos = np.zeros(0)
gripper_state = robot_qpos[-2:] gripper_state = robot_qpos[-2:]
gripper_vel = ( gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric
robot_qvel[-2:] * dt
) # change to a scalar if the gripper is made symmetric
if not self.has_object: if not self.has_object:
achieved_goal = grip_pos.copy() achieved_goal = grip_pos.copy()
@@ -175,9 +169,7 @@ class FetchEnv(robot_env.RobotEnv):
if self.has_object: if self.has_object:
object_xpos = self.initial_gripper_xpos[:2] object_xpos = self.initial_gripper_xpos[:2]
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1: while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform( object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2)
-self.obj_range, self.obj_range, size=2
)
object_qpos = self.sim.data.get_joint_qpos("object0:joint") object_qpos = self.sim.data.get_joint_qpos("object0:joint")
assert object_qpos.shape == (7,) assert object_qpos.shape == (7,)
object_qpos[:2] = object_xpos object_qpos[:2] = object_xpos
@@ -188,17 +180,13 @@ class FetchEnv(robot_env.RobotEnv):
def _sample_goal(self): def _sample_goal(self):
if self.has_object: if self.has_object:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform( goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
-self.target_range, self.target_range, size=3
)
goal += self.target_offset goal += self.target_offset
goal[2] = self.height_offset goal[2] = self.height_offset
if self.target_in_the_air and self.np_random.uniform() < 0.5: if self.target_in_the_air and self.np_random.uniform() < 0.5:
goal[2] += self.np_random.uniform(0, 0.45) goal[2] += self.np_random.uniform(0, 0.45)
else: else:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform( goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
-self.target_range, self.target_range, size=3
)
return goal.copy() return goal.copy()
def _is_success(self, achieved_goal, desired_goal): def _is_success(self, achieved_goal, desired_goal):
@@ -212,9 +200,9 @@ class FetchEnv(robot_env.RobotEnv):
self.sim.forward() self.sim.forward()
# Move end effector into position. # Move end effector into position.
gripper_target = np.array( gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos(
[-0.498, 0.005, -0.431 + self.gripper_extra_height] "robot0:grip"
) + self.sim.data.get_site_xpos("robot0:grip") )
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0]) gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target) self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation) self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)

View File

@@ -74,9 +74,7 @@ class ManipulateEnv(hand_env.HandEnv):
self.target_position = target_position self.target_position = target_position
self.target_rotation = target_rotation self.target_rotation = target_rotation
self.target_position_range = target_position_range self.target_position_range = target_position_range
self.parallel_quats = [ self.parallel_quats = [rotations.euler2quat(r) for r in rotations.get_parallel_rotations()]
rotations.euler2quat(r) for r in rotations.get_parallel_rotations()
]
self.randomize_initial_rotation = randomize_initial_rotation self.randomize_initial_rotation = randomize_initial_rotation
self.randomize_initial_position = randomize_initial_position self.randomize_initial_position = randomize_initial_position
self.distance_threshold = distance_threshold self.distance_threshold = distance_threshold
@@ -182,9 +180,7 @@ class ManipulateEnv(hand_env.HandEnv):
angle = self.np_random.uniform(-np.pi, np.pi) angle = self.np_random.uniform(-np.pi, np.pi)
axis = np.array([0.0, 0.0, 1.0]) axis = np.array([0.0, 0.0, 1.0])
z_quat = quat_from_angle_and_axis(angle, axis) z_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[ parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
self.np_random.randint(len(self.parallel_quats))
]
offset_quat = rotations.quat_mul(z_quat, parallel_quat) offset_quat = rotations.quat_mul(z_quat, parallel_quat)
initial_quat = rotations.quat_mul(initial_quat, offset_quat) initial_quat = rotations.quat_mul(initial_quat, offset_quat)
elif self.target_rotation in ["xyz", "ignore"]: elif self.target_rotation in ["xyz", "ignore"]:
@@ -195,9 +191,7 @@ class ManipulateEnv(hand_env.HandEnv):
elif self.target_rotation == "fixed": elif self.target_rotation == "fixed":
pass pass
else: else:
raise error.Error( raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
'Unknown target_rotation option "{}".'.format(self.target_rotation)
)
# Randomize initial position. # Randomize initial position.
if self.randomize_initial_position: if self.randomize_initial_position:
@@ -229,17 +223,13 @@ class ManipulateEnv(hand_env.HandEnv):
target_pos = None target_pos = None
if self.target_position == "random": if self.target_position == "random":
assert self.target_position_range.shape == (3, 2) assert self.target_position_range.shape == (3, 2)
offset = self.np_random.uniform( offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1])
self.target_position_range[:, 0], self.target_position_range[:, 1]
)
assert offset.shape == (3,) assert offset.shape == (3,)
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
elif self.target_position in ["ignore", "fixed"]: elif self.target_position in ["ignore", "fixed"]:
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
else: else:
raise error.Error( raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
'Unknown target_position option "{}".'.format(self.target_position)
)
assert target_pos is not None assert target_pos is not None
assert target_pos.shape == (3,) assert target_pos.shape == (3,)
@@ -253,9 +243,7 @@ class ManipulateEnv(hand_env.HandEnv):
angle = self.np_random.uniform(-np.pi, np.pi) angle = self.np_random.uniform(-np.pi, np.pi)
axis = np.array([0.0, 0.0, 1.0]) axis = np.array([0.0, 0.0, 1.0])
target_quat = quat_from_angle_and_axis(angle, axis) target_quat = quat_from_angle_and_axis(angle, axis)
parallel_quat = self.parallel_quats[ parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
self.np_random.randint(len(self.parallel_quats))
]
target_quat = rotations.quat_mul(target_quat, parallel_quat) target_quat = rotations.quat_mul(target_quat, parallel_quat)
elif self.target_rotation == "xyz": elif self.target_rotation == "xyz":
angle = self.np_random.uniform(-np.pi, np.pi) angle = self.np_random.uniform(-np.pi, np.pi)
@@ -264,9 +252,7 @@ class ManipulateEnv(hand_env.HandEnv):
elif self.target_rotation in ["ignore", "fixed"]: elif self.target_rotation in ["ignore", "fixed"]:
target_quat = self.sim.data.get_joint_qpos("object:joint") target_quat = self.sim.data.get_joint_qpos("object:joint")
else: else:
raise error.Error( raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
'Unknown target_rotation option "{}".'.format(self.target_rotation)
)
assert target_quat is not None assert target_quat is not None
assert target_quat.shape == (4,) assert target_quat.shape == (4,)
@@ -293,12 +279,8 @@ class ManipulateEnv(hand_env.HandEnv):
def _get_obs(self): def _get_obs(self):
robot_qpos, robot_qvel = robot_get_obs(self.sim) robot_qpos, robot_qvel = robot_get_obs(self.sim)
object_qvel = self.sim.data.get_joint_qvel("object:joint") object_qvel = self.sim.data.get_joint_qvel("object:joint")
achieved_goal = ( achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
self._get_achieved_goal().ravel() observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, achieved_goal])
) # this contains the object position + rotation
observation = np.concatenate(
[robot_qpos, robot_qvel, object_qvel, achieved_goal]
)
return { return {
"observation": observation.copy(), "observation": observation.copy(),
"achieved_goal": achieved_goal.copy(), "achieved_goal": achieved_goal.copy(),
@@ -307,9 +289,7 @@ class ManipulateEnv(hand_env.HandEnv):
class HandBlockEnv(ManipulateEnv, utils.EzPickle): class HandBlockEnv(ManipulateEnv, utils.EzPickle):
def __init__( def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__( ManipulateEnv.__init__(
self, self,
@@ -322,9 +302,7 @@ class HandBlockEnv(ManipulateEnv, utils.EzPickle):
class HandEggEnv(ManipulateEnv, utils.EzPickle): class HandEggEnv(ManipulateEnv, utils.EzPickle):
def __init__( def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__( ManipulateEnv.__init__(
self, self,
@@ -337,9 +315,7 @@ class HandEggEnv(ManipulateEnv, utils.EzPickle):
class HandPenEnv(ManipulateEnv, utils.EzPickle): class HandPenEnv(ManipulateEnv, utils.EzPickle):
def __init__( def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
self, target_position="random", target_rotation="xyz", reward_type="sparse"
):
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__( ManipulateEnv.__init__(
self, self,

View File

@@ -70,16 +70,12 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
for ( for (
k, k,
v, v,
) in ( ) in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
self.sim.model._sensor_name2id.items()
): # get touch sensor site names and their ids
if "robot0:TS_" in k: if "robot0:TS_" in k:
self._touch_sensor_id_site_id.append( self._touch_sensor_id_site_id.append(
( (
v, v,
self.sim.model._site_name2id[ self.sim.model._site_name2id[k.replace("robot0:TS_", "robot0:T_")],
k.replace("robot0:TS_", "robot0:T_")
],
) )
) )
self._touch_sensor_id.append(v) self._touch_sensor_id.append(v)
@@ -93,15 +89,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
obs = self._get_obs() obs = self._get_obs()
self.observation_space = spaces.Dict( self.observation_space = spaces.Dict(
dict( dict(
desired_goal=spaces.Box( desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32" achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
), observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
achieved_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
observation=spaces.Box(
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
),
) )
) )
@@ -117,9 +107,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
def _get_obs(self): def _get_obs(self):
robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim) robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim)
object_qvel = self.sim.data.get_joint_qvel("object:joint") object_qvel = self.sim.data.get_joint_qvel("object:joint")
achieved_goal = ( achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
self._get_achieved_goal().ravel()
) # this contains the object position + rotation
touch_values = [] # get touch sensor readings. if there is one, set value to 1 touch_values = [] # get touch sensor readings. if there is one, set value to 1
if self.touch_get_obs == "sensordata": if self.touch_get_obs == "sensordata":
touch_values = self.sim.data.sensordata[self._touch_sensor_id] touch_values = self.sim.data.sensordata[self._touch_sensor_id]
@@ -127,9 +115,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0 touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0
elif self.touch_get_obs == "log": elif self.touch_get_obs == "log":
touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0) touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0)
observation = np.concatenate( observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal])
[robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal]
)
return { return {
"observation": observation.copy(), "observation": observation.copy(),
@@ -146,9 +132,7 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata", touch_get_obs="sensordata",
reward_type="sparse", reward_type="sparse",
): ):
utils.EzPickle.__init__( utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
self, target_position, target_rotation, touch_get_obs, reward_type
)
ManipulateTouchSensorsEnv.__init__( ManipulateTouchSensorsEnv.__init__(
self, self,
model_path=MANIPULATE_BLOCK_XML, model_path=MANIPULATE_BLOCK_XML,
@@ -168,9 +152,7 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata", touch_get_obs="sensordata",
reward_type="sparse", reward_type="sparse",
): ):
utils.EzPickle.__init__( utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
self, target_position, target_rotation, touch_get_obs, reward_type
)
ManipulateTouchSensorsEnv.__init__( ManipulateTouchSensorsEnv.__init__(
self, self,
model_path=MANIPULATE_EGG_XML, model_path=MANIPULATE_EGG_XML,
@@ -190,9 +172,7 @@ class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
touch_get_obs="sensordata", touch_get_obs="sensordata",
reward_type="sparse", reward_type="sparse",
): ):
utils.EzPickle.__init__( utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
self, target_position, target_rotation, touch_get_obs, reward_type
)
ManipulateTouchSensorsEnv.__init__( ManipulateTouchSensorsEnv.__init__(
self, self,
model_path=MANIPULATE_PEN_XML, model_path=MANIPULATE_PEN_XML,

View File

@@ -96,9 +96,7 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
self.sim.forward() self.sim.forward()
self.initial_goal = self._get_achieved_goal().copy() self.initial_goal = self._get_achieved_goal().copy()
self.palm_xpos = self.sim.data.body_xpos[ self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id("robot0:palm")].copy()
self.sim.model.body_name2id("robot0:palm")
].copy()
def _get_obs(self): def _get_obs(self):
robot_qpos, robot_qvel = robot_get_obs(self.sim) robot_qpos, robot_qvel = robot_get_obs(self.sim)
@@ -155,7 +153,5 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
for finger_idx in range(5): for finger_idx in range(5):
site_name = "finger{}".format(finger_idx) site_name = "finger{}".format(finger_idx)
site_id = self.sim.model.site_name2id(site_name) site_id = self.sim.model.site_name2id(site_name)
self.sim.model.site_pos[site_id] = ( self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id]
achieved_goal[finger_idx] - sites_offset[site_id]
)
self.sim.forward() self.sim.forward()

View File

@@ -30,22 +30,14 @@ class HandEnv(robot_env.RobotEnv):
if self.relative_control: if self.relative_control:
actuation_center = np.zeros_like(action) actuation_center = np.zeros_like(action)
for i in range(self.sim.data.ctrl.shape[0]): for i in range(self.sim.data.ctrl.shape[0]):
actuation_center[i] = self.sim.data.get_joint_qpos( actuation_center[i] = self.sim.data.get_joint_qpos(self.sim.model.actuator_names[i].replace(":A_", ":"))
self.sim.model.actuator_names[i].replace(":A_", ":")
)
for joint_name in ["FF", "MF", "RF", "LF"]: for joint_name in ["FF", "MF", "RF", "LF"]:
act_idx = self.sim.model.actuator_name2id( act_idx = self.sim.model.actuator_name2id("robot0:A_{}J1".format(joint_name))
"robot0:A_{}J1".format(joint_name) actuation_center[act_idx] += self.sim.data.get_joint_qpos("robot0:{}J0".format(joint_name))
)
actuation_center[act_idx] += self.sim.data.get_joint_qpos(
"robot0:{}J0".format(joint_name)
)
else: else:
actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0 actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0
self.sim.data.ctrl[:] = actuation_center + action * actuation_range self.sim.data.ctrl[:] = actuation_center + action * actuation_range
self.sim.data.ctrl[:] = np.clip( self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1])
self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]
)
def _viewer_setup(self): def _viewer_setup(self):
body_id = self.sim.model.body_name2id("robot0:palm") body_id = self.sim.model.body_name2id("robot0:palm")

View File

@@ -46,15 +46,9 @@ class RobotEnv(gym.GoalEnv):
self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32") self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32")
self.observation_space = spaces.Dict( self.observation_space = spaces.Dict(
dict( dict(
desired_goal=spaces.Box( desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32" achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
), observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
achieved_goal=spaces.Box(
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
),
observation=spaces.Box(
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
),
) )
) )

View File

@@ -164,12 +164,8 @@ def mat2euler(mat):
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]), -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]),
) )
euler[..., 1] = np.where( euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy))
condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy) euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0)
)
euler[..., 0] = np.where(
condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0
)
return euler return euler

View File

@@ -75,15 +75,9 @@ def reset_mocap2body_xpos(sim):
values as the bodies they're welded to. values as the bodies they're welded to.
""" """
if ( if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None:
sim.model.eq_type is None
or sim.model.eq_obj1id is None
or sim.model.eq_obj2id is None
):
return return
for eq_type, obj1_id, obj2_id in zip( for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id):
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id
):
if eq_type != mujoco_py.const.EQ_WELD: if eq_type != mujoco_py.const.EQ_WELD:
continue continue

View File

@@ -2,10 +2,7 @@ from gym import envs, logger
import os import os
SKIP_MUJOCO_WARNING_MESSAGE = ( SKIP_MUJOCO_WARNING_MESSAGE = "Cannot run mujoco test (either license key not found or mujoco not" "installed properly)."
"Cannot run mujoco test (either license key not found or mujoco not"
"installed properly)."
)
skip_mujoco = not (os.environ.get("MUJOCO_KEY")) skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
@@ -21,9 +18,7 @@ def should_skip_env_spec_for_tests(spec):
# troublesome to run frequently # troublesome to run frequently
ep = spec.entry_point ep = spec.entry_point
# Skip mujoco tests for pull request CI # Skip mujoco tests for pull request CI
if skip_mujoco and ( if skip_mujoco and (ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")):
ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")
):
return True return True
try: try:
import atari_py import atari_py
@@ -39,11 +34,7 @@ def should_skip_env_spec_for_tests(spec):
if ( if (
"GoEnv" in ep "GoEnv" in ep
or "HexEnv" in ep or "HexEnv" in ep
or ( or (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
ep.startswith("gym.envs.atari")
and not spec.id.startswith("Pong")
and not spec.id.startswith("Seaquest")
)
): ):
logger.warn("Skipping tests for env {}".format(ep)) logger.warn("Skipping tests for env {}".format(ep))
return True return True

View File

@@ -25,9 +25,7 @@ def test_env(spec):
step_responses2 = [env2.step(action) for action in action_samples2] step_responses2 = [env2.step(action) for action in action_samples2]
env2.close() env2.close()
for i, (action_sample1, action_sample2) in enumerate( for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
zip(action_samples1, action_samples2)
):
try: try:
assert_equals(action_sample1, action_sample2) assert_equals(action_sample1, action_sample2)
except AssertionError: except AssertionError:
@@ -35,11 +33,7 @@ def test_env(spec):
print("env2.action_space=", env2.action_space) print("env2.action_space=", env2.action_space)
print("action_samples1=", action_samples1) print("action_samples1=", action_samples1)
print("action_samples2=", action_samples2) print("action_samples2=", action_samples2)
print( print("[{}] action_sample1: {}, action_sample2: {}".format(i, action_sample1, action_sample2))
"[{}] action_sample1: {}, action_sample2: {}".format(
i, action_sample1, action_sample2
)
)
raise raise
# Don't check rollout equality if it's a a nondeterministic # Don't check rollout equality if it's a a nondeterministic
@@ -49,9 +43,7 @@ def test_env(spec):
assert_equals(initial_observation1, initial_observation2) assert_equals(initial_observation1, initial_observation2)
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate( for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
zip(step_responses1, step_responses2)
):
assert_equals(o1, o2, "[{}] ".format(i)) assert_equals(o1, o2, "[{}] ".format(i))
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2) assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2) assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
@@ -66,9 +58,7 @@ def test_env(spec):
def assert_equals(a, b, prefix=None): def assert_equals(a, b, prefix=None):
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b) assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
if isinstance(a, dict): if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format( assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b)
prefix, a, b
)
for k in a.keys(): for k in a.keys():
v_a = a[k] v_a = a[k]

View File

@@ -24,9 +24,7 @@ def test_env(spec):
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob) assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
a = act_space.sample() a = act_space.sample()
observation, reward, done, _info = env.step(a) observation, reward, done, _info = env.step(a)
assert ob_space.contains(observation), "Step observation: {!r} not in space".format( assert ob_space.contains(observation), "Step observation: {!r} not in space".format(observation)
observation
)
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env) assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
assert isinstance(done, bool), "Expected {} to be a boolean".format(done) assert isinstance(done, bool), "Expected {} to be a boolean".format(done)

View File

@@ -81,9 +81,7 @@ def test_env_semantics(spec):
if spec.id not in rollout_dict: if spec.id not in rollout_dict:
if not spec.nondeterministic: if not spec.nondeterministic:
logger.warn( logger.warn(
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format( "Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id)
spec.id
)
) )
return return
@@ -100,21 +98,15 @@ def test_env_semantics(spec):
) )
if rollout_dict[spec.id]["actions"] != actions_now: if rollout_dict[spec.id]["actions"] != actions_now:
errors.append( errors.append(
"Actions not equal for {} -- expected {} but got {}".format( "Actions not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["actions"], actions_now)
spec.id, rollout_dict[spec.id]["actions"], actions_now
)
) )
if rollout_dict[spec.id]["rewards"] != rewards_now: if rollout_dict[spec.id]["rewards"] != rewards_now:
errors.append( errors.append(
"Rewards not equal for {} -- expected {} but got {}".format( "Rewards not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["rewards"], rewards_now)
spec.id, rollout_dict[spec.id]["rewards"], rewards_now
)
) )
if rollout_dict[spec.id]["dones"] != dones_now: if rollout_dict[spec.id]["dones"] != dones_now:
errors.append( errors.append(
"Dones not equal for {} -- expected {} but got {}".format( "Dones not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["dones"], dones_now)
spec.id, rollout_dict[spec.id]["dones"], dones_now
)
) )
if len(errors): if len(errors):
for error in errors: for error in errors:

View File

@@ -4,9 +4,7 @@ from gym import envs
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
def verify_environments_match( def verify_environments_match(old_environment_id, new_environment_id, seed=1, num_actions=1000):
old_environment_id, new_environment_id, seed=1, num_actions=1000
):
old_environment = envs.make(old_environment_id) old_environment = envs.make(old_environment_id)
new_environment = envs.make(new_environment_id) new_environment = envs.make(new_environment_id)

View File

@@ -83,8 +83,6 @@ def test_malformed_lookup():
try: try:
registry.spec(u"“Breakout-v0”") registry.spec(u"“Breakout-v0”")
except error.Error as e: except error.Error as e:
assert "malformed environment ID" in "{}".format( assert "malformed environment ID" in "{}".format(e), "Unexpected message: {}".format(e)
e
), "Unexpected message: {}".format(e)
else: else:
assert False assert False

View File

@@ -75,9 +75,7 @@ class BlackjackEnv(gym.Env):
def __init__(self, natural=False): def __init__(self, natural=False):
self.action_space = spaces.Discrete(2) self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Tuple( self.observation_space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
self.seed() self.seed()
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules # Flag to payout 1.5 on a "natural" blackjack win, like casino rules

View File

@@ -141,9 +141,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
else: else:
if is_slippery: if is_slippery:
for b in [(a - 1) % 4, a, (a + 1) % 4]: for b in [(a - 1) % 4, a, (a + 1) % 4]:
li.append( li.append((1.0 / 3.0, *update_probability_matrix(row, col, b)))
(1.0 / 3.0, *update_probability_matrix(row, col, b))
)
else: else:
li.append((1.0, *update_probability_matrix(row, col, a))) li.append((1.0, *update_probability_matrix(row, col, a)))
@@ -157,9 +155,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
desc = [[c.decode("utf-8") for c in line] for line in desc] desc = [[c.decode("utf-8") for c in line] for line in desc]
desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True) desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
if self.lastaction is not None: if self.lastaction is not None:
outfile.write( outfile.write(" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction]))
" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction])
)
else: else:
outfile.write("\n") outfile.write("\n")
outfile.write("\n".join("".join(line) for line in desc) + "\n") outfile.write("\n".join("".join(line) for line in desc) + "\n")

View File

@@ -80,11 +80,7 @@ class GuessingGame(gym.Env):
reward = 0 reward = 0
done = False done = False
if ( if (self.number - self.range * 0.01) < action < (self.number + self.range * 0.01):
(self.number - self.range * 0.01)
< action
< (self.number + self.range * 0.01)
):
reward = 1 reward = 1
done = True done = True

View File

@@ -62,10 +62,7 @@ class HotterColder(gym.Env):
elif action > self.number: elif action > self.number:
self.observation = 3 self.observation = 3
reward = ( reward = ((min(action, self.number) + self.bounds) / (max(action, self.number) + self.bounds)) ** 2
(min(action, self.number) + self.bounds)
/ (max(action, self.number) + self.bounds)
) ** 2
self.guess_count += 1 self.guess_count += 1
done = self.guess_count >= self.guess_max done = self.guess_count >= self.guess_max

View File

@@ -65,9 +65,7 @@ class KellyCoinflipEnv(gym.Env):
return [seed] return [seed]
def step(self, action): def step(self, action):
bet_in_dollars = min( bet_in_dollars = min(action / 100.0, self.wealth) # action = desired bet in pennies
action / 100.0, self.wealth
) # action = desired bet in pennies
self.rounds -= 1 self.rounds -= 1
coinflip = flip(self.edge, self.np_random) coinflip = flip(self.edge, self.np_random)
@@ -149,35 +147,19 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta) edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
if self.clip_distributions: if self.clip_distributions:
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds) # (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
max_wealth_bound = round( max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m))
genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m)
)
max_wealth = max_wealth_bound + 1.0 max_wealth = max_wealth_bound + 1.0
while max_wealth > max_wealth_bound: while max_wealth > max_wealth_bound:
max_wealth = round( max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
genpareto.rvs( max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)))
max_wealth_alpha, max_wealth_m, random_state=self.np_random
)
)
max_rounds_bound = int(
round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd))
)
max_rounds = max_rounds_bound + 1 max_rounds = max_rounds_bound + 1
while max_rounds > max_rounds_bound: while max_rounds > max_rounds_bound:
max_rounds = int( max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
)
else: else:
max_wealth = round( max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
genpareto.rvs(
max_wealth_alpha, max_wealth_m, random_state=self.np_random
)
)
max_wealth_bound = max_wealth max_wealth_bound = max_wealth
max_rounds = int( max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
)
max_rounds_bound = max_rounds max_rounds_bound = max_rounds
# add an additional global variable which is the sufficient statistic for the # add an additional global variable which is the sufficient statistic for the
@@ -194,9 +176,7 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
self.action_space = spaces.Discrete(int(max_wealth_bound * 100)) self.action_space = spaces.Discrete(int(max_wealth_bound * 100))
self.observation_space = spaces.Tuple( self.observation_space = spaces.Tuple(
( (
spaces.Box( spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth
0, max_wealth_bound, shape=[1], dtype=np.float32
), # current wealth
spaces.Discrete(max_rounds_bound + 1), # rounds elapsed spaces.Discrete(max_rounds_bound + 1), # rounds elapsed
spaces.Discrete(max_rounds_bound + 1), # wins spaces.Discrete(max_rounds_bound + 1), # wins
spaces.Discrete(max_rounds_bound + 1), # losses spaces.Discrete(max_rounds_bound + 1), # losses

View File

@@ -80,10 +80,7 @@ class TaxiEnv(discrete.DiscreteEnv):
max_col = num_columns - 1 max_col = num_columns - 1
initial_state_distrib = np.zeros(num_states) initial_state_distrib = np.zeros(num_states)
num_actions = 6 num_actions = 6
P = { P = {state: {action: [] for action in range(num_actions)} for state in range(num_states)}
state: {action: [] for action in range(num_actions)}
for state in range(num_states)
}
for row in range(num_rows): for row in range(num_rows):
for col in range(num_columns): for col in range(num_columns):
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
@@ -94,9 +91,7 @@ class TaxiEnv(discrete.DiscreteEnv):
for action in range(num_actions): for action in range(num_actions):
# defaults # defaults
new_row, new_col, new_pass_idx = row, col, pass_idx new_row, new_col, new_pass_idx = row, col, pass_idx
reward = ( reward = -1 # default reward when there is no pickup/dropoff
-1
) # default reward when there is no pickup/dropoff
done = False done = False
taxi_loc = (row, col) taxi_loc = (row, col)
@@ -122,14 +117,10 @@ class TaxiEnv(discrete.DiscreteEnv):
new_pass_idx = locs.index(taxi_loc) new_pass_idx = locs.index(taxi_loc)
else: # dropoff at wrong location else: # dropoff at wrong location
reward = -10 reward = -10
new_state = self.encode( new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
new_row, new_col, new_pass_idx, dest_idx
)
P[state][action].append((1.0, new_state, reward, done)) P[state][action].append((1.0, new_state, reward, done))
initial_state_distrib /= initial_state_distrib.sum() initial_state_distrib /= initial_state_distrib.sum()
discrete.DiscreteEnv.__init__( discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, initial_state_distrib)
self, num_states, num_actions, P, initial_state_distrib
)
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx): def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
# (5) 5, 5, 4 # (5) 5, 5, 4
@@ -165,13 +156,9 @@ class TaxiEnv(discrete.DiscreteEnv):
return "_" if x == " " else x return "_" if x == " " else x
if pass_idx < 4: if pass_idx < 4:
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize( out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True)
out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True
)
pi, pj = self.locs[pass_idx] pi, pj = self.locs[pass_idx]
out[1 + pi][2 * pj + 1] = utils.colorize( out[1 + pi][2 * pj + 1] = utils.colorize(out[1 + pi][2 * pj + 1], "blue", bold=True)
out[1 + pi][2 * pj + 1], "blue", bold=True
)
else: # passenger in taxi else: # passenger in taxi
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize( out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
@@ -181,13 +168,7 @@ class TaxiEnv(discrete.DiscreteEnv):
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta") out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
outfile.write("\n".join(["".join(row) for row in out]) + "\n") outfile.write("\n".join(["".join(row) for row in out]) + "\n")
if self.lastaction is not None: if self.lastaction is not None:
outfile.write( outfile.write(" ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
" ({})\n".format(
["South", "North", "East", "West", "Pickup", "Dropoff"][
self.lastaction
]
)
)
else: else:
outfile.write("\n") outfile.write("\n")

View File

@@ -55,9 +55,7 @@ class CubeCrash(gym.Env):
self.seed() self.seed()
self.viewer = None self.viewer = None
self.observation_space = spaces.Box( self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
self.reset() self.reset()
@@ -83,16 +81,9 @@ class CubeCrash(gym.Env):
self.potential = None self.potential = None
self.step_n = 0 self.step_n = 0
while 1: while 1:
self.wall_color = ( self.wall_color = self.random_color() if self.use_random_colors else color_white
self.random_color() if self.use_random_colors else color_white self.cube_color = self.random_color() if self.use_random_colors else color_green
) if np.linalg.norm(self.wall_color - self.bg_color) < 50 or np.linalg.norm(self.cube_color - self.bg_color) < 50:
self.cube_color = (
self.random_color() if self.use_random_colors else color_green
)
if (
np.linalg.norm(self.wall_color - self.bg_color) < 50
or np.linalg.norm(self.cube_color - self.bg_color) < 50
):
continue continue
break break
return self.step(0)[0] return self.step(0)[0]
@@ -117,9 +108,7 @@ class CubeCrash(gym.Env):
self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1, self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1,
:, :,
] = self.bg_color ] = self.bg_color
obs[ obs[self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :] = self.cube_color
self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :
] = self.cube_color
if self.use_black_screen and self.step_n > 4: if self.use_black_screen and self.step_n > 4:
obs[:] = np.zeros((3,), dtype=np.uint8) obs[:] = np.zeros((3,), dtype=np.uint8)

View File

@@ -62,16 +62,12 @@ class MemorizeDigits(gym.Env):
def __init__(self): def __init__(self):
self.seed() self.seed()
self.viewer = None self.viewer = None
self.observation_space = spaces.Box( self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(10) self.action_space = spaces.Discrete(10)
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8) self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
for digit in range(10): for digit in range(10):
for y in range(6): for y in range(6):
self.bogus_mnist[digit, y, :] = [ self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]]
ord(char) for char in bogus_mnist[digit][y]
]
self.reset() self.reset()
def seed(self, seed=None): def seed(self, seed=None):
@@ -93,9 +89,7 @@ class MemorizeDigits(gym.Env):
self.color_bg = self.random_color() if self.use_random_colors else color_black self.color_bg = self.random_color() if self.use_random_colors else color_black
self.step_n = 0 self.step_n = 0
while 1: while 1:
self.color_digit = ( self.color_digit = self.random_color() if self.use_random_colors else color_white
self.random_color() if self.use_random_colors else color_white
)
if np.linalg.norm(self.color_digit - self.color_bg) < 50: if np.linalg.norm(self.color_digit - self.color_bg) < 50:
continue continue
break break
@@ -119,9 +113,7 @@ class MemorizeDigits(gym.Env):
digit_img[:] = self.color_bg digit_img[:] = self.color_bg
xxx = self.bogus_mnist[self.digit] == 42 xxx = self.bogus_mnist[self.digit] == 42
digit_img[xxx] = self.color_digit digit_img[xxx] = self.color_digit
obs[ obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img
self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
] = digit_img
self.last_obs = obs self.last_obs = obs
return obs, reward, done, {} return obs, reward, done, {}

View File

@@ -102,10 +102,7 @@ class APIError(Error):
try: try:
http_body = http_body.decode("utf-8") http_body = http_body.decode("utf-8")
except: except:
http_body = ( http_body = "<Could not decode body as utf-8. " "Please report to gym@openai.com>"
"<Could not decode body as utf-8. "
"Please report to gym@openai.com>"
)
self._message = message self._message = message
self.http_body = http_body self.http_body = http_body
@@ -142,9 +139,7 @@ class InvalidRequestError(APIError):
json_body=None, json_body=None,
headers=None, headers=None,
): ):
super(InvalidRequestError, self).__init__( super(InvalidRequestError, self).__init__(message, http_body, http_status, json_body, headers)
message, http_body, http_status, json_body, headers
)
self.param = param self.param = param

View File

@@ -29,26 +29,16 @@ class Box(Space):
# determine shape if it isn't provided directly # determine shape if it isn't provided directly
if shape is not None: if shape is not None:
shape = tuple(shape) shape = tuple(shape)
assert ( assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
np.isscalar(low) or low.shape == shape assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
), "low.shape doesn't match provided shape"
assert (
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match provided shape"
elif not np.isscalar(low): elif not np.isscalar(low):
shape = low.shape shape = low.shape
assert ( assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
np.isscalar(high) or high.shape == shape
), "high.shape doesn't match low.shape"
elif not np.isscalar(high): elif not np.isscalar(high):
shape = high.shape shape = high.shape
assert ( assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
np.isscalar(low) or low.shape == shape
), "low.shape doesn't match high.shape"
else: else:
raise ValueError( raise ValueError("shape must be provided or inferred from the shapes of low or high")
"shape must be provided or inferred from the shapes of low or high"
)
if np.isscalar(low): if np.isscalar(low):
low = np.full(shape, low, dtype=dtype) low = np.full(shape, low, dtype=dtype)
@@ -70,9 +60,7 @@ class Box(Space):
high_precision = _get_precision(self.high.dtype) high_precision = _get_precision(self.high.dtype)
dtype_precision = _get_precision(self.dtype) dtype_precision = _get_precision(self.dtype)
if min(low_precision, high_precision) > dtype_precision: if min(low_precision, high_precision) > dtype_precision:
logger.warn( logger.warn("Box bound precision lowered by casting to {}".format(self.dtype))
"Box bound precision lowered by casting to {}".format(self.dtype)
)
self.low = self.low.astype(self.dtype) self.low = self.low.astype(self.dtype)
self.high = self.high.astype(self.dtype) self.high = self.high.astype(self.dtype)
@@ -119,19 +107,11 @@ class Box(Space):
# Vectorized sampling by interval type # Vectorized sampling by interval type
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape) sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
sample[low_bounded] = ( sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded]
self.np_random.exponential(size=low_bounded[low_bounded].shape)
+ self.low[low_bounded]
)
sample[upp_bounded] = ( sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded]
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
+ self.high[upp_bounded]
)
sample[bounded] = self.np_random.uniform( sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape)
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
)
if self.dtype.kind == "i": if self.dtype.kind == "i":
sample = np.floor(sample) sample = np.floor(sample)
@@ -140,9 +120,7 @@ class Box(Space):
def contains(self, x): def contains(self, x):
if isinstance(x, list): if isinstance(x, list):
x = np.array(x) # Promote list to array for contains check x = np.array(x) # Promote list to array for contains check
return ( return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
)
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
@@ -151,9 +129,7 @@ class Box(Space):
return [np.asarray(sample) for sample in sample_n] return [np.asarray(sample) for sample in sample_n]
def __repr__(self): def __repr__(self):
return "Box({}, {}, {}, {})".format( return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype)
self.low.min(), self.high.max(), self.shape, self.dtype
)
def __eq__(self, other): def __eq__(self, other):
return ( return (

View File

@@ -33,9 +33,7 @@ class Dict(Space):
""" """
def __init__(self, spaces=None, **spaces_kwargs): def __init__(self, spaces=None, **spaces_kwargs):
assert (spaces is None) or ( assert (spaces is None) or (not spaces_kwargs), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
not spaces_kwargs
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
if spaces is None: if spaces is None:
spaces = spaces_kwargs spaces = spaces_kwargs
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict): if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
@@ -44,12 +42,8 @@ class Dict(Space):
spaces = OrderedDict(spaces) spaces = OrderedDict(spaces)
self.spaces = spaces self.spaces = spaces
for space in spaces.values(): for space in spaces.values():
assert isinstance( assert isinstance(space, Space), "Values of the dict should be instances of gym.Space"
space, Space super(Dict, self).__init__(None, None) # None for shape and dtype, since it'll require special handling
), "Values of the dict should be instances of gym.Space"
super(Dict, self).__init__(
None, None
) # None for shape and dtype, since it'll require special handling
def seed(self, seed=None): def seed(self, seed=None):
[space.seed(seed) for space in self.spaces.values()] [space.seed(seed) for space in self.spaces.values()]
@@ -75,18 +69,11 @@ class Dict(Space):
yield key yield key
def __repr__(self): def __repr__(self):
return ( return "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
"Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
# serialize as dict-repr of vectors # serialize as dict-repr of vectors
return { return {key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items()}
key: space.to_jsonable([sample[key] for sample in sample_n])
for key, space in self.spaces.items()
}
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n):
dict_of_list = {} dict_of_list = {}

View File

@@ -22,9 +22,7 @@ class Discrete(Space):
def contains(self, x): def contains(self, x):
if isinstance(x, int): if isinstance(x, int):
as_int = x as_int = x
elif isinstance(x, (np.generic, np.ndarray)) and ( elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()):
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
):
as_int = int(x) as_int = int(x)
else: else:
return False return False

View File

@@ -34,9 +34,7 @@ class MultiDiscrete(Space):
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype) super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
def sample(self): def sample(self):
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype( return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
self.dtype
)
def contains(self, x): def contains(self, x):
if isinstance(x, list): if isinstance(x, list):

View File

@@ -25,9 +25,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
], ],
@@ -71,9 +69,7 @@ def test_roundtripping(space):
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
], ],

View File

@@ -28,9 +28,7 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, u
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
7, 7,
@@ -60,18 +58,14 @@ def test_flatdim(space, flatdim):
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
], ],
) )
def test_flatten_space_boxes(space): def test_flatten_space_boxes(space):
flat_space = utils.flatten_space(space) flat_space = utils.flatten_space(space)
assert isinstance(flat_space, Box), "Expected {} to equal {}".format( assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
type(flat_space), Box
)
flatdim = utils.flatdim(space) flatdim = utils.flatdim(space)
(single_dim,) = flat_space.shape (single_dim,) = flat_space.shape
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim) assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
@@ -95,9 +89,7 @@ def test_flatten_space_boxes(space):
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
], ],
@@ -107,9 +99,7 @@ def test_flat_space_contains_flat_points(space):
flattened_samples = [utils.flatten(space, sample) for sample in some_samples] flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
flat_space = utils.flatten_space(space) flat_space = utils.flatten_space(space)
for i, flat_sample in enumerate(flattened_samples): for i, flat_sample in enumerate(flattened_samples):
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format( assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(i, flat_sample, flat_space)
i, flat_sample, flat_space
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -130,9 +120,7 @@ def test_flat_space_contains_flat_points(space):
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
], ],
@@ -162,9 +150,7 @@ def test_flatten_dim(space):
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
),
} }
), ),
], ],
@@ -172,15 +158,9 @@ def test_flatten_dim(space):
def test_flatten_roundtripping(space): def test_flatten_roundtripping(space):
some_samples = [space.sample() for _ in range(10)] some_samples = [space.sample() for _ in range(10)]
flattened_samples = [utils.flatten(space, sample) for sample in some_samples] flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
roundtripped_samples = [ roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
utils.unflatten(space, sample) for sample in flattened_samples for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
] assert compare_nested(original, roundtripped), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
for i, (original, roundtripped) in enumerate(
zip(some_samples, roundtripped_samples)
):
assert compare_nested(
original, roundtripped
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
def compare_nested(left, right): def compare_nested(left, right):
@@ -188,9 +168,7 @@ def compare_nested(left, right):
return np.allclose(left, right) return np.allclose(left, right)
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict): elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
res = len(left) == len(right) res = len(left) == len(right)
for ((left_key, left_value), (right_key, right_value)) in zip( for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
left.items(), right.items()
):
if not res: if not res:
return False return False
res = left_key == right_key and compare_nested(left_value, right_value) res = left_key == right_key and compare_nested(left_value, right_value)
@@ -238,9 +216,7 @@ Expecteded flattened types are based off:
Dict( Dict(
{ {
"position": Discrete(5), "position": Discrete(5),
"velocity": Box( "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16),
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16
),
} }
), ),
np.float64, np.float64,
@@ -254,12 +230,8 @@ def test_dtypes(original_space, expected_flattened_dtype):
flattened_sample = utils.flatten(original_space, original_sample) flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample) unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains( assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
flattened_sample assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " "{}".format(
), "Expected flattened_space to contain flattened_sample"
assert (
flattened_space.dtype == expected_flattened_dtype
), "Expected flattened_space's dtype to equal " "{}".format(
expected_flattened_dtype expected_flattened_dtype
) )
@@ -272,9 +244,10 @@ def test_dtypes(original_space, expected_flattened_dtype):
def compare_sample_types(original_space, original_sample, unflattened_sample): def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete): if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), ( assert isinstance(
"Expected unflattened_sample to be an int. unflattened_sample: " unflattened_sample, int
"{} original_sample: {}".format(unflattened_sample, original_sample) ), "Expected unflattened_sample to be an int. unflattened_sample: " "{} original_sample: {}".format(
unflattened_sample, original_sample
) )
elif isinstance(original_space, Tuple): elif isinstance(original_space, Tuple):
for index in range(len(original_space)): for index in range(len(original_space)):

View File

@@ -13,9 +13,7 @@ class Tuple(Space):
def __init__(self, spaces): def __init__(self, spaces):
self.spaces = spaces self.spaces = spaces
for space in spaces: for space in spaces:
assert isinstance( assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
space, Space
), "Elements of the tuple must be instances of gym.Space"
super(Tuple, self).__init__(None, None) super(Tuple, self).__init__(None, None)
def seed(self, seed=None): def seed(self, seed=None):
@@ -38,21 +36,10 @@ class Tuple(Space):
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
# serialize as list-repr of tuple of vectors # serialize as list-repr of tuple of vectors
return [ return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)]
space.to_jsonable([sample[i] for sample in sample_n])
for i, space in enumerate(self.spaces)
]
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n):
return [ return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
sample
for sample in zip(
*[
space.from_jsonable(sample_n[i])
for i, space in enumerate(self.spaces)
]
)
]
def __getitem__(self, index): def __getitem__(self, index):
return self.spaces[index] return self.spaces[index]

View File

@@ -49,9 +49,7 @@ def flatten(space, x):
onehot[x] = 1 onehot[x] = 1
return onehot return onehot
elif isinstance(space, Tuple): elif isinstance(space, Tuple):
return np.concatenate( return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
)
elif isinstance(space, Dict): elif isinstance(space, Dict):
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary): elif isinstance(space, MultiBinary):
@@ -79,17 +77,13 @@ def unflatten(space, x):
elif isinstance(space, Tuple): elif isinstance(space, Tuple):
dims = [flatdim(s) for s in space.spaces] dims = [flatdim(s) for s in space.spaces]
list_flattened = np.split(x, np.cumsum(dims)[:-1]) list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [ list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)]
unflatten(s, flattened)
for flattened, s in zip(list_flattened, space.spaces)
]
return tuple(list_unflattened) return tuple(list_unflattened)
elif isinstance(space, Dict): elif isinstance(space, Dict):
dims = [flatdim(s) for s in space.spaces.values()] dims = [flatdim(s) for s in space.spaces.values()]
list_flattened = np.split(x, np.cumsum(dims)[:-1]) list_flattened = np.split(x, np.cumsum(dims)[:-1])
list_unflattened = [ list_unflattened = [
(key, unflatten(s, flattened)) (key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
] ]
return OrderedDict(list_unflattened) return OrderedDict(list_unflattened)
elif isinstance(space, MultiBinary): elif isinstance(space, MultiBinary):

View File

@@ -88,11 +88,7 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N
elif hasattr(env.unwrapped, "get_keys_to_action"): elif hasattr(env.unwrapped, "get_keys_to_action"):
keys_to_action = env.unwrapped.get_keys_to_action() keys_to_action = env.unwrapped.get_keys_to_action()
else: else:
assert False, ( assert False, env.spec.id + " does not have explicit key to action mapping, " + "please specify one manually"
env.spec.id
+ " does not have explicit key to action mapping, "
+ "please specify one manually"
)
relevant_keys = set(sum(map(list, keys_to_action.keys()), [])) relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
video_size = [rendered.shape[1], rendered.shape[0]] video_size = [rendered.shape[1], rendered.shape[0]]
@@ -172,9 +168,7 @@ class PlayPlot(object):
for i, plot in enumerate(self.cur_plot): for i, plot in enumerate(self.cur_plot):
if plot is not None: if plot is not None:
plot.remove() plot.remove()
self.cur_plot[i] = self.ax[i].scatter( self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]), c="blue")
range(xmin, xmax), list(self.data[i]), c="blue"
)
self.ax[i].set_xlim(xmin, xmax) self.ax[i].set_xlim(xmin, xmax)
plt.pause(0.000001) plt.pause(0.000001)

View File

@@ -10,9 +10,7 @@ from gym import error
def np_random(seed=None): def np_random(seed=None):
if seed is not None and not (isinstance(seed, int) and 0 <= seed): if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error( raise error.Error("Seed must be a non-negative integer or omitted, not {}".format(seed))
"Seed must be a non-negative integer or omitted, not {}".format(seed)
)
seed = create_seed(seed) seed = create_seed(seed)

View File

@@ -53,9 +53,7 @@ def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
if wrappers is not None: if wrappers is not None:
if callable(wrappers): if callable(wrappers):
env = wrappers(env) env = wrappers(env)
elif isinstance(wrappers, Iterable) and all( elif isinstance(wrappers, Iterable) and all([callable(w) for w in wrappers]):
[callable(w) for w in wrappers]
):
for wrapper in wrappers: for wrapper in wrappers:
env = wrapper(env) env = wrapper(env)
else: else:

View File

@@ -107,12 +107,8 @@ class AsyncVectorEnv(VectorEnv):
if self.shared_memory: if self.shared_memory:
try: try:
_obs_buffer = create_shared_memory( _obs_buffer = create_shared_memory(self.single_observation_space, n=self.num_envs, ctx=ctx)
self.single_observation_space, n=self.num_envs, ctx=ctx self.observations = read_from_shared_memory(_obs_buffer, self.single_observation_space, n=self.num_envs)
)
self.observations = read_from_shared_memory(
_obs_buffer, self.single_observation_space, n=self.num_envs
)
except CustomSpaceError: except CustomSpaceError:
raise ValueError( raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` " "Using `shared_memory=True` in `AsyncVectorEnv` "
@@ -124,9 +120,7 @@ class AsyncVectorEnv(VectorEnv):
) )
else: else:
_obs_buffer = None _obs_buffer = None
self.observations = create_empty_array( self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.parent_pipes, self.processes = [], [] self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue() self.error_queue = ctx.Queue()
@@ -168,8 +162,7 @@ class AsyncVectorEnv(VectorEnv):
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
"Calling `seed` while waiting " "Calling `seed` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
"for a pending call to `{0}` to complete.".format(self._state.value),
self._state.value, self._state.value,
) )
@@ -182,8 +175,7 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
"Calling `reset_async` while waiting " "Calling `reset_async` while waiting " "for a pending call to `{0}` to complete".format(self._state.value),
"for a pending call to `{0}` to complete".format(self._state.value),
self._state.value, self._state.value,
) )
@@ -214,8 +206,7 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout): if not self._poll(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise mp.TimeoutError(
"The call to `reset_wait` has timed out after " "The call to `reset_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
) )
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -223,9 +214,7 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(results, self.observations, self.single_observation_space)
results, self.observations, self.single_observation_space
)
return deepcopy(self.observations) if self.copy else self.observations return deepcopy(self.observations) if self.copy else self.observations
@@ -239,8 +228,7 @@ class AsyncVectorEnv(VectorEnv):
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError( raise AlreadyPendingCallError(
"Calling `step_async` while waiting " "Calling `step_async` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
"for a pending call to `{0}` to complete.".format(self._state.value),
self._state.value, self._state.value,
) )
@@ -280,8 +268,7 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout): if not self._poll(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise mp.TimeoutError(
"The call to `step_wait` has timed out after " "The call to `step_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
) )
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -290,9 +277,7 @@ class AsyncVectorEnv(VectorEnv):
observations_list, rewards, dones, infos = zip(*results) observations_list, rewards, dones, infos = zip(*results)
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(observations_list, self.observations, self.single_observation_space)
observations_list, self.observations, self.single_observation_space
)
return ( return (
deepcopy(self.observations) if self.copy else self.observations, deepcopy(self.observations) if self.copy else self.observations,
@@ -318,8 +303,7 @@ class AsyncVectorEnv(VectorEnv):
try: try:
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
logger.warn( logger.warn(
"Calling `close` while waiting for a pending " "Calling `close` while waiting for a pending " "call to `{0}` to complete.".format(self._state.value)
"call to `{0}` to complete.".format(self._state.value)
) )
function = getattr(self, "{0}_wait".format(self._state.value)) function = getattr(self, "{0}_wait".format(self._state.value))
function(timeout) function(timeout)
@@ -375,8 +359,7 @@ class AsyncVectorEnv(VectorEnv):
def _assert_is_running(self): def _assert_is_running(self):
if self.closed: if self.closed:
raise ClosedEnvironmentError( raise ClosedEnvironmentError(
"Trying to operate on `{0}`, after a " "Trying to operate on `{0}`, after a " "call to `close()`.".format(type(self).__name__)
"call to `close()`.".format(type(self).__name__)
) )
def _raise_if_errors(self, successes): def _raise_if_errors(self, successes):
@@ -387,10 +370,7 @@ class AsyncVectorEnv(VectorEnv):
assert num_errors > 0 assert num_errors > 0
for _ in range(num_errors): for _ in range(num_errors):
index, exctype, value = self.error_queue.get() index, exctype, value = self.error_queue.get()
logger.error( logger.error("Received the following error from Worker-{0}: " "{1}: {2}".format(index, exctype.__name__, value))
"Received the following error from Worker-{0}: "
"{1}: {2}".format(index, exctype.__name__, value)
)
logger.error("Shutting down Worker-{0}.".format(index)) logger.error("Shutting down Worker-{0}.".format(index))
self.parent_pipes[index].close() self.parent_pipes[index].close()
self.parent_pipes[index] = None self.parent_pipes[index] = None
@@ -445,17 +425,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
observation = env.reset() observation = env.reset()
write_to_shared_memory( write_to_shared_memory(index, observation, shared_memory, observation_space)
index, observation, shared_memory, observation_space
)
pipe.send((None, True)) pipe.send((None, True))
elif command == "step": elif command == "step":
observation, reward, done, info = env.step(data) observation, reward, done, info = env.step(data)
if done: if done:
observation = env.reset() observation = env.reset()
write_to_shared_memory( write_to_shared_memory(index, observation, shared_memory, observation_space)
index, observation, shared_memory, observation_space
)
pipe.send(((None, reward, done, info), True)) pipe.send(((None, reward, done, info), True))
elif command == "seed": elif command == "seed":
env.seed(data) env.seed(data)

View File

@@ -44,9 +44,7 @@ class SyncVectorEnv(VectorEnv):
) )
self._check_observation_spaces() self._check_observation_spaces()
self.observations = create_empty_array( self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64) self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_) self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None self._actions = None
@@ -67,9 +65,7 @@ class SyncVectorEnv(VectorEnv):
for env in self.envs: for env in self.envs:
observation = env.reset() observation = env.reset()
observations.append(observation) observations.append(observation)
self.observations = concatenate( self.observations = concatenate(observations, self.observations, self.single_observation_space)
observations, self.observations, self.single_observation_space
)
return deepcopy(self.observations) if self.copy else self.observations return deepcopy(self.observations) if self.copy else self.observations
@@ -84,9 +80,7 @@ class SyncVectorEnv(VectorEnv):
observation = env.reset() observation = env.reset()
observations.append(observation) observations.append(observation)
infos.append(info) infos.append(info)
self.observations = concatenate( self.observations = concatenate(observations, self.observations, self.single_observation_space)
observations, self.observations, self.single_observation_space
)
return ( return (
deepcopy(self.observations) if self.copy else self.observations, deepcopy(self.observations) if self.copy else self.observations,

View File

@@ -10,9 +10,7 @@ from gym.vector.tests.utils import spaces
from gym.vector.utils.numpy_utils import concatenate, create_empty_array from gym.vector.utils.numpy_utils import concatenate, create_empty_array
@pytest.mark.parametrize( @pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_concatenate(space): def test_concatenate(space):
def assert_type(lhs, rhs, n): def assert_type(lhs, rhs, n):
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray # Special case: if rhs is a list of scalars, lhs must be an np.ndarray
@@ -53,9 +51,7 @@ def test_concatenate(space):
@pytest.mark.parametrize("n", [1, 8]) @pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize( @pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array(space, n): def test_create_empty_array(space, n):
def assert_nested_type(arr, space, n): def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces): if isinstance(space, _BaseGymSpaces):
@@ -83,9 +79,7 @@ def test_create_empty_array(space, n):
@pytest.mark.parametrize("n", [1, 8]) @pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize( @pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_zeros(space, n): def test_create_empty_array_zeros(space, n):
def assert_nested_type(arr, space, n): def assert_nested_type(arr, space, n):
if isinstance(space, _BaseGymSpaces): if isinstance(space, _BaseGymSpaces):
@@ -113,9 +107,7 @@ def test_create_empty_array_zeros(space, n):
assert_nested_type(array, space, n=n) assert_nested_type(array, space, n=n)
@pytest.mark.parametrize( @pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_none_shape_ones(space): def test_create_empty_array_none_shape_ones(space):
def assert_nested_type(arr, space): def assert_nested_type(arr, space):
if isinstance(space, _BaseGymSpaces): if isinstance(space, _BaseGymSpaces):

View File

@@ -46,9 +46,7 @@ expected_types = [
list(zip(spaces, expected_types)), list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces], ids=[space.__class__.__name__ for space in spaces],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
def test_create_shared_memory(space, expected_type, n, ctx): def test_create_shared_memory(space, expected_type, n, ctx):
def assert_nested_type(lhs, rhs, n): def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs) assert type(lhs) == type(rhs)
@@ -77,9 +75,7 @@ def test_create_shared_memory(space, expected_type, n, ctx):
@pytest.mark.parametrize("n", [1, 8]) @pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize( @pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
@pytest.mark.parametrize("space", custom_spaces) @pytest.mark.parametrize("space", custom_spaces)
def test_create_shared_memory_custom_space(n, ctx, space): def test_create_shared_memory_custom_space(n, ctx, space):
ctx = mp if (ctx is None) else mp.get_context(ctx) ctx = mp if (ctx is None) else mp.get_context(ctx)
@@ -87,9 +83,7 @@ def test_create_shared_memory_custom_space(n, ctx, space):
shared_memory = create_shared_memory(space, n=n, ctx=ctx) shared_memory = create_shared_memory(space, n=n, ctx=ctx)
@pytest.mark.parametrize( @pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_write_to_shared_memory(space): def test_write_to_shared_memory(space):
def assert_nested_equal(lhs, rhs): def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list) assert isinstance(rhs, list)
@@ -113,9 +107,7 @@ def test_write_to_shared_memory(space):
shared_memory_n8 = create_shared_memory(space, n=8) shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)] samples = [space.sample() for _ in range(8)]
processes = [ processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
]
for process in processes: for process in processes:
process.start() process.start()
@@ -125,25 +117,19 @@ def test_write_to_shared_memory(space):
assert_nested_equal(shared_memory_n8, samples) assert_nested_equal(shared_memory_n8, samples)
@pytest.mark.parametrize( @pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_read_from_shared_memory(space): def test_read_from_shared_memory(space):
def assert_nested_equal(lhs, rhs, space, n): def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list) assert isinstance(rhs, list)
if isinstance(space, Tuple): if isinstance(space, Tuple):
assert isinstance(lhs, tuple) assert isinstance(lhs, tuple)
for i in range(len(lhs)): for i in range(len(lhs)):
assert_nested_equal( assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n)
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
)
elif isinstance(space, Dict): elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict) assert isinstance(lhs, OrderedDict)
for key in lhs.keys(): for key in lhs.keys():
assert_nested_equal( assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n)
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
)
elif isinstance(space, _BaseGymSpaces): elif isinstance(space, _BaseGymSpaces):
assert isinstance(lhs, np.ndarray) assert isinstance(lhs, np.ndarray)
@@ -161,9 +147,7 @@ def test_read_from_shared_memory(space):
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8) memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
samples = [space.sample() for _ in range(8)] samples = [space.sample() for _ in range(8)]
processes = [ processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
]
for process in processes: for process in processes:
process.start() process.start()

View File

@@ -10,12 +10,8 @@ expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64), Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32), Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32),
Box( Box(
low=np.array( low=np.array([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]),
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]] high=np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
),
high=np.array(
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
),
dtype=np.float32, dtype=np.float32,
), ),
Box( Box(

View File

@@ -7,12 +7,8 @@ from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
spaces = [ spaces = [
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64), Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32), Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32),
Box( Box(low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32),
low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32 Box(low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32),
),
Box(
low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32
),
Box(low=0, high=255, shape=(), dtype=np.uint8), Box(low=0, high=255, shape=(), dtype=np.uint8),
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Discrete(2), Discrete(2),
@@ -28,17 +24,13 @@ spaces = [
Dict( Dict(
{ {
"position": Discrete(23), "position": Discrete(23),
"velocity": Box( "velocity": Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32),
low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32
),
} }
), ),
Dict( Dict(
{ {
"position": Dict({"x": Discrete(29), "y": Discrete(31)}), "position": Dict({"x": Discrete(29), "y": Discrete(31)}),
"velocity": Tuple( "velocity": Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))),
(Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))
),
} }
), ),
] ]
@@ -50,9 +42,7 @@ class UnittestSlowEnv(gym.Env):
def __init__(self, slow_reset=0.3): def __init__(self, slow_reset=0.3):
super(UnittestSlowEnv, self).__init__() super(UnittestSlowEnv, self).__init__()
self.slow_reset = slow_reset self.slow_reset = slow_reset
self.observation_space = Box( self.observation_space = Box(low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
)
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32) self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
def reset(self): def reset(self):

View File

@@ -46,10 +46,7 @@ def concatenate(items, out, space):
elif isinstance(space, Space): elif isinstance(space, Space):
return concatenate_custom(items, out, space) return concatenate_custom(items, out, space)
else: else:
raise ValueError( raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
def concatenate_base(items, out, space): def concatenate_base(items, out, space):
@@ -57,18 +54,12 @@ def concatenate_base(items, out, space):
def concatenate_tuple(items, out, space): def concatenate_tuple(items, out, space):
return tuple( return tuple(concatenate([item[i] for item in items], out[i], subspace) for (i, subspace) in enumerate(space.spaces))
concatenate([item[i] for item in items], out[i], subspace)
for (i, subspace) in enumerate(space.spaces)
)
def concatenate_dict(items, out, space): def concatenate_dict(items, out, space):
return OrderedDict( return OrderedDict(
[ [(key, concatenate([item[key] for item in items], out[key], subspace)) for (key, subspace) in space.spaces.items()]
(key, concatenate([item[key] for item in items], out[key], subspace))
for (key, subspace) in space.spaces.items()
]
) )
@@ -118,10 +109,7 @@ def create_empty_array(space, n=1, fn=np.zeros):
elif isinstance(space, Space): elif isinstance(space, Space):
return create_empty_array_custom(space, n=n, fn=fn) return create_empty_array_custom(space, n=n, fn=fn)
else: else:
raise ValueError( raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)
def create_empty_array_base(space, n=1, fn=np.zeros): def create_empty_array_base(space, n=1, fn=np.zeros):
@@ -134,12 +122,7 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros):
def create_empty_array_dict(space, n=1, fn=np.zeros): def create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict( return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn)) for (key, subspace) in space.spaces.items()])
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
)
def create_empty_array_custom(space, n=1, fn=np.zeros): def create_empty_array_custom(space, n=1, fn=np.zeros):

View File

@@ -56,18 +56,11 @@ def create_base_shared_memory(space, n=1, ctx=mp):
def create_tuple_shared_memory(space, n=1, ctx=mp): def create_tuple_shared_memory(space, n=1, ctx=mp):
return tuple( return tuple(create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces)
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
def create_dict_shared_memory(space, n=1, ctx=mp): def create_dict_shared_memory(space, n=1, ctx=mp):
return OrderedDict( return OrderedDict([(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()])
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
def read_from_shared_memory(shared_memory, space, n=1): def read_from_shared_memory(shared_memory, space, n=1):
@@ -114,24 +107,16 @@ def read_from_shared_memory(shared_memory, space, n=1):
def read_base_from_shared_memory(shared_memory, space, n=1): def read_base_from_shared_memory(shared_memory, space, n=1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape((n,) + space.shape)
(n,) + space.shape
)
def read_tuple_from_shared_memory(shared_memory, space, n=1): def read_tuple_from_shared_memory(shared_memory, space, n=1):
return tuple( return tuple(read_from_shared_memory(memory, subspace, n=n) for (memory, subspace) in zip(shared_memory, space.spaces))
read_from_shared_memory(memory, subspace, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
def read_dict_from_shared_memory(shared_memory, space, n=1): def read_dict_from_shared_memory(shared_memory, space, n=1):
return OrderedDict( return OrderedDict(
[ [(key, read_from_shared_memory(shared_memory[key], subspace, n=n)) for (key, subspace) in space.spaces.items()]
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
for (key, subspace) in space.spaces.items()
]
) )

View File

@@ -44,8 +44,7 @@ def batch_space(space, n=1):
return batch_space_custom(space, n=n) return batch_space_custom(space, n=n)
else: else:
raise ValueError( raise ValueError(
"Cannot batch space with type `{0}`. The space must " "Cannot batch space with type `{0}`. The space must " "be a valid `gym.Space` instance.".format(type(space))
"be a valid `gym.Space` instance.".format(type(space))
) )
@@ -75,14 +74,7 @@ def batch_space_tuple(space, n=1):
def batch_space_dict(space, n=1): def batch_space_dict(space, n=1):
return Dict( return Dict(OrderedDict([(key, batch_space(subspace, n=n)) for (key, subspace) in space.spaces.items()]))
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
)
)
def batch_space_custom(space, n=1): def batch_space_custom(space, n=1):

View File

@@ -141,9 +141,7 @@ class VectorEnv(gym.Env):
if self.spec is None: if self.spec is None:
return "{}({})".format(self.__class__.__name__, self.num_envs) return "{}({})".format(self.__class__.__name__, self.num_envs)
else: else:
return "{}({}, {})".format( return "{}({}, {})".format(self.__class__.__name__, self.spec.id, self.num_envs)
self.__class__.__name__, self.spec.id, self.num_envs
)
class VectorEnvWrapper(VectorEnv): class VectorEnvWrapper(VectorEnv):
@@ -189,9 +187,7 @@ class VectorEnvWrapper(VectorEnv):
# implicitly forward all other methods and attributes to self.env # implicitly forward all other methods and attributes to self.env
def __getattr__(self, name): def __getattr__(self, name):
if name.startswith("_"): if name.startswith("_"):
raise AttributeError( raise AttributeError("attempted to get missing private attribute '{}'".format(name))
"attempted to get missing private attribute '{}'".format(name)
)
return getattr(self.env, name) return getattr(self.env, name)
@property @property

View File

@@ -62,12 +62,13 @@ class AtariPreprocessing(gym.Wrapper):
assert noop_max >= 0 assert noop_max >= 0
if frame_skip > 1: if frame_skip > 1:
assert "NoFrameskip" in env.spec.id, ( assert "NoFrameskip" in env.spec.id, (
"disable frame-skipping in the original env. for more than one" "disable frame-skipping in the original env. for more than one" " frame-skip as it will be done by the wrapper"
" frame-skip as it will be done by the wrapper"
) )
self.noop_max = noop_max self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP" assert env.unwrapped.get_action_meanings()[0] == "NOOP"
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.frame_skip = frame_skip self.frame_skip = frame_skip
self.screen_size = screen_size self.screen_size = screen_size
@@ -92,15 +93,11 @@ class AtariPreprocessing(gym.Wrapper):
self.lives = 0 self.lives = 0
self.game_over = False self.game_over = False
_low, _high, _obs_dtype = ( _low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3) _shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis: if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis _shape = _shape[:-1] # Remove channel axis
self.observation_space = Box( self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
def step(self, action): def step(self, action):
R = 0.0 R = 0.0
@@ -132,11 +129,7 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
# NoopReset # NoopReset
self.env.reset(**kwargs) self.env.reset(**kwargs)
noops = ( noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops): for _ in range(noops):
_, _, done, _ = self.env.step(0) _, _, done, _ = self.env.step(0)
if done: if done:

View File

@@ -9,7 +9,9 @@ class ClipAction(ActionWrapper):
def __init__(self, env): def __init__(self, env):
assert isinstance(env.action_space, Box) assert isinstance(env.action_space, Box)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
super(ClipAction, self).__init__(env) super(ClipAction, self).__init__(env)
def action(self, action): def action(self, action):

View File

@@ -26,7 +26,9 @@ class FilterObservation(ObservationWrapper):
assert isinstance( assert isinstance(
wrapped_observation_space, spaces.Dict wrapped_observation_space, spaces.Dict
), "FilterObservationWrapper is only usable with dict observations." ), "FilterObservationWrapper is only usable with dict observations."
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
observation_keys = wrapped_observation_space.spaces.keys() observation_keys = wrapped_observation_space.spaces.keys()
@@ -49,11 +51,7 @@ class FilterObservation(ObservationWrapper):
) )
self.observation_space = type(wrapped_observation_space)( self.observation_space = type(wrapped_observation_space)(
[ [(name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys]
(name, copy.deepcopy(space))
for name, space in wrapped_observation_space.spaces.items()
if name in filter_keys
]
) )
self._env = env self._env = env
@@ -64,11 +62,5 @@ class FilterObservation(ObservationWrapper):
return filter_observation return filter_observation
def _filter_observation(self, observation): def _filter_observation(self, observation):
observation = type(observation)( observation = type(observation)([(name, value) for name, value in observation.items() if name in self._filter_keys])
[
(name, value)
for name, value in observation.items()
if name in self._filter_keys
]
)
return observation return observation

View File

@@ -9,7 +9,9 @@ class FlattenObservation(ObservationWrapper):
def __init__(self, env): def __init__(self, env):
super(FlattenObservation, self).__init__(env) super(FlattenObservation, self).__init__(env)
self.observation_space = spaces.flatten_space(env.observation_space) self.observation_space = spaces.flatten_space(env.observation_space)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
def observation(self, observation): def observation(self, observation):
return spaces.flatten(self.env.observation_space, observation) return spaces.flatten(self.env.observation_space, observation)

View File

@@ -22,7 +22,9 @@ class LazyFrames(object):
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames") __slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
def __init__(self, frames, lz4_compress=False): def __init__(self, frames, lz4_compress=False):
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.frame_shape = tuple(frames[0].shape) self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape self.shape = (len(frames),) + self.frame_shape
self.dtype = frames[0].dtype self.dtype = frames[0].dtype
@@ -45,9 +47,7 @@ class LazyFrames(object):
def __getitem__(self, int_or_slice): def __getitem__(self, int_or_slice):
if isinstance(int_or_slice, int): if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame return self._check_decompress(self._frames[int_or_slice]) # single frame
return np.stack( return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0)
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
def __eq__(self, other): def __eq__(self, other):
return self.__array__() == other return self.__array__() == other
@@ -56,9 +56,7 @@ class LazyFrames(object):
if self.lz4_compress: if self.lz4_compress:
from lz4.block import decompress from lz4.block import decompress
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape( return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape)
self.frame_shape
)
return frame return frame
@@ -102,12 +100,8 @@ class FrameStack(Wrapper):
self.frames = deque(maxlen=num_stack) self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0) low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat( high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
self.observation_space.high[np.newaxis, ...], num_stack, axis=0 self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
def _get_observation(self): def _get_observation(self):
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)

View File

@@ -11,20 +11,15 @@ class GrayScaleObservation(ObservationWrapper):
super(GrayScaleObservation, self).__init__(env) super(GrayScaleObservation, self).__init__(env)
self.keep_dim = keep_dim self.keep_dim = keep_dim
assert ( assert len(env.observation_space.shape) == 3 and env.observation_space.shape[-1] == 3
len(env.observation_space.shape) == 3 warnings.warn(
and env.observation_space.shape[-1] == 3 "Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
) )
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
obs_shape = self.observation_space.shape[:2] obs_shape = self.observation_space.shape[:2]
if self.keep_dim: if self.keep_dim:
self.observation_space = Box( self.observation_space = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8)
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
)
else: else:
self.observation_space = Box( self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
low=0, high=255, shape=obs_shape, dtype=np.uint8
)
def observation(self, observation): def observation(self, observation):
import cv2 import cv2

View File

@@ -37,9 +37,7 @@ class Monitor(Wrapper):
self._monitor_id = None self._monitor_id = None
self.env_semantics_autoreset = env.metadata.get("semantics.autoreset") self.env_semantics_autoreset = env.metadata.get("semantics.autoreset")
self._start( self._start(directory, video_callable, force, resume, write_upon_reset, uid, mode)
directory, video_callable, force, resume, write_upon_reset, uid, mode
)
def step(self, action): def step(self, action):
self._before_step(action) self._before_step(action)
@@ -163,10 +161,7 @@ class Monitor(Wrapper):
json.dump( json.dump(
{ {
"stats": os.path.basename(self.stats_recorder.path), "stats": os.path.basename(self.stats_recorder.path),
"videos": [ "videos": [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos],
(os.path.basename(v), os.path.basename(m))
for v, m in self.videos
],
"env_info": self._env_info(), "env_info": self._env_info(),
}, },
f, f,
@@ -199,9 +194,7 @@ class Monitor(Wrapper):
elif mode == "training": elif mode == "training":
type = "t" type = "t"
else: else:
raise error.Error( raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
'Invalid mode {}: must be "training" or "evaluation"', mode
)
self.stats_recorder.type = type self.stats_recorder.type = type
def _before_step(self, action): def _before_step(self, action):
@@ -257,9 +250,7 @@ class Monitor(Wrapper):
env=self.env, env=self.env,
base_path=os.path.join( base_path=os.path.join(
self.directory, self.directory,
"{}.video.{}.video{:06}".format( "{}.video.{}.video{:06}".format(self.file_prefix, self.file_infix, self.episode_id),
self.file_prefix, self.file_infix, self.episode_id
),
), ),
metadata={"episode_id": self.episode_id}, metadata={"episode_id": self.episode_id},
enabled=self._video_enabled(), enabled=self._video_enabled(),
@@ -269,9 +260,7 @@ class Monitor(Wrapper):
def _close_video_recorder(self): def _close_video_recorder(self):
self.video_recorder.close() self.video_recorder.close()
if self.video_recorder.functional: if self.video_recorder.functional:
self.videos.append( self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
(self.video_recorder.path, self.video_recorder.metadata_path)
)
def _video_enabled(self): def _video_enabled(self):
return self.video_callable(self.episode_id) return self.video_callable(self.episode_id)
@@ -301,19 +290,11 @@ class Monitor(Wrapper):
def detect_training_manifests(training_dir, files=None): def detect_training_manifests(training_dir, files=None):
if files is None: if files is None:
files = os.listdir(training_dir) files = os.listdir(training_dir)
return [ return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + ".")]
os.path.join(training_dir, f)
for f in files
if f.startswith(MANIFEST_PREFIX + ".")
]
def detect_monitor_files(training_dir): def detect_monitor_files(training_dir):
return [ return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + ".")]
os.path.join(training_dir, f)
for f in os.listdir(training_dir)
if f.startswith(FILE_PREFIX + ".")
]
def clear_monitor_files(training_dir): def clear_monitor_files(training_dir):
@@ -382,10 +363,7 @@ def load_results(training_dir):
contents = json.load(f) contents = json.load(f)
# Make these paths absolute again # Make these paths absolute again
stats_files.append(os.path.join(training_dir, contents["stats"])) stats_files.append(os.path.join(training_dir, contents["stats"]))
videos += [ videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m)) for v, m in contents["videos"]]
(os.path.join(training_dir, v), os.path.join(training_dir, m))
for v, m in contents["videos"]
]
env_infos.append(contents["env_info"]) env_infos.append(contents["env_info"])
env_info = collapse_env_infos(env_infos, training_dir) env_info = collapse_env_infos(env_infos, training_dir)

View File

@@ -48,9 +48,7 @@ class VideoRecorder(object):
self.ansi_mode = True self.ansi_mode = True
else: else:
logger.info( logger.info(
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format( 'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env)
env
)
) )
# Whoops, turns out we shouldn't be enabled after all # Whoops, turns out we shouldn't be enabled after all
self.enabled = False self.enabled = False
@@ -69,9 +67,7 @@ class VideoRecorder(object):
path = base_path + required_ext path = base_path + required_ext
else: else:
# Otherwise, just generate a unique filename # Otherwise, just generate a unique filename
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f:
suffix=required_ext, delete=False
) as f:
path = f.name path = f.name
self.path = path self.path = path
@@ -83,28 +79,20 @@ class VideoRecorder(object):
if self.ansi_mode if self.ansi_mode
else "" else ""
) )
raise error.Error( raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint))
"Invalid path given: {} -- must have file extension {}.{}".format(
self.path, required_ext, hint
)
)
# Touch the file in any case, so we know it's present. (This # Touch the file in any case, so we know it's present. (This
# corrects for platform platform differences. Using ffmpeg on # corrects for platform platform differences. Using ffmpeg on
# OS X, the file is precreated, but not on Linux. # OS X, the file is precreated, but not on Linux.
touch(path) touch(path)
self.frames_per_sec = env.metadata.get("video.frames_per_second", 30) self.frames_per_sec = env.metadata.get("video.frames_per_second", 30)
self.output_frames_per_sec = env.metadata.get( self.output_frames_per_sec = env.metadata.get("video.output_frames_per_second", self.frames_per_sec)
"video.output_frames_per_second", self.frames_per_sec
)
self.encoder = None # lazily start the process self.encoder = None # lazily start the process
self.broken = False self.broken = False
# Dump metadata # Dump metadata
self.metadata = metadata or {} self.metadata = metadata or {}
self.metadata["content_type"] = ( self.metadata["content_type"] = "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
"video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
)
self.metadata_path = "{}.meta.json".format(path_base) self.metadata_path = "{}.meta.json".format(path_base)
self.write_metadata() self.write_metadata()
@@ -191,9 +179,7 @@ class VideoRecorder(object):
def _encode_image_frame(self, frame): def _encode_image_frame(self, frame):
if not self.encoder: if not self.encoder:
self.encoder = ImageEncoder( self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec)
self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
)
self.metadata["encoder_version"] = self.encoder.version_info self.metadata["encoder_version"] = self.encoder.version_info
try: try:
@@ -222,24 +208,16 @@ class TextEncoder(object):
string = frame.getvalue() string = frame.getvalue()
else: else:
raise error.InvalidFrame( raise error.InvalidFrame(
"Wrong type {} for {}: text frame must be a string or StringIO".format( "Wrong type {} for {}: text frame must be a string or StringIO".format(type(frame), frame)
type(frame), frame
)
) )
frame_bytes = string.encode("utf-8") frame_bytes = string.encode("utf-8")
if frame_bytes[-1:] != b"\n": if frame_bytes[-1:] != b"\n":
raise error.InvalidFrame( raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string))
'Frame must end with a newline: """{}"""'.format(string)
)
if b"\r" in frame_bytes: if b"\r" in frame_bytes:
raise error.InvalidFrame( raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string))
'Frame contains carriage returns (only newlines are allowed: """{}"""'.format(
string
)
)
self.frames.append(frame_bytes) self.frames.append(frame_bytes)
@@ -263,15 +241,7 @@ class TextEncoder(object):
# Calculate frame size from the largest frames. # Calculate frame size from the largest frames.
# Add some padding since we'll get cut off otherwise. # Add some padding since we'll get cut off otherwise.
height = max([frame.count(b"\n") for frame in self.frames]) + 1 height = max([frame.count(b"\n") for frame in self.frames]) + 1
width = ( width = max([max([len(line) for line in frame.split(b"\n")]) for frame in self.frames]) + 2
max(
[
max([len(line) for line in frame.split(b"\n")])
for frame in self.frames
]
)
+ 2
)
data = { data = {
"version": 1, "version": 1,
@@ -325,11 +295,7 @@ class ImageEncoder(object):
def version_info(self): def version_info(self):
return { return {
"backend": self.backend, "backend": self.backend,
"version": str( "version": str(subprocess.check_output([self.backend, "-version"], stderr=subprocess.STDOUT)),
subprocess.check_output(
[self.backend, "-version"], stderr=subprocess.STDOUT
)
),
"cmdline": self.cmdline, "cmdline": self.cmdline,
} }
@@ -396,19 +362,13 @@ class ImageEncoder(object):
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline)) logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
if hasattr(os, "setsid"): # setsid not present on Windows if hasattr(os, "setsid"): # setsid not present on Windows
self.proc = subprocess.Popen( self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
)
else: else:
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE) self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
def capture_frame(self, frame): def capture_frame(self, frame):
if not isinstance(frame, (np.ndarray, np.generic)): if not isinstance(frame, (np.ndarray, np.generic)):
raise error.InvalidFrame( raise error.InvalidFrame("Wrong type {} for {} (must be np.ndarray or np.generic)".format(type(frame), frame))
"Wrong type {} for {} (must be np.ndarray or np.generic)".format(
type(frame), frame
)
)
if frame.shape != self.frame_shape: if frame.shape != self.frame_shape:
raise error.InvalidFrame( raise error.InvalidFrame(
"Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format( "Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(
@@ -417,15 +377,11 @@ class ImageEncoder(object):
) )
if frame.dtype != np.uint8: if frame.dtype != np.uint8:
raise error.InvalidFrame( raise error.InvalidFrame(
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format( "Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype)
frame.dtype
)
) )
try: try:
if distutils.version.LooseVersion( if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion("1.9.0"):
np.__version__
) >= distutils.version.LooseVersion("1.9.0"):
self.proc.stdin.write(frame.tobytes()) self.proc.stdin.write(frame.tobytes())
else: else:
self.proc.stdin.write(frame.tostring()) self.proc.stdin.write(frame.tostring())

View File

@@ -14,9 +14,7 @@ STATE_KEY = "state"
class PixelObservationWrapper(ObservationWrapper): class PixelObservationWrapper(ObservationWrapper):
"""Augment observations by pixel values.""" """Augment observations by pixel values."""
def __init__( def __init__(self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)):
self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
):
"""Initializes a new pixel Wrapper. """Initializes a new pixel Wrapper.
Args: Args:
@@ -52,7 +50,9 @@ class PixelObservationWrapper(ObservationWrapper):
assert render_mode == "rgb_array", render_mode assert render_mode == "rgb_array", render_mode
render_kwargs[key]["mode"] = "rgb_array" render_kwargs[key]["mode"] = "rgb_array"
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
wrapped_observation_space = env.observation_space wrapped_observation_space = env.observation_space
@@ -70,9 +70,7 @@ class PixelObservationWrapper(ObservationWrapper):
# `observation_keys` # `observation_keys`
overlapping_keys = set(pixel_keys) & set(invalid_keys) overlapping_keys = set(pixel_keys) & set(invalid_keys)
if overlapping_keys: if overlapping_keys:
raise ValueError( raise ValueError("Duplicate or reserved pixel keys {!r}.".format(overlapping_keys))
"Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)
)
if pixels_only: if pixels_only:
self.observation_space = spaces.Dict() self.observation_space = spaces.Dict()
@@ -95,9 +93,7 @@ class PixelObservationWrapper(ObservationWrapper):
else: else:
raise TypeError(pixels.dtype) raise TypeError(pixels.dtype)
pixels_space = spaces.Box( pixels_space = spaces.Box(shape=pixels.shape, low=low, high=high, dtype=pixels.dtype)
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
)
pixels_spaces[pixel_key] = pixels_space pixels_spaces[pixel_key] = pixels_space
self.observation_space.spaces.update(pixels_spaces) self.observation_space.spaces.update(pixels_spaces)
@@ -120,10 +116,7 @@ class PixelObservationWrapper(ObservationWrapper):
observation = collections.OrderedDict() observation = collections.OrderedDict()
observation[STATE_KEY] = wrapped_observation observation[STATE_KEY] = wrapped_observation
pixel_observations = { pixel_observations = {pixel_key: self.env.render(**self._render_kwargs[pixel_key]) for pixel_key in self._pixel_keys}
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
for pixel_key in self._pixel_keys
}
observation.update(pixel_observations) observation.update(pixel_observations)

View File

@@ -7,14 +7,14 @@ import gym
class RecordEpisodeStatistics(gym.Wrapper): class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100): def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env) super(RecordEpisodeStatistics, self).__init__(env)
self.t0 = ( self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support
time.time()
) # TODO: use perf_counter when gym removes Python 2 support
self.episode_return = 0.0 self.episode_return = 0.0
self.episode_length = 0 self.episode_length = 0
self.return_queue = deque(maxlen=deque_size) self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
def reset(self, **kwargs): def reset(self, **kwargs):
observation = super(RecordEpisodeStatistics, self).reset(**kwargs) observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
@@ -23,9 +23,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
return observation return observation
def step(self, action): def step(self, action):
observation, reward, done, info = super(RecordEpisodeStatistics, self).step( observation, reward, done, info = super(RecordEpisodeStatistics, self).step(action)
action
)
self.episode_return += reward self.episode_return += reward
self.episode_length += 1 self.episode_length += 1
if done: if done:

View File

@@ -15,17 +15,15 @@ class RescaleAction(gym.ActionWrapper):
""" """
def __init__(self, env, a, b): def __init__(self, env, a, b):
assert isinstance( assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space))
env.action_space, spaces.Box
), "expected Box action space, got {}".format(type(env.action_space))
assert np.less_equal(a, b).all(), (a, b) assert np.less_equal(a, b).all(), (a, b)
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
super(RescaleAction, self).__init__(env) super(RescaleAction, self).__init__(env)
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
self.action_space = spaces.Box( self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype
)
def action(self, action): def action(self, action):
assert np.all(np.greater_equal(action, self.a)), (action, self.a) assert np.all(np.greater_equal(action, self.a)), (action, self.a)

View File

@@ -12,7 +12,9 @@ class ResizeObservation(ObservationWrapper):
if isinstance(shape, int): if isinstance(shape, int):
shape = (shape, shape) shape = (shape, shape)
assert all(x > 0 for x in shape), shape assert all(x > 0 for x in shape), shape
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit") warnings.warn(
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
)
self.shape = tuple(shape) self.shape = tuple(shape)
obs_shape = self.shape + self.observation_space.shape[2:] obs_shape = self.shape + self.observation_space.shape[2:]
@@ -21,9 +23,7 @@ class ResizeObservation(ObservationWrapper):
def observation(self, observation): def observation(self, observation):
import cv2 import cv2
observation = cv2.resize( observation = cv2.resize(observation, self.shape[::-1], interpolation=cv2.INTER_AREA)
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
)
if observation.ndim == 2: if observation.ndim == 2:
observation = np.expand_dims(observation, -1) observation = np.expand_dims(observation, -1)
return observation return observation

View File

@@ -15,12 +15,8 @@ def test_atari_preprocessing_grayscale(env_fn):
import cv2 import cv2
env1 = env_fn() env1 = env_fn()
env2 = AtariPreprocessing( env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0 env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
)
env3 = AtariPreprocessing(
env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0
)
env4 = AtariPreprocessing( env4 = AtariPreprocessing(
env_fn(), env_fn(),
screen_size=84, screen_size=84,
@@ -79,15 +75,11 @@ def test_atari_preprocessing_scale(env_fn):
obs = env.reset().flatten() obs = env.reset().flatten()
done, step_i = False, 0 done, step_i = False, 0
max_obs = 1 if scaled else 255 max_obs = 1 if scaled else 255
assert (0 <= obs).all() and ( assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
while not done or step_i <= max_test_steps: while not done or step_i <= max_test_steps:
obs, _, done, _ = env.step(env.action_space.sample()) obs, _, done, _ = env.step(env.action_space.sample())
obs = obs.flatten() obs = obs.flatten()
assert (0 <= obs).all() and ( assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
obs <= max_obs
).all(), "Obs. must be in range [0,{}]".format(max_obs)
step_i += 1 step_i += 1
env.close() env.close()

View File

@@ -19,9 +19,7 @@ def test_clip_action():
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
for action in actions: for action in actions:
obs1, r1, d1, _ = env.step( obs1, r1, d1, _ = env.step(np.clip(action, env.action_space.low, env.action_space.high))
np.clip(action, env.action_space.low, env.action_space.high)
)
obs2, r2, d2, _ = wrapped_env.step(action) obs2, r2, d2, _ = wrapped_env.step(action)
assert np.allclose(r1, r2) assert np.allclose(r1, r2)
assert np.allclose(obs1, obs2) assert np.allclose(obs1, obs2)

View File

@@ -9,10 +9,7 @@ from gym.wrappers.filter_observation import FilterObservation
class FakeEnvironment(gym.Env): class FakeEnvironment(gym.Env):
def __init__(self, observation_keys=("state")): def __init__(self, observation_keys=("state")):
self.observation_space = spaces.Dict( self.observation_space = spaces.Dict(
{ {name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) for name in observation_keys}
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
for name in observation_keys
}
) )
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
@@ -48,9 +45,7 @@ ERROR_TEST_CASES = (
class TestFilterObservation(object): class TestFilterObservation(object):
@pytest.mark.parametrize( @pytest.mark.parametrize("observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES)
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
)
def test_filter_observation(self, observation_keys, filter_keys): def test_filter_observation(self, observation_keys, filter_keys):
env = FakeEnvironment(observation_keys=observation_keys) env = FakeEnvironment(observation_keys=observation_keys)
@@ -73,9 +68,7 @@ class TestFilterObservation(object):
assert len(observation) == len(filter_keys) assert len(observation) == len(filter_keys)
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES) @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
def test_raises_with_incorrect_arguments( def test_raises_with_incorrect_arguments(self, filter_keys, error_type, error_match):
self, filter_keys, error_type, error_match
):
env = FakeEnvironment(observation_keys=("key1", "key2")) env = FakeEnvironment(observation_keys=("key1", "key2"))
ValueError ValueError

View File

@@ -16,14 +16,10 @@ def test_flatten_observation(env_id):
wrapped_obs = wrapped_env.reset() wrapped_obs = wrapped_env.reset()
if env_id == "Blackjack-v0": if env_id == "Blackjack-v0":
space = spaces.Tuple( space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
)
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32) wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
elif env_id == "KellyCoinflip-v0": elif env_id == "KellyCoinflip-v0":
space = spaces.Tuple( space = spaces.Tuple((spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)))
(spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))
)
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32) wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
assert space.contains(obs) assert space.contains(obs)

View File

@@ -19,9 +19,7 @@ except ImportError:
[ [
pytest.param( pytest.param(
True, True,
marks=pytest.mark.skipif( marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression"),
lz4 is None, reason="Need lz4 to run tests with compression"
),
), ),
False, False,
], ],

View File

@@ -10,9 +10,7 @@ pytest.importorskip("atari_py")
pytest.importorskip("cv2") pytest.importorskip("cv2")
@pytest.mark.parametrize( @pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"])
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("keep_dim", [True, False]) @pytest.mark.parametrize("keep_dim", [True, False])
def test_gray_scale_observation(env_id, keep_dim): def test_gray_scale_observation(env_id, keep_dim):
gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True) gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True)

View File

@@ -32,9 +32,7 @@ class FakeEnvironment(gym.Env):
class FakeArrayObservationEnvironment(FakeEnvironment): class FakeArrayObservationEnvironment(FakeEnvironment):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.observation_space = spaces.Box( self.observation_space = spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
shape=(2,), low=-1, high=1, dtype=np.float32
)
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs) super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
@@ -75,10 +73,7 @@ class TestPixelObservationWrapper(object):
assert len(wrapped_env.observation_space.spaces) == 1 assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key] assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else: else:
assert ( assert len(wrapped_env.observation_space.spaces) == len(observation_space.spaces) + 1
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
)
expected_keys = list(observation_space.spaces.keys()) + [pixel_key] expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
@@ -97,9 +92,7 @@ class TestPixelObservationWrapper(object):
observation_space = env.observation_space observation_space = env.observation_space
assert isinstance(observation_space, spaces.Box) assert isinstance(observation_space, spaces.Box)
wrapped_env = PixelObservationWrapper( wrapped_env = PixelObservationWrapper(env, pixel_keys=(pixel_key,), pixels_only=pixels_only)
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
)
wrapped_env.observation_space = wrapped_env.observation_space wrapped_env.observation_space = wrapped_env.observation_space
assert isinstance(wrapped_env.observation_space, spaces.Dict) assert isinstance(wrapped_env.observation_space, spaces.Dict)

Some files were not shown because too many files have changed in this diff Show More