mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
redo black
This commit is contained in:
@@ -3,9 +3,7 @@ import argparse
|
||||
import gym
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Renders a Gym environment for quick inspection."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.")
|
||||
parser.add_argument(
|
||||
"env_id",
|
||||
type=str,
|
||||
|
@@ -35,12 +35,7 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0):
|
||||
th_std = np.ones_like(th_mean) * initial_std
|
||||
|
||||
for _ in range(n_iter):
|
||||
ths = np.array(
|
||||
[
|
||||
th_mean + dth
|
||||
for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)
|
||||
]
|
||||
)
|
||||
ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)])
|
||||
ys = np.array([f(th) for th in ths])
|
||||
elite_inds = ys.argsort()[::-1][:n_elite]
|
||||
elite_ths = ths[elite_inds]
|
||||
@@ -101,9 +96,7 @@ if __name__ == "__main__":
|
||||
return rew
|
||||
|
||||
# Train the agent, and snapshot each stage
|
||||
for (i, iterdata) in enumerate(
|
||||
cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)
|
||||
):
|
||||
for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)):
|
||||
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
|
||||
agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
|
||||
if args.display:
|
||||
|
@@ -17,9 +17,7 @@ class RandomAgent(object):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=None)
|
||||
parser.add_argument(
|
||||
"env_id", nargs="?", default="CartPole-v0", help="Select the environment to run"
|
||||
)
|
||||
parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
# You can set the level to logger.DEBUG or logger.WARN if you
|
||||
|
14
gym/core.py
14
gym/core.py
@@ -173,16 +173,10 @@ class GoalEnv(Env):
|
||||
def reset(self):
|
||||
# Enforce that each GoalEnv uses a Goal-compatible observation space.
|
||||
if not isinstance(self.observation_space, gym.spaces.Dict):
|
||||
raise error.Error(
|
||||
"GoalEnv requires an observation space of type gym.spaces.Dict"
|
||||
)
|
||||
raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict")
|
||||
for key in ["observation", "achieved_goal", "desired_goal"]:
|
||||
if key not in self.observation_space.spaces:
|
||||
raise error.Error(
|
||||
'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(
|
||||
key
|
||||
)
|
||||
)
|
||||
raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))
|
||||
|
||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||
"""Compute the step reward. This externalizes the reward function and makes
|
||||
@@ -227,9 +221,7 @@ class Wrapper(Env):
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(
|
||||
"attempted to get missing private attribute '{}'".format(name)
|
||||
)
|
||||
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
|
||||
return getattr(self.env, name)
|
||||
|
||||
@property
|
||||
|
@@ -422,9 +422,7 @@ for reward_type in ["sparse", "dense"]:
|
||||
register(
|
||||
id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
|
||||
entry_point="gym.envs.robotics:HandBlockEnv",
|
||||
kwargs=_merge(
|
||||
{"target_position": "ignore", "target_rotation": "parallel"}, kwargs
|
||||
),
|
||||
kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs),
|
||||
max_episode_steps=100,
|
||||
)
|
||||
|
||||
|
@@ -73,9 +73,7 @@ class AlgorithmicEnv(Env):
|
||||
# 1. Move read head left or right (or up/down)
|
||||
# 2. Write or not
|
||||
# 3. Which character to write. (Ignored if should_write=0)
|
||||
self.action_space = Tuple(
|
||||
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
|
||||
)
|
||||
self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
|
||||
# Can see just what is on the input tape (one of n characters, or
|
||||
# nothing)
|
||||
self.observation_space = Discrete(self.base + 1)
|
||||
@@ -147,10 +145,7 @@ class AlgorithmicEnv(Env):
|
||||
move = self.MOVEMENTS[inp_act]
|
||||
outfile.write("Action : Tuple(move over input: %s,\n" % move)
|
||||
out_act = out_act == 1
|
||||
outfile.write(
|
||||
" write to the output tape: %s,\n"
|
||||
% out_act
|
||||
)
|
||||
outfile.write(" write to the output tape: %s,\n" % out_act)
|
||||
outfile.write(" prediction: %s)\n" % pred_str)
|
||||
else:
|
||||
outfile.write("\n" * 5)
|
||||
@@ -276,9 +271,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
|
||||
x_str = "Observation Tape : "
|
||||
for i in range(-2, self.input_width + 2):
|
||||
if i == x:
|
||||
x_str += colorize(
|
||||
self._get_str_obs(np.array([i])), "green", highlight=True
|
||||
)
|
||||
x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
|
||||
else:
|
||||
x_str += self._get_str_obs(np.array([i]))
|
||||
x_str += "\n"
|
||||
@@ -311,10 +304,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
|
||||
self.read_head_position = x, y
|
||||
|
||||
def generate_input_data(self, size):
|
||||
return [
|
||||
[self.np_random.randint(self.base) for _ in range(self.rows)]
|
||||
for __ in range(size)
|
||||
]
|
||||
return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]
|
||||
|
||||
def _get_obs(self, pos=None):
|
||||
if pos is None:
|
||||
@@ -336,9 +326,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
|
||||
x_str += " " * len(label)
|
||||
for i in range(-2, self.input_width + 2):
|
||||
if i == x[0] and j == x[1]:
|
||||
x_str += colorize(
|
||||
self._get_str_obs((i, j)), "green", highlight=True
|
||||
)
|
||||
x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
|
||||
else:
|
||||
x_str += self._get_str_obs((i, j))
|
||||
x_str += "\n"
|
||||
|
@@ -10,12 +10,8 @@ ALL_ENVS = [
|
||||
alg.reverse.ReverseEnv,
|
||||
alg.reversed_addition.ReversedAdditionEnv,
|
||||
]
|
||||
ALL_TAPE_ENVS = [
|
||||
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)
|
||||
]
|
||||
ALL_GRID_ENVS = [
|
||||
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)
|
||||
]
|
||||
ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)]
|
||||
ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)]
|
||||
|
||||
|
||||
def imprint(env, input_arr):
|
||||
@@ -92,10 +88,7 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
|
||||
|
||||
def test_grid_naviation(self):
|
||||
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
|
||||
N, S, E, W = [
|
||||
env._movement_idx(named_dir)
|
||||
for named_dir in ["up", "down", "right", "left"]
|
||||
]
|
||||
N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]]
|
||||
# Corresponds to a grid that looks like...
|
||||
# 0 1 2
|
||||
# 3 4 5
|
||||
@@ -204,9 +197,7 @@ class TestTargets(unittest.TestCase):
|
||||
|
||||
def test_repeat_copy_target(self):
|
||||
env = alg.repeat_copy.RepeatCopyEnv()
|
||||
self.assertEqual(
|
||||
env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]
|
||||
)
|
||||
self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2])
|
||||
|
||||
|
||||
class TestInputGeneration(unittest.TestCase):
|
||||
|
@@ -9,8 +9,7 @@ try:
|
||||
import atari_py
|
||||
except ImportError as e:
|
||||
raise error.DependencyNotInstalled(
|
||||
"{}. (HINT: you can install Atari dependencies by running "
|
||||
"'pip install gym[atari]'.)".format(e)
|
||||
"{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e)
|
||||
)
|
||||
|
||||
|
||||
@@ -64,35 +63,23 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
||||
|
||||
# Tune (or disable) ALE's action repeat:
|
||||
# https://github.com/openai/gym/issues/349
|
||||
assert isinstance(
|
||||
repeat_action_probability, (float, int)
|
||||
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
|
||||
self.ale.setFloat(
|
||||
"repeat_action_probability".encode("utf-8"), repeat_action_probability
|
||||
assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format(
|
||||
repeat_action_probability
|
||||
)
|
||||
self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability)
|
||||
|
||||
self.seed()
|
||||
|
||||
self._action_set = (
|
||||
self.ale.getLegalActionSet()
|
||||
if full_action_space
|
||||
else self.ale.getMinimalActionSet()
|
||||
)
|
||||
self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet()
|
||||
self.action_space = spaces.Discrete(len(self._action_set))
|
||||
|
||||
(screen_width, screen_height) = self.ale.getScreenDims()
|
||||
if self._obs_type == "ram":
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, dtype=np.uint8, shape=(128,)
|
||||
)
|
||||
self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,))
|
||||
elif self._obs_type == "image":
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
|
||||
)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8)
|
||||
else:
|
||||
raise error.Error(
|
||||
"Unrecognized observation type: {}".format(self._obs_type)
|
||||
)
|
||||
raise error.Error("Unrecognized observation type: {}".format(self._obs_type))
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed1 = seeding.np_random(seed)
|
||||
@@ -107,9 +94,9 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
||||
if self.game_mode is not None:
|
||||
modes = self.ale.getAvailableModes()
|
||||
|
||||
assert self.game_mode in modes, (
|
||||
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
|
||||
).format(self.game_mode, self.game, modes)
|
||||
assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format(
|
||||
self.game_mode, self.game, modes
|
||||
)
|
||||
self.ale.setMode(self.game_mode)
|
||||
|
||||
if self.game_difficulty is not None:
|
||||
|
@@ -100,10 +100,7 @@ class ContactDetector(contactListener):
|
||||
self.env = env
|
||||
|
||||
def BeginContact(self, contact):
|
||||
if (
|
||||
self.env.hull == contact.fixtureA.body
|
||||
or self.env.hull == contact.fixtureB.body
|
||||
):
|
||||
if self.env.hull == contact.fixtureA.body or self.env.hull == contact.fixtureB.body:
|
||||
self.env.game_over = True
|
||||
for leg in [self.env.legs[1], self.env.legs[3]]:
|
||||
if leg in [contact.fixtureA.body, contact.fixtureB.body]:
|
||||
@@ -202,9 +199,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
|
||||
self.terrain.append(t)
|
||||
|
||||
self.fd_polygon.shape.vertices = [
|
||||
(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly
|
||||
]
|
||||
self.fd_polygon.shape.vertices = [(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly]
|
||||
t = self.world.CreateStaticBody(fixtures=self.fd_polygon)
|
||||
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
|
||||
self.terrain.append(t)
|
||||
@@ -301,12 +296,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
y = VIEWPORT_H / SCALE * 3 / 4
|
||||
poly = [
|
||||
(
|
||||
x
|
||||
+ 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5)
|
||||
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
||||
y
|
||||
+ 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5)
|
||||
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
||||
x + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
||||
y + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
||||
)
|
||||
for a in range(5)
|
||||
]
|
||||
@@ -331,14 +322,10 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2
|
||||
init_y = TERRAIN_HEIGHT + 2 * LEG_H
|
||||
self.hull = self.world.CreateDynamicBody(
|
||||
position=(init_x, init_y), fixtures=HULL_FD
|
||||
)
|
||||
self.hull = self.world.CreateDynamicBody(position=(init_x, init_y), fixtures=HULL_FD)
|
||||
self.hull.color1 = (0.5, 0.4, 0.9)
|
||||
self.hull.color2 = (0.3, 0.3, 0.5)
|
||||
self.hull.ApplyForceToCenter(
|
||||
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
|
||||
)
|
||||
self.hull.ApplyForceToCenter((self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True)
|
||||
|
||||
self.legs = []
|
||||
self.joints = []
|
||||
@@ -412,21 +399,13 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1))
|
||||
else:
|
||||
self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0]))
|
||||
self.joints[0].maxMotorTorque = float(
|
||||
MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1)
|
||||
)
|
||||
self.joints[0].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1))
|
||||
self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1]))
|
||||
self.joints[1].maxMotorTorque = float(
|
||||
MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1)
|
||||
)
|
||||
self.joints[1].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1))
|
||||
self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2]))
|
||||
self.joints[2].maxMotorTorque = float(
|
||||
MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1)
|
||||
)
|
||||
self.joints[2].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1))
|
||||
self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3]))
|
||||
self.joints[3].maxMotorTorque = float(
|
||||
MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1)
|
||||
)
|
||||
self.joints[3].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1))
|
||||
|
||||
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
|
||||
|
||||
@@ -465,12 +444,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
self.scroll = pos.x - VIEWPORT_W / SCALE / 5
|
||||
|
||||
shaping = (
|
||||
130 * pos[0] / SCALE
|
||||
) # moving forward is a way to receive reward (normalized to get 300 on completion)
|
||||
shaping -= 5.0 * abs(
|
||||
state[0]
|
||||
) # keep head straight, other than that and falling, any behavior is unpunished
|
||||
shaping = 130 * pos[0] / SCALE # moving forward is a way to receive reward (normalized to get 300 on completion)
|
||||
shaping -= 5.0 * abs(state[0]) # keep head straight, other than that and falling, any behavior is unpunished
|
||||
|
||||
reward = 0
|
||||
if self.prev_shaping is not None:
|
||||
@@ -494,9 +469,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
if self.viewer is None:
|
||||
self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H)
|
||||
self.viewer.set_bounds(
|
||||
self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE
|
||||
)
|
||||
self.viewer.set_bounds(self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE)
|
||||
|
||||
self.viewer.draw_polygon(
|
||||
[
|
||||
@@ -512,9 +485,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
continue
|
||||
if x1 > self.scroll / 2 + VIEWPORT_W / SCALE:
|
||||
continue
|
||||
self.viewer.draw_polygon(
|
||||
[(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1)
|
||||
)
|
||||
self.viewer.draw_polygon([(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1))
|
||||
for poly, color in self.terrain_poly:
|
||||
if poly[1][0] < self.scroll:
|
||||
continue
|
||||
@@ -525,11 +496,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.lidar_render = (self.lidar_render + 1) % 100
|
||||
i = self.lidar_render
|
||||
if i < 2 * len(self.lidar):
|
||||
l = (
|
||||
self.lidar[i]
|
||||
if i < len(self.lidar)
|
||||
else self.lidar[len(self.lidar) - i - 1]
|
||||
)
|
||||
l = self.lidar[i] if i < len(self.lidar) else self.lidar[len(self.lidar) - i - 1]
|
||||
self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1)
|
||||
|
||||
for obj in self.drawlist:
|
||||
@@ -537,12 +504,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
trans = f.body.transform
|
||||
if type(f.shape) is circleShape:
|
||||
t = rendering.Transform(translation=trans * f.shape.pos)
|
||||
self.viewer.draw_circle(
|
||||
f.shape.radius, 30, color=obj.color1
|
||||
).add_attr(t)
|
||||
self.viewer.draw_circle(
|
||||
f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2
|
||||
).add_attr(t)
|
||||
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t)
|
||||
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t)
|
||||
else:
|
||||
path = [trans * v for v in f.shape.vertices]
|
||||
self.viewer.draw_polygon(path, color=obj.color1)
|
||||
@@ -552,9 +515,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
flagy1 = TERRAIN_HEIGHT
|
||||
flagy2 = flagy1 + 50 / SCALE
|
||||
x = TERRAIN_STEP * 3
|
||||
self.viewer.draw_polyline(
|
||||
[(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2
|
||||
)
|
||||
self.viewer.draw_polyline([(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2)
|
||||
f = [
|
||||
(x, flagy2),
|
||||
(x, flagy2 - 10 / SCALE),
|
||||
|
@@ -23,9 +23,7 @@ from Box2D.b2 import (
|
||||
SIZE = 0.02
|
||||
ENGINE_POWER = 100000000 * SIZE * SIZE
|
||||
WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE
|
||||
FRICTION_LIMIT = (
|
||||
1000000 * SIZE * SIZE
|
||||
) # friction ~= mass ~= size^2 (calculated implicitly using density)
|
||||
FRICTION_LIMIT = 1000000 * SIZE * SIZE # friction ~= mass ~= size^2 (calculated implicitly using density)
|
||||
WHEEL_R = 27
|
||||
WHEEL_W = 14
|
||||
WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)]
|
||||
@@ -55,27 +53,19 @@ class Car:
|
||||
angle=init_angle,
|
||||
fixtures=[
|
||||
fixtureDef(
|
||||
shape=polygonShape(
|
||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]
|
||||
),
|
||||
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]),
|
||||
density=1.0,
|
||||
),
|
||||
fixtureDef(
|
||||
shape=polygonShape(
|
||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]
|
||||
),
|
||||
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]),
|
||||
density=1.0,
|
||||
),
|
||||
fixtureDef(
|
||||
shape=polygonShape(
|
||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]
|
||||
),
|
||||
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]),
|
||||
density=1.0,
|
||||
),
|
||||
fixtureDef(
|
||||
shape=polygonShape(
|
||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]
|
||||
),
|
||||
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]),
|
||||
density=1.0,
|
||||
),
|
||||
],
|
||||
@@ -95,12 +85,7 @@ class Car:
|
||||
position=(init_x + wx * SIZE, init_y + wy * SIZE),
|
||||
angle=init_angle,
|
||||
fixtures=fixtureDef(
|
||||
shape=polygonShape(
|
||||
vertices=[
|
||||
(x * front_k * SIZE, y * front_k * SIZE)
|
||||
for x, y in WHEEL_POLY
|
||||
]
|
||||
),
|
||||
shape=polygonShape(vertices=[(x * front_k * SIZE, y * front_k * SIZE) for x, y in WHEEL_POLY]),
|
||||
density=0.1,
|
||||
categoryBits=0x0020,
|
||||
maskBits=0x001,
|
||||
@@ -175,9 +160,7 @@ class Car:
|
||||
grass = True
|
||||
friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile
|
||||
for tile in w.tiles:
|
||||
friction_limit = max(
|
||||
friction_limit, FRICTION_LIMIT * tile.road_friction
|
||||
)
|
||||
friction_limit = max(friction_limit, FRICTION_LIMIT * tile.road_friction)
|
||||
grass = False
|
||||
|
||||
# Force
|
||||
@@ -192,13 +175,7 @@ class Car:
|
||||
# domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega
|
||||
|
||||
# add small coef not to divide by zero
|
||||
w.omega += (
|
||||
dt
|
||||
* ENGINE_POWER
|
||||
* w.gas
|
||||
/ WHEEL_MOMENT_OF_INERTIA
|
||||
/ (abs(w.omega) + 5.0)
|
||||
)
|
||||
w.omega += dt * ENGINE_POWER * w.gas / WHEEL_MOMENT_OF_INERTIA / (abs(w.omega) + 5.0)
|
||||
self.fuel_spent += dt * ENGINE_POWER * w.gas
|
||||
|
||||
if w.brake >= 0.9:
|
||||
@@ -226,18 +203,12 @@ class Car:
|
||||
|
||||
# Skid trace
|
||||
if abs(force) > 2.0 * friction_limit:
|
||||
if (
|
||||
w.skid_particle
|
||||
and w.skid_particle.grass == grass
|
||||
and len(w.skid_particle.poly) < 30
|
||||
):
|
||||
if w.skid_particle and w.skid_particle.grass == grass and len(w.skid_particle.poly) < 30:
|
||||
w.skid_particle.poly.append((w.position[0], w.position[1]))
|
||||
elif w.skid_start is None:
|
||||
w.skid_start = w.position
|
||||
else:
|
||||
w.skid_particle = self._create_particle(
|
||||
w.skid_start, w.position, grass
|
||||
)
|
||||
w.skid_particle = self._create_particle(w.skid_start, w.position, grass)
|
||||
w.skid_start = None
|
||||
else:
|
||||
w.skid_start = None
|
||||
|
@@ -132,18 +132,14 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self.reward = 0.0
|
||||
self.prev_reward = 0.0
|
||||
self.verbose = verbose
|
||||
self.fd_tile = fixtureDef(
|
||||
shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])
|
||||
)
|
||||
self.fd_tile = fixtureDef(shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)]))
|
||||
|
||||
self.action_space = spaces.Box(
|
||||
np.array([-1, 0, 0]).astype(np.float32),
|
||||
np.array([+1, +1, +1]).astype(np.float32),
|
||||
) # steer, gas, brake
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
|
||||
)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8)
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
@@ -246,9 +242,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
i -= 1
|
||||
if i == 0:
|
||||
return False # Failed
|
||||
pass_through_start = (
|
||||
track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
|
||||
)
|
||||
pass_through_start = track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
|
||||
if pass_through_start and i2 == -1:
|
||||
i2 = i
|
||||
elif pass_through_start and i1 == -1:
|
||||
@@ -266,8 +260,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
first_perp_y = math.sin(first_beta)
|
||||
# Length of perpendicular jump to put together head and tail
|
||||
well_glued_together = np.sqrt(
|
||||
np.square(first_perp_x * (track[0][2] - track[-1][2]))
|
||||
+ np.square(first_perp_y * (track[0][3] - track[-1][3]))
|
||||
np.square(first_perp_x * (track[0][2] - track[-1][2])) + np.square(first_perp_y * (track[0][3] - track[-1][3]))
|
||||
)
|
||||
if well_glued_together > TRACK_DETAIL_STEP:
|
||||
return False
|
||||
@@ -337,9 +330,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
|
||||
y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
|
||||
)
|
||||
self.road_poly.append(
|
||||
([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0))
|
||||
)
|
||||
self.road_poly.append(([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0)))
|
||||
self.track = track
|
||||
return True
|
||||
|
||||
@@ -356,10 +347,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
if success:
|
||||
break
|
||||
if self.verbose == 1:
|
||||
print(
|
||||
"retry to generate track (normal if there are not many"
|
||||
"instances of this message)"
|
||||
)
|
||||
print("retry to generate track (normal if there are not many" "instances of this message)")
|
||||
self.car = Car(self.world, *self.track[0][1:4])
|
||||
|
||||
return self.step(None)[0]
|
||||
@@ -424,10 +412,8 @@ class CarRacing(gym.Env, EzPickle):
|
||||
angle = math.atan2(vel[0], vel[1])
|
||||
self.transform.set_scale(zoom, zoom)
|
||||
self.transform.set_translation(
|
||||
WINDOW_W / 2
|
||||
- (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
|
||||
WINDOW_H / 4
|
||||
- (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
|
||||
WINDOW_W / 2 - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
|
||||
WINDOW_H / 4 - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
|
||||
)
|
||||
self.transform.set_rotation(angle)
|
||||
|
||||
@@ -449,9 +435,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
else:
|
||||
pixel_scale = 1
|
||||
if hasattr(win.context, "_nscontext"):
|
||||
pixel_scale = (
|
||||
win.context._nscontext.view().backingScaleFactor()
|
||||
) # pylint: disable=protected-access
|
||||
pixel_scale = win.context._nscontext.view().backingScaleFactor() # pylint: disable=protected-access
|
||||
VP_W = int(pixel_scale * WINDOW_W)
|
||||
VP_H = int(pixel_scale * WINDOW_H)
|
||||
|
||||
@@ -468,9 +452,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
win.flip()
|
||||
return self.viewer.isopen
|
||||
|
||||
image_data = (
|
||||
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
||||
)
|
||||
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
||||
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
|
||||
arr = arr.reshape(VP_H, VP_W, 4)
|
||||
arr = arr[::-1, :, 0:3]
|
||||
@@ -525,9 +507,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
for p in poly:
|
||||
polygons_.extend([p[0], p[1], 0])
|
||||
|
||||
vl = pyglet.graphics.vertex_list(
|
||||
len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors) # gl.GL_QUADS,
|
||||
)
|
||||
vl = pyglet.graphics.vertex_list(len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)) # gl.GL_QUADS,
|
||||
vl.draw(gl.GL_QUADS)
|
||||
vl.delete()
|
||||
|
||||
@@ -575,10 +555,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
]
|
||||
)
|
||||
|
||||
true_speed = np.sqrt(
|
||||
np.square(self.car.hull.linearVelocity[0])
|
||||
+ np.square(self.car.hull.linearVelocity[1])
|
||||
)
|
||||
true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1]))
|
||||
|
||||
vertical_ind(5, 0.02 * true_speed, (1, 1, 1))
|
||||
vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors
|
||||
@@ -587,9 +564,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1))
|
||||
horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0))
|
||||
horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0))
|
||||
vl = pyglet.graphics.vertex_list(
|
||||
len(polygons) // 3, ("v3f", polygons), ("c4f", colors) # gl.GL_QUADS,
|
||||
)
|
||||
vl = pyglet.graphics.vertex_list(len(polygons) // 3, ("v3f", polygons), ("c4f", colors)) # gl.GL_QUADS,
|
||||
vl.draw(gl.GL_QUADS)
|
||||
vl.delete()
|
||||
self.score_label.text = "%04i" % self.reward
|
||||
|
@@ -71,10 +71,7 @@ class ContactDetector(contactListener):
|
||||
self.env = env
|
||||
|
||||
def BeginContact(self, contact):
|
||||
if (
|
||||
self.env.lander == contact.fixtureA.body
|
||||
or self.env.lander == contact.fixtureB.body
|
||||
):
|
||||
if self.env.lander == contact.fixtureA.body or self.env.lander == contact.fixtureB.body:
|
||||
self.env.game_over = True
|
||||
for i in range(2):
|
||||
if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
|
||||
@@ -104,9 +101,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.prev_reward = None
|
||||
|
||||
# useful range is -1 .. +1, but spikes can be higher
|
||||
self.observation_space = spaces.Box(
|
||||
-np.inf, np.inf, shape=(8,), dtype=np.float32
|
||||
)
|
||||
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32)
|
||||
|
||||
if self.continuous:
|
||||
# Action is two floats [main engine, left-right engines].
|
||||
@@ -157,14 +152,9 @@ class LunarLander(gym.Env, EzPickle):
|
||||
height[CHUNKS // 2 + 0] = self.helipad_y
|
||||
height[CHUNKS // 2 + 1] = self.helipad_y
|
||||
height[CHUNKS // 2 + 2] = self.helipad_y
|
||||
smooth_y = [
|
||||
0.33 * (height[i - 1] + height[i + 0] + height[i + 1])
|
||||
for i in range(CHUNKS)
|
||||
]
|
||||
smooth_y = [0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) for i in range(CHUNKS)]
|
||||
|
||||
self.moon = self.world.CreateStaticBody(
|
||||
shapes=edgeShape(vertices=[(0, 0), (W, 0)])
|
||||
)
|
||||
self.moon = self.world.CreateStaticBody(shapes=edgeShape(vertices=[(0, 0), (W, 0)]))
|
||||
self.sky_polys = []
|
||||
for i in range(CHUNKS - 1):
|
||||
p1 = (chunk_x[i], smooth_y[i])
|
||||
@@ -180,9 +170,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
position=(VIEWPORT_W / SCALE / 2, initial_y),
|
||||
angle=0.0,
|
||||
fixtures=fixtureDef(
|
||||
shape=polygonShape(
|
||||
vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]
|
||||
),
|
||||
shape=polygonShape(vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]),
|
||||
density=5.0,
|
||||
friction=0.1,
|
||||
categoryBits=0x0010,
|
||||
@@ -227,9 +215,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
motorSpeed=+0.3 * i, # low enough not to jump back into the sky
|
||||
)
|
||||
if i == -1:
|
||||
rjd.lowerAngle = (
|
||||
+0.9 - 0.5
|
||||
) # The most esoteric numbers here, angled legs have freedom to travel within
|
||||
rjd.lowerAngle = +0.9 - 0.5 # The most esoteric numbers here, angled legs have freedom to travel within
|
||||
rjd.upperAngle = +0.9
|
||||
else:
|
||||
rjd.lowerAngle = -0.9
|
||||
@@ -278,9 +264,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]
|
||||
|
||||
m_power = 0.0
|
||||
if (self.continuous and action[0] > 0.0) or (
|
||||
not self.continuous and action == 2
|
||||
):
|
||||
if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2):
|
||||
# Main engine
|
||||
if self.continuous:
|
||||
m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0
|
||||
@@ -310,9 +294,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
)
|
||||
|
||||
s_power = 0.0
|
||||
if (self.continuous and np.abs(action[1]) > 0.5) or (
|
||||
not self.continuous and action in [1, 3]
|
||||
):
|
||||
if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]):
|
||||
# Orientation engines
|
||||
if self.continuous:
|
||||
direction = np.sign(action[1])
|
||||
@@ -321,12 +303,8 @@ class LunarLander(gym.Env, EzPickle):
|
||||
else:
|
||||
direction = action - 2
|
||||
s_power = 1.0
|
||||
ox = tip[0] * dispersion[0] + side[0] * (
|
||||
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
|
||||
)
|
||||
oy = -tip[1] * dispersion[0] - side[1] * (
|
||||
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
|
||||
)
|
||||
ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
|
||||
oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
|
||||
impulse_pos = (
|
||||
self.lander.position[0] + ox - tip[0] * 17 / SCALE,
|
||||
self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
|
||||
@@ -372,9 +350,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
reward = shaping - self.prev_shaping
|
||||
self.prev_shaping = shaping
|
||||
|
||||
reward -= (
|
||||
m_power * 0.30
|
||||
) # less fuel spent is better, about -30 for heuristic landing
|
||||
reward -= m_power * 0.30 # less fuel spent is better, about -30 for heuristic landing
|
||||
reward -= s_power * 0.03
|
||||
|
||||
done = False
|
||||
@@ -416,12 +392,8 @@ class LunarLander(gym.Env, EzPickle):
|
||||
trans = f.body.transform
|
||||
if type(f.shape) is circleShape:
|
||||
t = rendering.Transform(translation=trans * f.shape.pos)
|
||||
self.viewer.draw_circle(
|
||||
f.shape.radius, 20, color=obj.color1
|
||||
).add_attr(t)
|
||||
self.viewer.draw_circle(
|
||||
f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2
|
||||
).add_attr(t)
|
||||
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color1).add_attr(t)
|
||||
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2).add_attr(t)
|
||||
else:
|
||||
path = [trans * v for v in f.shape.vertices]
|
||||
self.viewer.draw_polygon(path, color=obj.color1)
|
||||
@@ -479,18 +451,14 @@ def heuristic(env, s):
|
||||
angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad
|
||||
if angle_targ < -0.4:
|
||||
angle_targ = -0.4
|
||||
hover_targ = 0.55 * np.abs(
|
||||
s[0]
|
||||
) # target y should be proportional to horizontal offset
|
||||
hover_targ = 0.55 * np.abs(s[0]) # target y should be proportional to horizontal offset
|
||||
|
||||
angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
|
||||
hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5
|
||||
|
||||
if s[6] or s[7]: # legs have contact
|
||||
angle_todo = 0
|
||||
hover_todo = (
|
||||
-(s[3]) * 0.5
|
||||
) # override to reduce fall speed, that's all we need after contact
|
||||
hover_todo = -(s[3]) * 0.5 # override to reduce fall speed, that's all we need after contact
|
||||
|
||||
if env.continuous:
|
||||
a = np.array([hover_todo * 20 - 1, -angle_todo * 20])
|
||||
|
@@ -88,9 +88,7 @@ class AcrobotEnv(core.Env):
|
||||
|
||||
def __init__(self):
|
||||
self.viewer = None
|
||||
high = np.array(
|
||||
[1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32
|
||||
)
|
||||
high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32)
|
||||
low = -high
|
||||
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
@@ -111,9 +109,7 @@ class AcrobotEnv(core.Env):
|
||||
|
||||
# Add noise to the force action
|
||||
if self.torque_noise_max > 0:
|
||||
torque += self.np_random.uniform(
|
||||
-self.torque_noise_max, self.torque_noise_max
|
||||
)
|
||||
torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max)
|
||||
|
||||
# Now, augment the state with our force action so it can be passed to
|
||||
# _dsdt
|
||||
@@ -160,12 +156,7 @@ class AcrobotEnv(core.Env):
|
||||
theta2 = s[1]
|
||||
dtheta1 = s[2]
|
||||
dtheta2 = s[3]
|
||||
d1 = (
|
||||
m1 * lc1 ** 2
|
||||
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2))
|
||||
+ I1
|
||||
+ I2
|
||||
)
|
||||
d1 = m1 * lc1 ** 2 + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2
|
||||
d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2
|
||||
phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
|
||||
phi1 = (
|
||||
@@ -181,9 +172,9 @@ class AcrobotEnv(core.Env):
|
||||
else:
|
||||
# the following line is consistent with the java implementation and the
|
||||
# book
|
||||
ddtheta2 = (
|
||||
a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2
|
||||
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
|
||||
ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2) / (
|
||||
m2 * lc2 ** 2 + I2 - d2 ** 2 / d1
|
||||
)
|
||||
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
|
||||
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)
|
||||
|
||||
|
@@ -111,9 +111,7 @@ class CartPoleEnv(gym.Env):
|
||||
|
||||
# For the interested reader:
|
||||
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
||||
temp = (
|
||||
force + self.polemass_length * theta_dot ** 2 * sintheta
|
||||
) / self.total_mass
|
||||
temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
|
||||
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
||||
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
|
||||
)
|
||||
|
@@ -62,27 +62,17 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
self.min_position = -1.2
|
||||
self.max_position = 0.6
|
||||
self.max_speed = 0.07
|
||||
self.goal_position = (
|
||||
0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
|
||||
)
|
||||
self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
|
||||
self.goal_velocity = goal_velocity
|
||||
self.power = 0.0015
|
||||
|
||||
self.low_state = np.array(
|
||||
[self.min_position, -self.max_speed], dtype=np.float32
|
||||
)
|
||||
self.high_state = np.array(
|
||||
[self.max_position, self.max_speed], dtype=np.float32
|
||||
)
|
||||
self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
|
||||
self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)
|
||||
|
||||
self.viewer = None
|
||||
|
||||
self.action_space = spaces.Box(
|
||||
low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32
|
||||
)
|
||||
self.observation_space = spaces.Box(
|
||||
low=self.low_state, high=self.high_state, dtype=np.float32
|
||||
)
|
||||
self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32)
|
||||
self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32)
|
||||
|
||||
self.seed()
|
||||
self.reset()
|
||||
@@ -159,15 +149,11 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
self.viewer.add_geom(car)
|
||||
frontwheel = rendering.make_circle(carheight / 2.5)
|
||||
frontwheel.set_color(0.5, 0.5, 0.5)
|
||||
frontwheel.add_attr(
|
||||
rendering.Transform(translation=(carwidth / 4, clearance))
|
||||
)
|
||||
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
|
||||
frontwheel.add_attr(self.cartrans)
|
||||
self.viewer.add_geom(frontwheel)
|
||||
backwheel = rendering.make_circle(carheight / 2.5)
|
||||
backwheel.add_attr(
|
||||
rendering.Transform(translation=(-carwidth / 4, clearance))
|
||||
)
|
||||
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
|
||||
backwheel.add_attr(self.cartrans)
|
||||
backwheel.set_color(0.5, 0.5, 0.5)
|
||||
self.viewer.add_geom(backwheel)
|
||||
@@ -176,16 +162,12 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
flagy2 = flagy1 + 50
|
||||
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
|
||||
self.viewer.add_geom(flagpole)
|
||||
flag = rendering.FilledPolygon(
|
||||
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
|
||||
)
|
||||
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
|
||||
flag.set_color(0.8, 0.8, 0)
|
||||
self.viewer.add_geom(flag)
|
||||
|
||||
pos = self.state[0]
|
||||
self.cartrans.set_translation(
|
||||
(pos - self.min_position) * scale, self._height(pos) * scale
|
||||
)
|
||||
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
|
||||
self.cartrans.set_rotation(math.cos(3 * pos))
|
||||
|
||||
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
||||
|
@@ -136,15 +136,11 @@ class MountainCarEnv(gym.Env):
|
||||
self.viewer.add_geom(car)
|
||||
frontwheel = rendering.make_circle(carheight / 2.5)
|
||||
frontwheel.set_color(0.5, 0.5, 0.5)
|
||||
frontwheel.add_attr(
|
||||
rendering.Transform(translation=(carwidth / 4, clearance))
|
||||
)
|
||||
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
|
||||
frontwheel.add_attr(self.cartrans)
|
||||
self.viewer.add_geom(frontwheel)
|
||||
backwheel = rendering.make_circle(carheight / 2.5)
|
||||
backwheel.add_attr(
|
||||
rendering.Transform(translation=(-carwidth / 4, clearance))
|
||||
)
|
||||
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
|
||||
backwheel.add_attr(self.cartrans)
|
||||
backwheel.set_color(0.5, 0.5, 0.5)
|
||||
self.viewer.add_geom(backwheel)
|
||||
@@ -153,16 +149,12 @@ class MountainCarEnv(gym.Env):
|
||||
flagy2 = flagy1 + 50
|
||||
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
|
||||
self.viewer.add_geom(flagpole)
|
||||
flag = rendering.FilledPolygon(
|
||||
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
|
||||
)
|
||||
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
|
||||
flag.set_color(0.8, 0.8, 0)
|
||||
self.viewer.add_geom(flag)
|
||||
|
||||
pos = self.state[0]
|
||||
self.cartrans.set_translation(
|
||||
(pos - self.min_position) * scale, self._height(pos) * scale
|
||||
)
|
||||
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
|
||||
self.cartrans.set_rotation(math.cos(3 * pos))
|
||||
|
||||
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
||||
|
@@ -18,9 +18,7 @@ class PendulumEnv(gym.Env):
|
||||
self.viewer = None
|
||||
|
||||
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
|
||||
self.action_space = spaces.Box(
|
||||
low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32
|
||||
)
|
||||
self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
|
||||
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
|
||||
|
||||
self.seed()
|
||||
@@ -41,10 +39,7 @@ class PendulumEnv(gym.Env):
|
||||
self.last_u = u # for rendering
|
||||
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2)
|
||||
|
||||
newthdot = (
|
||||
thdot
|
||||
+ (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
|
||||
)
|
||||
newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
|
||||
newth = th + newthdot * dt
|
||||
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)
|
||||
|
||||
|
@@ -54,11 +54,7 @@ def get_display(spec):
|
||||
elif isinstance(spec, str):
|
||||
return pyglet.canvas.Display(spec)
|
||||
else:
|
||||
raise error.Error(
|
||||
"Invalid display specification: {}. (Must be a string like :0 or None.)".format(
|
||||
spec
|
||||
)
|
||||
)
|
||||
raise error.Error("Invalid display specification: {}. (Must be a string like :0 or None.)".format(spec))
|
||||
|
||||
|
||||
def get_window(width, height, display, **kwargs):
|
||||
@@ -69,14 +65,7 @@ def get_window(width, height, display, **kwargs):
|
||||
config = screen[0].get_best_config() # selecting the first screen
|
||||
context = config.create_context(None) # create GL context
|
||||
|
||||
return pyglet.window.Window(
|
||||
width=width,
|
||||
height=height,
|
||||
display=display,
|
||||
config=config,
|
||||
context=context,
|
||||
**kwargs
|
||||
)
|
||||
return pyglet.window.Window(width=width, height=height, display=display, config=config, context=context, **kwargs)
|
||||
|
||||
|
||||
class Viewer(object):
|
||||
@@ -108,9 +97,7 @@ class Viewer(object):
|
||||
assert right > left and top > bottom
|
||||
scalex = self.width / (right - left)
|
||||
scaley = self.height / (top - bottom)
|
||||
self.transform = Transform(
|
||||
translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley)
|
||||
)
|
||||
self.transform = Transform(translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley))
|
||||
|
||||
def add_geom(self, geom):
|
||||
self.geoms.append(geom)
|
||||
@@ -173,9 +160,7 @@ class Viewer(object):
|
||||
|
||||
def get_array(self):
|
||||
self.window.flip()
|
||||
image_data = (
|
||||
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
||||
)
|
||||
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
||||
self.window.flip()
|
||||
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
|
||||
arr = arr.reshape(self.height, self.width, 4)
|
||||
@@ -230,9 +215,7 @@ class Transform(Attr):
|
||||
|
||||
def enable(self):
|
||||
glPushMatrix()
|
||||
glTranslatef(
|
||||
self.translation[0], self.translation[1], 0
|
||||
) # translate to GL loc ppint
|
||||
glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
|
||||
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
|
||||
glScalef(self.scale[0], self.scale[1], 1)
|
||||
|
||||
@@ -392,9 +375,7 @@ class Image(Geom):
|
||||
self.flip = False
|
||||
|
||||
def render1(self):
|
||||
self.img.blit(
|
||||
-self.width / 2, -self.height / 2, width=self.width, height=self.height
|
||||
)
|
||||
self.img.blit(-self.width / 2, -self.height / 2, width=self.width, height=self.height)
|
||||
|
||||
|
||||
# ================================================================
|
||||
@@ -435,9 +416,7 @@ class SimpleImageViewer(object):
|
||||
self.isopen = False
|
||||
|
||||
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
|
||||
image = pyglet.image.ImageData(
|
||||
arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3
|
||||
)
|
||||
image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
|
||||
texture = image.get_texture()
|
||||
texture.width = self.width
|
||||
|
@@ -14,9 +14,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
xposafter = self.get_body_com("torso")[0]
|
||||
forward_reward = (xposafter - xposbefore) / self.dt
|
||||
ctrl_cost = 0.5 * np.square(a).sum()
|
||||
contact_cost = (
|
||||
0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
|
||||
)
|
||||
contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
|
||||
survive_reward = 1.0
|
||||
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
|
||||
state = self.state_vector()
|
||||
@@ -45,9 +43,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
)
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
size=self.model.nq, low=-0.1, high=0.1
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1)
|
||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
@@ -34,18 +34,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
||||
|
||||
@property
|
||||
def healthy_reward(self):
|
||||
return (
|
||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
||||
* self._healthy_reward
|
||||
)
|
||||
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||
|
||||
def control_cost(self, action):
|
||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
||||
@@ -60,9 +55,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
@property
|
||||
def contact_cost(self):
|
||||
contact_cost = self._contact_cost_weight * np.sum(
|
||||
np.square(self.contact_forces)
|
||||
)
|
||||
contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces))
|
||||
return contact_cost
|
||||
|
||||
@property
|
||||
@@ -128,12 +121,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
||||
self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
observation = self._get_obs()
|
||||
|
@@ -28,9 +28,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
)
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=-0.1, high=0.1, size=self.model.nq
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
|
||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
@@ -25,9 +25,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
||||
|
||||
@@ -71,12 +69,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
||||
self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
|
||||
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
|
@@ -17,27 +17,16 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
reward += alive_bonus
|
||||
reward -= 1e-3 * np.square(a).sum()
|
||||
s = self.state_vector()
|
||||
done = not (
|
||||
np.isfinite(s).all()
|
||||
and (np.abs(s[2:]) < 100).all()
|
||||
and (height > 0.7)
|
||||
and (abs(ang) < 0.2)
|
||||
)
|
||||
done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2))
|
||||
ob = self._get_obs()
|
||||
return ob, reward, done, {}
|
||||
|
||||
def _get_obs(self):
|
||||
return np.concatenate(
|
||||
[self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
|
||||
)
|
||||
return np.concatenate([self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)])
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=-0.005, high=0.005, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=-0.005, high=0.005, size=self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -40,18 +40,13 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
||||
|
||||
@property
|
||||
def healthy_reward(self):
|
||||
return (
|
||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
||||
* self._healthy_reward
|
||||
)
|
||||
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||
|
||||
def control_cost(self, action):
|
||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
||||
@@ -117,12 +112,8 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
|
@@ -43,18 +43,13 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
||||
|
||||
@property
|
||||
def healthy_reward(self):
|
||||
return (
|
||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
||||
* self._healthy_reward
|
||||
)
|
||||
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||
|
||||
def control_cost(self, action):
|
||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl))
|
||||
@@ -143,12 +138,8 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
observation = self._get_obs()
|
||||
|
@@ -33,8 +33,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
def reset_model(self):
|
||||
self.set_state(
|
||||
self.init_qpos
|
||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
|
||||
)
|
||||
return self._get_obs()
|
||||
|
@@ -17,12 +17,8 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return ob, reward, done, {}
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
size=self.model.nq, low=-0.01, high=0.01
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
size=self.model.nv, low=-0.01, high=0.01
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
|
||||
qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -22,14 +22,7 @@ DEFAULT_SIZE = 500
|
||||
|
||||
def convert_observation_to_space(observation):
|
||||
if isinstance(observation, dict):
|
||||
space = spaces.Dict(
|
||||
OrderedDict(
|
||||
[
|
||||
(key, convert_observation_to_space(value))
|
||||
for key, value in observation.items()
|
||||
]
|
||||
)
|
||||
)
|
||||
space = spaces.Dict(OrderedDict([(key, convert_observation_to_space(value)) for key, value in observation.items()]))
|
||||
elif isinstance(observation, np.ndarray):
|
||||
low = np.full(observation.shape, -float("inf"), dtype=np.float32)
|
||||
high = np.full(observation.shape, float("inf"), dtype=np.float32)
|
||||
@@ -117,9 +110,7 @@ class MujocoEnv(gym.Env):
|
||||
def set_state(self, qpos, qvel):
|
||||
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
||||
old_state = self.sim.get_state()
|
||||
new_state = mujoco_py.MjSimState(
|
||||
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
|
||||
)
|
||||
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state)
|
||||
self.sim.set_state(new_state)
|
||||
self.sim.forward()
|
||||
|
||||
@@ -142,10 +133,7 @@ class MujocoEnv(gym.Env):
|
||||
):
|
||||
if mode == "rgb_array" or mode == "depth_array":
|
||||
if camera_id is not None and camera_name is not None:
|
||||
raise ValueError(
|
||||
"Both `camera_id` and `camera_name` cannot be"
|
||||
" specified at the same time."
|
||||
)
|
||||
raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.")
|
||||
|
||||
no_camera_specified = camera_name is None and camera_id is None
|
||||
if no_camera_specified:
|
||||
|
@@ -44,9 +44,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
qpos[-4:-2] = self.cylinder_pos
|
||||
qpos[-2:] = self.goal_pos
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=-0.005, high=0.005, size=self.model.nv
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||
qvel[-4:] = 0
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
@@ -22,18 +22,13 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
self.viewer.cam.trackbodyid = 0
|
||||
|
||||
def reset_model(self):
|
||||
qpos = (
|
||||
self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
|
||||
+ self.init_qpos
|
||||
)
|
||||
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
||||
while True:
|
||||
self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
|
||||
if np.linalg.norm(self.goal) < 0.2:
|
||||
break
|
||||
qpos[-2:] = self.goal
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=-0.005, high=0.005, size=self.model.nv
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||
qvel[-2:] = 0
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
@@ -62,9 +62,7 @@ class StrikerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
diff = self.ball - self.goal
|
||||
angle = -np.arctan(diff[0] / (diff[1] + 1e-8))
|
||||
qpos[-1] = angle / 3.14
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=-0.1, high=0.1, size=self.model.nv
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv)
|
||||
qvel[7:] = 0
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
@@ -26,9 +26,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
def reset_model(self):
|
||||
self.set_state(
|
||||
self.init_qpos
|
||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||
self.init_qvel
|
||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
|
||||
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||
self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
|
||||
)
|
||||
return self._get_obs()
|
||||
|
@@ -22,9 +22,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
||||
|
||||
@@ -74,12 +72,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
|
@@ -48,9 +48,7 @@ class ThrowerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
)
|
||||
|
||||
qpos[-9:-7] = self.goal
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=-0.005, high=0.005, size=self.model.nv
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||
qvel[7:] = 0
|
||||
self.set_state(qpos, qvel)
|
||||
return self._get_obs()
|
||||
|
@@ -27,10 +27,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
def reset_model(self):
|
||||
self.set_state(
|
||||
self.init_qpos
|
||||
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
|
||||
self.init_qvel
|
||||
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
|
||||
self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
|
||||
self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
|
||||
)
|
||||
return self._get_obs()
|
||||
|
||||
|
@@ -37,18 +37,13 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||
|
||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
||||
|
||||
@property
|
||||
def healthy_reward(self):
|
||||
return (
|
||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
||||
* self._healthy_reward
|
||||
)
|
||||
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||
|
||||
def control_cost(self, action):
|
||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
||||
@@ -110,12 +105,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
|
||||
qpos = self.init_qpos + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nq
|
||||
)
|
||||
qvel = self.init_qvel + self.np_random.uniform(
|
||||
low=noise_low, high=noise_high, size=self.model.nv
|
||||
)
|
||||
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||
|
||||
self.set_state(qpos, qvel)
|
||||
|
||||
|
@@ -108,11 +108,7 @@ class EnvRegistry(object):
|
||||
# reset/step. Set _gym_disable_underscore_compat = True on
|
||||
# your environment if you use these methods and don't want
|
||||
# compatibility code to be invoked.
|
||||
if (
|
||||
hasattr(env, "_reset")
|
||||
and hasattr(env, "_step")
|
||||
and not getattr(env, "_gym_disable_underscore_compat", False)
|
||||
):
|
||||
if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False):
|
||||
patch_deprecated_methods(env)
|
||||
if env.spec.max_episode_steps is not None:
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
@@ -158,11 +154,7 @@ class EnvRegistry(object):
|
||||
if env_name == valid_env_spec._env_name
|
||||
]
|
||||
if matching_envs:
|
||||
raise error.DeprecatedEnv(
|
||||
"Env {} not found (valid versions include {})".format(
|
||||
id, matching_envs
|
||||
)
|
||||
)
|
||||
raise error.DeprecatedEnv("Env {} not found (valid versions include {})".format(id, matching_envs))
|
||||
else:
|
||||
raise error.UnregisteredEnv("No registered env with id: {}".format(id))
|
||||
|
||||
|
@@ -81,9 +81,7 @@ class FetchEnv(robot_env.RobotEnv):
|
||||
|
||||
def _set_action(self, action):
|
||||
assert action.shape == (4,)
|
||||
action = (
|
||||
action.copy()
|
||||
) # ensure that we don't change the action outside of this scope
|
||||
action = action.copy() # ensure that we don't change the action outside of this scope
|
||||
pos_ctrl, gripper_ctrl = action[:3], action[3]
|
||||
|
||||
pos_ctrl *= 0.05 # limit maximum change in position
|
||||
@@ -120,13 +118,9 @@ class FetchEnv(robot_env.RobotEnv):
|
||||
object_rel_pos = object_pos - grip_pos
|
||||
object_velp -= grip_velp
|
||||
else:
|
||||
object_pos = (
|
||||
object_rot
|
||||
) = object_velp = object_velr = object_rel_pos = np.zeros(0)
|
||||
object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
|
||||
gripper_state = robot_qpos[-2:]
|
||||
gripper_vel = (
|
||||
robot_qvel[-2:] * dt
|
||||
) # change to a scalar if the gripper is made symmetric
|
||||
gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric
|
||||
|
||||
if not self.has_object:
|
||||
achieved_goal = grip_pos.copy()
|
||||
@@ -175,9 +169,7 @@ class FetchEnv(robot_env.RobotEnv):
|
||||
if self.has_object:
|
||||
object_xpos = self.initial_gripper_xpos[:2]
|
||||
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
|
||||
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(
|
||||
-self.obj_range, self.obj_range, size=2
|
||||
)
|
||||
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2)
|
||||
object_qpos = self.sim.data.get_joint_qpos("object0:joint")
|
||||
assert object_qpos.shape == (7,)
|
||||
object_qpos[:2] = object_xpos
|
||||
@@ -188,17 +180,13 @@ class FetchEnv(robot_env.RobotEnv):
|
||||
|
||||
def _sample_goal(self):
|
||||
if self.has_object:
|
||||
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
|
||||
-self.target_range, self.target_range, size=3
|
||||
)
|
||||
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
|
||||
goal += self.target_offset
|
||||
goal[2] = self.height_offset
|
||||
if self.target_in_the_air and self.np_random.uniform() < 0.5:
|
||||
goal[2] += self.np_random.uniform(0, 0.45)
|
||||
else:
|
||||
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
|
||||
-self.target_range, self.target_range, size=3
|
||||
)
|
||||
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
|
||||
return goal.copy()
|
||||
|
||||
def _is_success(self, achieved_goal, desired_goal):
|
||||
@@ -212,9 +200,9 @@ class FetchEnv(robot_env.RobotEnv):
|
||||
self.sim.forward()
|
||||
|
||||
# Move end effector into position.
|
||||
gripper_target = np.array(
|
||||
[-0.498, 0.005, -0.431 + self.gripper_extra_height]
|
||||
) + self.sim.data.get_site_xpos("robot0:grip")
|
||||
gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos(
|
||||
"robot0:grip"
|
||||
)
|
||||
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
|
||||
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
|
||||
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)
|
||||
|
@@ -74,9 +74,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
self.target_position = target_position
|
||||
self.target_rotation = target_rotation
|
||||
self.target_position_range = target_position_range
|
||||
self.parallel_quats = [
|
||||
rotations.euler2quat(r) for r in rotations.get_parallel_rotations()
|
||||
]
|
||||
self.parallel_quats = [rotations.euler2quat(r) for r in rotations.get_parallel_rotations()]
|
||||
self.randomize_initial_rotation = randomize_initial_rotation
|
||||
self.randomize_initial_position = randomize_initial_position
|
||||
self.distance_threshold = distance_threshold
|
||||
@@ -182,9 +180,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
angle = self.np_random.uniform(-np.pi, np.pi)
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
z_quat = quat_from_angle_and_axis(angle, axis)
|
||||
parallel_quat = self.parallel_quats[
|
||||
self.np_random.randint(len(self.parallel_quats))
|
||||
]
|
||||
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
|
||||
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
|
||||
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
|
||||
elif self.target_rotation in ["xyz", "ignore"]:
|
||||
@@ -195,9 +191,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
elif self.target_rotation == "fixed":
|
||||
pass
|
||||
else:
|
||||
raise error.Error(
|
||||
'Unknown target_rotation option "{}".'.format(self.target_rotation)
|
||||
)
|
||||
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
|
||||
|
||||
# Randomize initial position.
|
||||
if self.randomize_initial_position:
|
||||
@@ -229,17 +223,13 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
target_pos = None
|
||||
if self.target_position == "random":
|
||||
assert self.target_position_range.shape == (3, 2)
|
||||
offset = self.np_random.uniform(
|
||||
self.target_position_range[:, 0], self.target_position_range[:, 1]
|
||||
)
|
||||
offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1])
|
||||
assert offset.shape == (3,)
|
||||
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
|
||||
elif self.target_position in ["ignore", "fixed"]:
|
||||
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
|
||||
else:
|
||||
raise error.Error(
|
||||
'Unknown target_position option "{}".'.format(self.target_position)
|
||||
)
|
||||
raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
|
||||
assert target_pos is not None
|
||||
assert target_pos.shape == (3,)
|
||||
|
||||
@@ -253,9 +243,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
angle = self.np_random.uniform(-np.pi, np.pi)
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
target_quat = quat_from_angle_and_axis(angle, axis)
|
||||
parallel_quat = self.parallel_quats[
|
||||
self.np_random.randint(len(self.parallel_quats))
|
||||
]
|
||||
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
|
||||
target_quat = rotations.quat_mul(target_quat, parallel_quat)
|
||||
elif self.target_rotation == "xyz":
|
||||
angle = self.np_random.uniform(-np.pi, np.pi)
|
||||
@@ -264,9 +252,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
elif self.target_rotation in ["ignore", "fixed"]:
|
||||
target_quat = self.sim.data.get_joint_qpos("object:joint")
|
||||
else:
|
||||
raise error.Error(
|
||||
'Unknown target_rotation option "{}".'.format(self.target_rotation)
|
||||
)
|
||||
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
|
||||
assert target_quat is not None
|
||||
assert target_quat.shape == (4,)
|
||||
|
||||
@@ -293,12 +279,8 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
def _get_obs(self):
|
||||
robot_qpos, robot_qvel = robot_get_obs(self.sim)
|
||||
object_qvel = self.sim.data.get_joint_qvel("object:joint")
|
||||
achieved_goal = (
|
||||
self._get_achieved_goal().ravel()
|
||||
) # this contains the object position + rotation
|
||||
observation = np.concatenate(
|
||||
[robot_qpos, robot_qvel, object_qvel, achieved_goal]
|
||||
)
|
||||
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
|
||||
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, achieved_goal])
|
||||
return {
|
||||
"observation": observation.copy(),
|
||||
"achieved_goal": achieved_goal.copy(),
|
||||
@@ -307,9 +289,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
||||
|
||||
|
||||
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
||||
def __init__(
|
||||
self, target_position="random", target_rotation="xyz", reward_type="sparse"
|
||||
):
|
||||
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||
ManipulateEnv.__init__(
|
||||
self,
|
||||
@@ -322,9 +302,7 @@ class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
||||
|
||||
|
||||
class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
||||
def __init__(
|
||||
self, target_position="random", target_rotation="xyz", reward_type="sparse"
|
||||
):
|
||||
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||
ManipulateEnv.__init__(
|
||||
self,
|
||||
@@ -337,9 +315,7 @@ class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
||||
|
||||
|
||||
class HandPenEnv(ManipulateEnv, utils.EzPickle):
|
||||
def __init__(
|
||||
self, target_position="random", target_rotation="xyz", reward_type="sparse"
|
||||
):
|
||||
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||
ManipulateEnv.__init__(
|
||||
self,
|
||||
|
@@ -70,16 +70,12 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
||||
for (
|
||||
k,
|
||||
v,
|
||||
) in (
|
||||
self.sim.model._sensor_name2id.items()
|
||||
): # get touch sensor site names and their ids
|
||||
) in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
|
||||
if "robot0:TS_" in k:
|
||||
self._touch_sensor_id_site_id.append(
|
||||
(
|
||||
v,
|
||||
self.sim.model._site_name2id[
|
||||
k.replace("robot0:TS_", "robot0:T_")
|
||||
],
|
||||
self.sim.model._site_name2id[k.replace("robot0:TS_", "robot0:T_")],
|
||||
)
|
||||
)
|
||||
self._touch_sensor_id.append(v)
|
||||
@@ -93,15 +89,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
||||
obs = self._get_obs()
|
||||
self.observation_space = spaces.Dict(
|
||||
dict(
|
||||
desired_goal=spaces.Box(
|
||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
||||
),
|
||||
achieved_goal=spaces.Box(
|
||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
||||
),
|
||||
observation=spaces.Box(
|
||||
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
|
||||
),
|
||||
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -117,9 +107,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
||||
def _get_obs(self):
|
||||
robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim)
|
||||
object_qvel = self.sim.data.get_joint_qvel("object:joint")
|
||||
achieved_goal = (
|
||||
self._get_achieved_goal().ravel()
|
||||
) # this contains the object position + rotation
|
||||
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
|
||||
touch_values = [] # get touch sensor readings. if there is one, set value to 1
|
||||
if self.touch_get_obs == "sensordata":
|
||||
touch_values = self.sim.data.sensordata[self._touch_sensor_id]
|
||||
@@ -127,9 +115,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
||||
touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0
|
||||
elif self.touch_get_obs == "log":
|
||||
touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0)
|
||||
observation = np.concatenate(
|
||||
[robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal]
|
||||
)
|
||||
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal])
|
||||
|
||||
return {
|
||||
"observation": observation.copy(),
|
||||
@@ -146,9 +132,7 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||
touch_get_obs="sensordata",
|
||||
reward_type="sparse",
|
||||
):
|
||||
utils.EzPickle.__init__(
|
||||
self, target_position, target_rotation, touch_get_obs, reward_type
|
||||
)
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||
ManipulateTouchSensorsEnv.__init__(
|
||||
self,
|
||||
model_path=MANIPULATE_BLOCK_XML,
|
||||
@@ -168,9 +152,7 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||
touch_get_obs="sensordata",
|
||||
reward_type="sparse",
|
||||
):
|
||||
utils.EzPickle.__init__(
|
||||
self, target_position, target_rotation, touch_get_obs, reward_type
|
||||
)
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||
ManipulateTouchSensorsEnv.__init__(
|
||||
self,
|
||||
model_path=MANIPULATE_EGG_XML,
|
||||
@@ -190,9 +172,7 @@ class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||
touch_get_obs="sensordata",
|
||||
reward_type="sparse",
|
||||
):
|
||||
utils.EzPickle.__init__(
|
||||
self, target_position, target_rotation, touch_get_obs, reward_type
|
||||
)
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||
ManipulateTouchSensorsEnv.__init__(
|
||||
self,
|
||||
model_path=MANIPULATE_PEN_XML,
|
||||
|
@@ -96,9 +96,7 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
|
||||
self.sim.forward()
|
||||
|
||||
self.initial_goal = self._get_achieved_goal().copy()
|
||||
self.palm_xpos = self.sim.data.body_xpos[
|
||||
self.sim.model.body_name2id("robot0:palm")
|
||||
].copy()
|
||||
self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id("robot0:palm")].copy()
|
||||
|
||||
def _get_obs(self):
|
||||
robot_qpos, robot_qvel = robot_get_obs(self.sim)
|
||||
@@ -155,7 +153,5 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
|
||||
for finger_idx in range(5):
|
||||
site_name = "finger{}".format(finger_idx)
|
||||
site_id = self.sim.model.site_name2id(site_name)
|
||||
self.sim.model.site_pos[site_id] = (
|
||||
achieved_goal[finger_idx] - sites_offset[site_id]
|
||||
)
|
||||
self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id]
|
||||
self.sim.forward()
|
||||
|
@@ -30,22 +30,14 @@ class HandEnv(robot_env.RobotEnv):
|
||||
if self.relative_control:
|
||||
actuation_center = np.zeros_like(action)
|
||||
for i in range(self.sim.data.ctrl.shape[0]):
|
||||
actuation_center[i] = self.sim.data.get_joint_qpos(
|
||||
self.sim.model.actuator_names[i].replace(":A_", ":")
|
||||
)
|
||||
actuation_center[i] = self.sim.data.get_joint_qpos(self.sim.model.actuator_names[i].replace(":A_", ":"))
|
||||
for joint_name in ["FF", "MF", "RF", "LF"]:
|
||||
act_idx = self.sim.model.actuator_name2id(
|
||||
"robot0:A_{}J1".format(joint_name)
|
||||
)
|
||||
actuation_center[act_idx] += self.sim.data.get_joint_qpos(
|
||||
"robot0:{}J0".format(joint_name)
|
||||
)
|
||||
act_idx = self.sim.model.actuator_name2id("robot0:A_{}J1".format(joint_name))
|
||||
actuation_center[act_idx] += self.sim.data.get_joint_qpos("robot0:{}J0".format(joint_name))
|
||||
else:
|
||||
actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0
|
||||
self.sim.data.ctrl[:] = actuation_center + action * actuation_range
|
||||
self.sim.data.ctrl[:] = np.clip(
|
||||
self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]
|
||||
)
|
||||
self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1])
|
||||
|
||||
def _viewer_setup(self):
|
||||
body_id = self.sim.model.body_name2id("robot0:palm")
|
||||
|
@@ -46,15 +46,9 @@ class RobotEnv(gym.GoalEnv):
|
||||
self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32")
|
||||
self.observation_space = spaces.Dict(
|
||||
dict(
|
||||
desired_goal=spaces.Box(
|
||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
||||
),
|
||||
achieved_goal=spaces.Box(
|
||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
||||
),
|
||||
observation=spaces.Box(
|
||||
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
|
||||
),
|
||||
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
|
||||
)
|
||||
)
|
||||
|
||||
|
@@ -164,12 +164,8 @@ def mat2euler(mat):
|
||||
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
|
||||
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]),
|
||||
)
|
||||
euler[..., 1] = np.where(
|
||||
condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)
|
||||
)
|
||||
euler[..., 0] = np.where(
|
||||
condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0
|
||||
)
|
||||
euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy))
|
||||
euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0)
|
||||
return euler
|
||||
|
||||
|
||||
|
@@ -75,15 +75,9 @@ def reset_mocap2body_xpos(sim):
|
||||
values as the bodies they're welded to.
|
||||
"""
|
||||
|
||||
if (
|
||||
sim.model.eq_type is None
|
||||
or sim.model.eq_obj1id is None
|
||||
or sim.model.eq_obj2id is None
|
||||
):
|
||||
if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None:
|
||||
return
|
||||
for eq_type, obj1_id, obj2_id in zip(
|
||||
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id
|
||||
):
|
||||
for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id):
|
||||
if eq_type != mujoco_py.const.EQ_WELD:
|
||||
continue
|
||||
|
||||
|
@@ -2,10 +2,7 @@ from gym import envs, logger
|
||||
import os
|
||||
|
||||
|
||||
SKIP_MUJOCO_WARNING_MESSAGE = (
|
||||
"Cannot run mujoco test (either license key not found or mujoco not"
|
||||
"installed properly)."
|
||||
)
|
||||
SKIP_MUJOCO_WARNING_MESSAGE = "Cannot run mujoco test (either license key not found or mujoco not" "installed properly)."
|
||||
|
||||
|
||||
skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
|
||||
@@ -21,9 +18,7 @@ def should_skip_env_spec_for_tests(spec):
|
||||
# troublesome to run frequently
|
||||
ep = spec.entry_point
|
||||
# Skip mujoco tests for pull request CI
|
||||
if skip_mujoco and (
|
||||
ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")
|
||||
):
|
||||
if skip_mujoco and (ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")):
|
||||
return True
|
||||
try:
|
||||
import atari_py
|
||||
@@ -39,11 +34,7 @@ def should_skip_env_spec_for_tests(spec):
|
||||
if (
|
||||
"GoEnv" in ep
|
||||
or "HexEnv" in ep
|
||||
or (
|
||||
ep.startswith("gym.envs.atari")
|
||||
and not spec.id.startswith("Pong")
|
||||
and not spec.id.startswith("Seaquest")
|
||||
)
|
||||
or (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
|
||||
):
|
||||
logger.warn("Skipping tests for env {}".format(ep))
|
||||
return True
|
||||
|
@@ -25,9 +25,7 @@ def test_env(spec):
|
||||
step_responses2 = [env2.step(action) for action in action_samples2]
|
||||
env2.close()
|
||||
|
||||
for i, (action_sample1, action_sample2) in enumerate(
|
||||
zip(action_samples1, action_samples2)
|
||||
):
|
||||
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
|
||||
try:
|
||||
assert_equals(action_sample1, action_sample2)
|
||||
except AssertionError:
|
||||
@@ -35,11 +33,7 @@ def test_env(spec):
|
||||
print("env2.action_space=", env2.action_space)
|
||||
print("action_samples1=", action_samples1)
|
||||
print("action_samples2=", action_samples2)
|
||||
print(
|
||||
"[{}] action_sample1: {}, action_sample2: {}".format(
|
||||
i, action_sample1, action_sample2
|
||||
)
|
||||
)
|
||||
print("[{}] action_sample1: {}, action_sample2: {}".format(i, action_sample1, action_sample2))
|
||||
raise
|
||||
|
||||
# Don't check rollout equality if it's a a nondeterministic
|
||||
@@ -49,9 +43,7 @@ def test_env(spec):
|
||||
|
||||
assert_equals(initial_observation1, initial_observation2)
|
||||
|
||||
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
|
||||
zip(step_responses1, step_responses2)
|
||||
):
|
||||
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
|
||||
assert_equals(o1, o2, "[{}] ".format(i))
|
||||
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
|
||||
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
|
||||
@@ -66,9 +58,7 @@ def test_env(spec):
|
||||
def assert_equals(a, b, prefix=None):
|
||||
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
|
||||
if isinstance(a, dict):
|
||||
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(
|
||||
prefix, a, b
|
||||
)
|
||||
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b)
|
||||
|
||||
for k in a.keys():
|
||||
v_a = a[k]
|
||||
|
@@ -24,9 +24,7 @@ def test_env(spec):
|
||||
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
|
||||
a = act_space.sample()
|
||||
observation, reward, done, _info = env.step(a)
|
||||
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(
|
||||
observation
|
||||
)
|
||||
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(observation)
|
||||
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
|
||||
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)
|
||||
|
||||
|
@@ -81,9 +81,7 @@ def test_env_semantics(spec):
|
||||
if spec.id not in rollout_dict:
|
||||
if not spec.nondeterministic:
|
||||
logger.warn(
|
||||
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(
|
||||
spec.id
|
||||
)
|
||||
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id)
|
||||
)
|
||||
return
|
||||
|
||||
@@ -100,21 +98,15 @@ def test_env_semantics(spec):
|
||||
)
|
||||
if rollout_dict[spec.id]["actions"] != actions_now:
|
||||
errors.append(
|
||||
"Actions not equal for {} -- expected {} but got {}".format(
|
||||
spec.id, rollout_dict[spec.id]["actions"], actions_now
|
||||
)
|
||||
"Actions not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["actions"], actions_now)
|
||||
)
|
||||
if rollout_dict[spec.id]["rewards"] != rewards_now:
|
||||
errors.append(
|
||||
"Rewards not equal for {} -- expected {} but got {}".format(
|
||||
spec.id, rollout_dict[spec.id]["rewards"], rewards_now
|
||||
)
|
||||
"Rewards not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["rewards"], rewards_now)
|
||||
)
|
||||
if rollout_dict[spec.id]["dones"] != dones_now:
|
||||
errors.append(
|
||||
"Dones not equal for {} -- expected {} but got {}".format(
|
||||
spec.id, rollout_dict[spec.id]["dones"], dones_now
|
||||
)
|
||||
"Dones not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["dones"], dones_now)
|
||||
)
|
||||
if len(errors):
|
||||
for error in errors:
|
||||
|
@@ -4,9 +4,7 @@ from gym import envs
|
||||
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
|
||||
|
||||
|
||||
def verify_environments_match(
|
||||
old_environment_id, new_environment_id, seed=1, num_actions=1000
|
||||
):
|
||||
def verify_environments_match(old_environment_id, new_environment_id, seed=1, num_actions=1000):
|
||||
old_environment = envs.make(old_environment_id)
|
||||
new_environment = envs.make(new_environment_id)
|
||||
|
||||
|
@@ -83,8 +83,6 @@ def test_malformed_lookup():
|
||||
try:
|
||||
registry.spec(u"“Breakout-v0”")
|
||||
except error.Error as e:
|
||||
assert "malformed environment ID" in "{}".format(
|
||||
e
|
||||
), "Unexpected message: {}".format(e)
|
||||
assert "malformed environment ID" in "{}".format(e), "Unexpected message: {}".format(e)
|
||||
else:
|
||||
assert False
|
||||
|
@@ -75,9 +75,7 @@ class BlackjackEnv(gym.Env):
|
||||
|
||||
def __init__(self, natural=False):
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Tuple(
|
||||
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
||||
)
|
||||
self.observation_space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
||||
self.seed()
|
||||
|
||||
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
|
||||
|
@@ -141,9 +141,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
||||
else:
|
||||
if is_slippery:
|
||||
for b in [(a - 1) % 4, a, (a + 1) % 4]:
|
||||
li.append(
|
||||
(1.0 / 3.0, *update_probability_matrix(row, col, b))
|
||||
)
|
||||
li.append((1.0 / 3.0, *update_probability_matrix(row, col, b)))
|
||||
else:
|
||||
li.append((1.0, *update_probability_matrix(row, col, a)))
|
||||
|
||||
@@ -157,9 +155,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
||||
desc = [[c.decode("utf-8") for c in line] for line in desc]
|
||||
desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
|
||||
if self.lastaction is not None:
|
||||
outfile.write(
|
||||
" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction])
|
||||
)
|
||||
outfile.write(" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction]))
|
||||
else:
|
||||
outfile.write("\n")
|
||||
outfile.write("\n".join("".join(line) for line in desc) + "\n")
|
||||
|
@@ -80,11 +80,7 @@ class GuessingGame(gym.Env):
|
||||
reward = 0
|
||||
done = False
|
||||
|
||||
if (
|
||||
(self.number - self.range * 0.01)
|
||||
< action
|
||||
< (self.number + self.range * 0.01)
|
||||
):
|
||||
if (self.number - self.range * 0.01) < action < (self.number + self.range * 0.01):
|
||||
reward = 1
|
||||
done = True
|
||||
|
||||
|
@@ -62,10 +62,7 @@ class HotterColder(gym.Env):
|
||||
elif action > self.number:
|
||||
self.observation = 3
|
||||
|
||||
reward = (
|
||||
(min(action, self.number) + self.bounds)
|
||||
/ (max(action, self.number) + self.bounds)
|
||||
) ** 2
|
||||
reward = ((min(action, self.number) + self.bounds) / (max(action, self.number) + self.bounds)) ** 2
|
||||
|
||||
self.guess_count += 1
|
||||
done = self.guess_count >= self.guess_max
|
||||
|
@@ -65,9 +65,7 @@ class KellyCoinflipEnv(gym.Env):
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
bet_in_dollars = min(
|
||||
action / 100.0, self.wealth
|
||||
) # action = desired bet in pennies
|
||||
bet_in_dollars = min(action / 100.0, self.wealth) # action = desired bet in pennies
|
||||
self.rounds -= 1
|
||||
|
||||
coinflip = flip(self.edge, self.np_random)
|
||||
@@ -149,35 +147,19 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
|
||||
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
|
||||
if self.clip_distributions:
|
||||
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
|
||||
max_wealth_bound = round(
|
||||
genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m)
|
||||
)
|
||||
max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m))
|
||||
max_wealth = max_wealth_bound + 1.0
|
||||
while max_wealth > max_wealth_bound:
|
||||
max_wealth = round(
|
||||
genpareto.rvs(
|
||||
max_wealth_alpha, max_wealth_m, random_state=self.np_random
|
||||
)
|
||||
)
|
||||
max_rounds_bound = int(
|
||||
round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd))
|
||||
)
|
||||
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
|
||||
max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)))
|
||||
max_rounds = max_rounds_bound + 1
|
||||
while max_rounds > max_rounds_bound:
|
||||
max_rounds = int(
|
||||
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
|
||||
)
|
||||
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
|
||||
|
||||
else:
|
||||
max_wealth = round(
|
||||
genpareto.rvs(
|
||||
max_wealth_alpha, max_wealth_m, random_state=self.np_random
|
||||
)
|
||||
)
|
||||
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
|
||||
max_wealth_bound = max_wealth
|
||||
max_rounds = int(
|
||||
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
|
||||
)
|
||||
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
|
||||
max_rounds_bound = max_rounds
|
||||
|
||||
# add an additional global variable which is the sufficient statistic for the
|
||||
@@ -194,9 +176,7 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
|
||||
self.action_space = spaces.Discrete(int(max_wealth_bound * 100))
|
||||
self.observation_space = spaces.Tuple(
|
||||
(
|
||||
spaces.Box(
|
||||
0, max_wealth_bound, shape=[1], dtype=np.float32
|
||||
), # current wealth
|
||||
spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth
|
||||
spaces.Discrete(max_rounds_bound + 1), # rounds elapsed
|
||||
spaces.Discrete(max_rounds_bound + 1), # wins
|
||||
spaces.Discrete(max_rounds_bound + 1), # losses
|
||||
|
@@ -80,10 +80,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
max_col = num_columns - 1
|
||||
initial_state_distrib = np.zeros(num_states)
|
||||
num_actions = 6
|
||||
P = {
|
||||
state: {action: [] for action in range(num_actions)}
|
||||
for state in range(num_states)
|
||||
}
|
||||
P = {state: {action: [] for action in range(num_actions)} for state in range(num_states)}
|
||||
for row in range(num_rows):
|
||||
for col in range(num_columns):
|
||||
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
|
||||
@@ -94,9 +91,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
for action in range(num_actions):
|
||||
# defaults
|
||||
new_row, new_col, new_pass_idx = row, col, pass_idx
|
||||
reward = (
|
||||
-1
|
||||
) # default reward when there is no pickup/dropoff
|
||||
reward = -1 # default reward when there is no pickup/dropoff
|
||||
done = False
|
||||
taxi_loc = (row, col)
|
||||
|
||||
@@ -122,14 +117,10 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
new_pass_idx = locs.index(taxi_loc)
|
||||
else: # dropoff at wrong location
|
||||
reward = -10
|
||||
new_state = self.encode(
|
||||
new_row, new_col, new_pass_idx, dest_idx
|
||||
)
|
||||
new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
|
||||
P[state][action].append((1.0, new_state, reward, done))
|
||||
initial_state_distrib /= initial_state_distrib.sum()
|
||||
discrete.DiscreteEnv.__init__(
|
||||
self, num_states, num_actions, P, initial_state_distrib
|
||||
)
|
||||
discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, initial_state_distrib)
|
||||
|
||||
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
|
||||
# (5) 5, 5, 4
|
||||
@@ -165,13 +156,9 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
return "_" if x == " " else x
|
||||
|
||||
if pass_idx < 4:
|
||||
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
||||
out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True
|
||||
)
|
||||
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True)
|
||||
pi, pj = self.locs[pass_idx]
|
||||
out[1 + pi][2 * pj + 1] = utils.colorize(
|
||||
out[1 + pi][2 * pj + 1], "blue", bold=True
|
||||
)
|
||||
out[1 + pi][2 * pj + 1] = utils.colorize(out[1 + pi][2 * pj + 1], "blue", bold=True)
|
||||
else: # passenger in taxi
|
||||
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
||||
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
|
||||
@@ -181,13 +168,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
|
||||
outfile.write("\n".join(["".join(row) for row in out]) + "\n")
|
||||
if self.lastaction is not None:
|
||||
outfile.write(
|
||||
" ({})\n".format(
|
||||
["South", "North", "East", "West", "Pickup", "Dropoff"][
|
||||
self.lastaction
|
||||
]
|
||||
)
|
||||
)
|
||||
outfile.write(" ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
|
||||
else:
|
||||
outfile.write("\n")
|
||||
|
||||
|
@@ -55,9 +55,7 @@ class CubeCrash(gym.Env):
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
||||
)
|
||||
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
|
||||
self.reset()
|
||||
@@ -83,16 +81,9 @@ class CubeCrash(gym.Env):
|
||||
self.potential = None
|
||||
self.step_n = 0
|
||||
while 1:
|
||||
self.wall_color = (
|
||||
self.random_color() if self.use_random_colors else color_white
|
||||
)
|
||||
self.cube_color = (
|
||||
self.random_color() if self.use_random_colors else color_green
|
||||
)
|
||||
if (
|
||||
np.linalg.norm(self.wall_color - self.bg_color) < 50
|
||||
or np.linalg.norm(self.cube_color - self.bg_color) < 50
|
||||
):
|
||||
self.wall_color = self.random_color() if self.use_random_colors else color_white
|
||||
self.cube_color = self.random_color() if self.use_random_colors else color_green
|
||||
if np.linalg.norm(self.wall_color - self.bg_color) < 50 or np.linalg.norm(self.cube_color - self.bg_color) < 50:
|
||||
continue
|
||||
break
|
||||
return self.step(0)[0]
|
||||
@@ -117,9 +108,7 @@ class CubeCrash(gym.Env):
|
||||
self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1,
|
||||
:,
|
||||
] = self.bg_color
|
||||
obs[
|
||||
self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :
|
||||
] = self.cube_color
|
||||
obs[self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :] = self.cube_color
|
||||
if self.use_black_screen and self.step_n > 4:
|
||||
obs[:] = np.zeros((3,), dtype=np.uint8)
|
||||
|
||||
|
@@ -62,16 +62,12 @@ class MemorizeDigits(gym.Env):
|
||||
def __init__(self):
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
self.observation_space = spaces.Box(
|
||||
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
||||
)
|
||||
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
||||
self.action_space = spaces.Discrete(10)
|
||||
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
|
||||
for digit in range(10):
|
||||
for y in range(6):
|
||||
self.bogus_mnist[digit, y, :] = [
|
||||
ord(char) for char in bogus_mnist[digit][y]
|
||||
]
|
||||
self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]]
|
||||
self.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
@@ -93,9 +89,7 @@ class MemorizeDigits(gym.Env):
|
||||
self.color_bg = self.random_color() if self.use_random_colors else color_black
|
||||
self.step_n = 0
|
||||
while 1:
|
||||
self.color_digit = (
|
||||
self.random_color() if self.use_random_colors else color_white
|
||||
)
|
||||
self.color_digit = self.random_color() if self.use_random_colors else color_white
|
||||
if np.linalg.norm(self.color_digit - self.color_bg) < 50:
|
||||
continue
|
||||
break
|
||||
@@ -119,9 +113,7 @@ class MemorizeDigits(gym.Env):
|
||||
digit_img[:] = self.color_bg
|
||||
xxx = self.bogus_mnist[self.digit] == 42
|
||||
digit_img[xxx] = self.color_digit
|
||||
obs[
|
||||
self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
|
||||
] = digit_img
|
||||
obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img
|
||||
self.last_obs = obs
|
||||
return obs, reward, done, {}
|
||||
|
||||
|
@@ -102,10 +102,7 @@ class APIError(Error):
|
||||
try:
|
||||
http_body = http_body.decode("utf-8")
|
||||
except:
|
||||
http_body = (
|
||||
"<Could not decode body as utf-8. "
|
||||
"Please report to gym@openai.com>"
|
||||
)
|
||||
http_body = "<Could not decode body as utf-8. " "Please report to gym@openai.com>"
|
||||
|
||||
self._message = message
|
||||
self.http_body = http_body
|
||||
@@ -142,9 +139,7 @@ class InvalidRequestError(APIError):
|
||||
json_body=None,
|
||||
headers=None,
|
||||
):
|
||||
super(InvalidRequestError, self).__init__(
|
||||
message, http_body, http_status, json_body, headers
|
||||
)
|
||||
super(InvalidRequestError, self).__init__(message, http_body, http_status, json_body, headers)
|
||||
self.param = param
|
||||
|
||||
|
||||
|
@@ -29,26 +29,16 @@ class Box(Space):
|
||||
# determine shape if it isn't provided directly
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
assert (
|
||||
np.isscalar(low) or low.shape == shape
|
||||
), "low.shape doesn't match provided shape"
|
||||
assert (
|
||||
np.isscalar(high) or high.shape == shape
|
||||
), "high.shape doesn't match provided shape"
|
||||
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
|
||||
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
|
||||
elif not np.isscalar(low):
|
||||
shape = low.shape
|
||||
assert (
|
||||
np.isscalar(high) or high.shape == shape
|
||||
), "high.shape doesn't match low.shape"
|
||||
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
|
||||
elif not np.isscalar(high):
|
||||
shape = high.shape
|
||||
assert (
|
||||
np.isscalar(low) or low.shape == shape
|
||||
), "low.shape doesn't match high.shape"
|
||||
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
|
||||
else:
|
||||
raise ValueError(
|
||||
"shape must be provided or inferred from the shapes of low or high"
|
||||
)
|
||||
raise ValueError("shape must be provided or inferred from the shapes of low or high")
|
||||
|
||||
if np.isscalar(low):
|
||||
low = np.full(shape, low, dtype=dtype)
|
||||
@@ -70,9 +60,7 @@ class Box(Space):
|
||||
high_precision = _get_precision(self.high.dtype)
|
||||
dtype_precision = _get_precision(self.dtype)
|
||||
if min(low_precision, high_precision) > dtype_precision:
|
||||
logger.warn(
|
||||
"Box bound precision lowered by casting to {}".format(self.dtype)
|
||||
)
|
||||
logger.warn("Box bound precision lowered by casting to {}".format(self.dtype))
|
||||
self.low = self.low.astype(self.dtype)
|
||||
self.high = self.high.astype(self.dtype)
|
||||
|
||||
@@ -119,19 +107,11 @@ class Box(Space):
|
||||
# Vectorized sampling by interval type
|
||||
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
|
||||
|
||||
sample[low_bounded] = (
|
||||
self.np_random.exponential(size=low_bounded[low_bounded].shape)
|
||||
+ self.low[low_bounded]
|
||||
)
|
||||
sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded]
|
||||
|
||||
sample[upp_bounded] = (
|
||||
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
|
||||
+ self.high[upp_bounded]
|
||||
)
|
||||
sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded]
|
||||
|
||||
sample[bounded] = self.np_random.uniform(
|
||||
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
|
||||
)
|
||||
sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape)
|
||||
if self.dtype.kind == "i":
|
||||
sample = np.floor(sample)
|
||||
|
||||
@@ -140,9 +120,7 @@ class Box(Space):
|
||||
def contains(self, x):
|
||||
if isinstance(x, list):
|
||||
x = np.array(x) # Promote list to array for contains check
|
||||
return (
|
||||
x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
|
||||
)
|
||||
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
return np.array(sample_n).tolist()
|
||||
@@ -151,9 +129,7 @@ class Box(Space):
|
||||
return [np.asarray(sample) for sample in sample_n]
|
||||
|
||||
def __repr__(self):
|
||||
return "Box({}, {}, {}, {})".format(
|
||||
self.low.min(), self.high.max(), self.shape, self.dtype
|
||||
)
|
||||
return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
|
@@ -33,9 +33,7 @@ class Dict(Space):
|
||||
"""
|
||||
|
||||
def __init__(self, spaces=None, **spaces_kwargs):
|
||||
assert (spaces is None) or (
|
||||
not spaces_kwargs
|
||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||
assert (spaces is None) or (not spaces_kwargs), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||
if spaces is None:
|
||||
spaces = spaces_kwargs
|
||||
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
||||
@@ -44,12 +42,8 @@ class Dict(Space):
|
||||
spaces = OrderedDict(spaces)
|
||||
self.spaces = spaces
|
||||
for space in spaces.values():
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Values of the dict should be instances of gym.Space"
|
||||
super(Dict, self).__init__(
|
||||
None, None
|
||||
) # None for shape and dtype, since it'll require special handling
|
||||
assert isinstance(space, Space), "Values of the dict should be instances of gym.Space"
|
||||
super(Dict, self).__init__(None, None) # None for shape and dtype, since it'll require special handling
|
||||
|
||||
def seed(self, seed=None):
|
||||
[space.seed(seed) for space in self.spaces.values()]
|
||||
@@ -75,18 +69,11 @@ class Dict(Space):
|
||||
yield key
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"Dict("
|
||||
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
||||
+ ")"
|
||||
)
|
||||
return "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
# serialize as dict-repr of vectors
|
||||
return {
|
||||
key: space.to_jsonable([sample[key] for sample in sample_n])
|
||||
for key, space in self.spaces.items()
|
||||
}
|
||||
return {key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items()}
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
dict_of_list = {}
|
||||
|
@@ -22,9 +22,7 @@ class Discrete(Space):
|
||||
def contains(self, x):
|
||||
if isinstance(x, int):
|
||||
as_int = x
|
||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
||||
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
|
||||
):
|
||||
elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()):
|
||||
as_int = int(x)
|
||||
else:
|
||||
return False
|
||||
|
@@ -34,9 +34,7 @@ class MultiDiscrete(Space):
|
||||
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
|
||||
|
||||
def sample(self):
|
||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
||||
self.dtype
|
||||
)
|
||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||
|
||||
def contains(self, x):
|
||||
if isinstance(x, list):
|
||||
|
@@ -25,9 +25,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
],
|
||||
@@ -71,9 +69,7 @@ def test_roundtripping(space):
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
],
|
||||
|
@@ -28,9 +28,7 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, u
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
7,
|
||||
@@ -60,18 +58,14 @@ def test_flatdim(space, flatdim):
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_flatten_space_boxes(space):
|
||||
flat_space = utils.flatten_space(space)
|
||||
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(
|
||||
type(flat_space), Box
|
||||
)
|
||||
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
|
||||
flatdim = utils.flatdim(space)
|
||||
(single_dim,) = flat_space.shape
|
||||
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
||||
@@ -95,9 +89,7 @@ def test_flatten_space_boxes(space):
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
],
|
||||
@@ -107,9 +99,7 @@ def test_flat_space_contains_flat_points(space):
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
flat_space = utils.flatten_space(space)
|
||||
for i, flat_sample in enumerate(flattened_samples):
|
||||
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(
|
||||
i, flat_sample, flat_space
|
||||
)
|
||||
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(i, flat_sample, flat_space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -130,9 +120,7 @@ def test_flat_space_contains_flat_points(space):
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
],
|
||||
@@ -162,9 +150,7 @@ def test_flatten_dim(space):
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
],
|
||||
@@ -172,15 +158,9 @@ def test_flatten_dim(space):
|
||||
def test_flatten_roundtripping(space):
|
||||
some_samples = [space.sample() for _ in range(10)]
|
||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||
roundtripped_samples = [
|
||||
utils.unflatten(space, sample) for sample in flattened_samples
|
||||
]
|
||||
for i, (original, roundtripped) in enumerate(
|
||||
zip(some_samples, roundtripped_samples)
|
||||
):
|
||||
assert compare_nested(
|
||||
original, roundtripped
|
||||
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
|
||||
roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
|
||||
for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
|
||||
assert compare_nested(original, roundtripped), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
|
||||
|
||||
|
||||
def compare_nested(left, right):
|
||||
@@ -188,9 +168,7 @@ def compare_nested(left, right):
|
||||
return np.allclose(left, right)
|
||||
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||
res = len(left) == len(right)
|
||||
for ((left_key, left_value), (right_key, right_value)) in zip(
|
||||
left.items(), right.items()
|
||||
):
|
||||
for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
|
||||
if not res:
|
||||
return False
|
||||
res = left_key == right_key and compare_nested(left_value, right_value)
|
||||
@@ -238,9 +216,7 @@ Expecteded flattened types are based off:
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(5),
|
||||
"velocity": Box(
|
||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16
|
||||
),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16),
|
||||
}
|
||||
),
|
||||
np.float64,
|
||||
@@ -254,12 +230,8 @@ def test_dtypes(original_space, expected_flattened_dtype):
|
||||
flattened_sample = utils.flatten(original_space, original_sample)
|
||||
unflattened_sample = utils.unflatten(original_space, flattened_sample)
|
||||
|
||||
assert flattened_space.contains(
|
||||
flattened_sample
|
||||
), "Expected flattened_space to contain flattened_sample"
|
||||
assert (
|
||||
flattened_space.dtype == expected_flattened_dtype
|
||||
), "Expected flattened_space's dtype to equal " "{}".format(
|
||||
assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
|
||||
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " "{}".format(
|
||||
expected_flattened_dtype
|
||||
)
|
||||
|
||||
@@ -272,9 +244,10 @@ def test_dtypes(original_space, expected_flattened_dtype):
|
||||
|
||||
def compare_sample_types(original_space, original_sample, unflattened_sample):
|
||||
if isinstance(original_space, Discrete):
|
||||
assert isinstance(unflattened_sample, int), (
|
||||
"Expected unflattened_sample to be an int. unflattened_sample: "
|
||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
||||
assert isinstance(
|
||||
unflattened_sample, int
|
||||
), "Expected unflattened_sample to be an int. unflattened_sample: " "{} original_sample: {}".format(
|
||||
unflattened_sample, original_sample
|
||||
)
|
||||
elif isinstance(original_space, Tuple):
|
||||
for index in range(len(original_space)):
|
||||
|
@@ -13,9 +13,7 @@ class Tuple(Space):
|
||||
def __init__(self, spaces):
|
||||
self.spaces = spaces
|
||||
for space in spaces:
|
||||
assert isinstance(
|
||||
space, Space
|
||||
), "Elements of the tuple must be instances of gym.Space"
|
||||
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
|
||||
super(Tuple, self).__init__(None, None)
|
||||
|
||||
def seed(self, seed=None):
|
||||
@@ -38,21 +36,10 @@ class Tuple(Space):
|
||||
|
||||
def to_jsonable(self, sample_n):
|
||||
# serialize as list-repr of tuple of vectors
|
||||
return [
|
||||
space.to_jsonable([sample[i] for sample in sample_n])
|
||||
for i, space in enumerate(self.spaces)
|
||||
]
|
||||
return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)]
|
||||
|
||||
def from_jsonable(self, sample_n):
|
||||
return [
|
||||
sample
|
||||
for sample in zip(
|
||||
*[
|
||||
space.from_jsonable(sample_n[i])
|
||||
for i, space in enumerate(self.spaces)
|
||||
]
|
||||
)
|
||||
]
|
||||
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.spaces[index]
|
||||
|
@@ -49,9 +49,7 @@ def flatten(space, x):
|
||||
onehot[x] = 1
|
||||
return onehot
|
||||
elif isinstance(space, Tuple):
|
||||
return np.concatenate(
|
||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
||||
)
|
||||
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
||||
elif isinstance(space, Dict):
|
||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
elif isinstance(space, MultiBinary):
|
||||
@@ -79,17 +77,13 @@ def unflatten(space, x):
|
||||
elif isinstance(space, Tuple):
|
||||
dims = [flatdim(s) for s in space.spaces]
|
||||
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
||||
list_unflattened = [
|
||||
unflatten(s, flattened)
|
||||
for flattened, s in zip(list_flattened, space.spaces)
|
||||
]
|
||||
list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)]
|
||||
return tuple(list_unflattened)
|
||||
elif isinstance(space, Dict):
|
||||
dims = [flatdim(s) for s in space.spaces.values()]
|
||||
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
||||
list_unflattened = [
|
||||
(key, unflatten(s, flattened))
|
||||
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
||||
(key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
||||
]
|
||||
return OrderedDict(list_unflattened)
|
||||
elif isinstance(space, MultiBinary):
|
||||
|
@@ -88,11 +88,7 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N
|
||||
elif hasattr(env.unwrapped, "get_keys_to_action"):
|
||||
keys_to_action = env.unwrapped.get_keys_to_action()
|
||||
else:
|
||||
assert False, (
|
||||
env.spec.id
|
||||
+ " does not have explicit key to action mapping, "
|
||||
+ "please specify one manually"
|
||||
)
|
||||
assert False, env.spec.id + " does not have explicit key to action mapping, " + "please specify one manually"
|
||||
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
|
||||
|
||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||
@@ -172,9 +168,7 @@ class PlayPlot(object):
|
||||
for i, plot in enumerate(self.cur_plot):
|
||||
if plot is not None:
|
||||
plot.remove()
|
||||
self.cur_plot[i] = self.ax[i].scatter(
|
||||
range(xmin, xmax), list(self.data[i]), c="blue"
|
||||
)
|
||||
self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]), c="blue")
|
||||
self.ax[i].set_xlim(xmin, xmax)
|
||||
plt.pause(0.000001)
|
||||
|
||||
|
@@ -10,9 +10,7 @@ from gym import error
|
||||
|
||||
def np_random(seed=None):
|
||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||
raise error.Error(
|
||||
"Seed must be a non-negative integer or omitted, not {}".format(seed)
|
||||
)
|
||||
raise error.Error("Seed must be a non-negative integer or omitted, not {}".format(seed))
|
||||
|
||||
seed = create_seed(seed)
|
||||
|
||||
|
@@ -53,9 +53,7 @@ def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
|
||||
if wrappers is not None:
|
||||
if callable(wrappers):
|
||||
env = wrappers(env)
|
||||
elif isinstance(wrappers, Iterable) and all(
|
||||
[callable(w) for w in wrappers]
|
||||
):
|
||||
elif isinstance(wrappers, Iterable) and all([callable(w) for w in wrappers]):
|
||||
for wrapper in wrappers:
|
||||
env = wrapper(env)
|
||||
else:
|
||||
|
@@ -107,12 +107,8 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
if self.shared_memory:
|
||||
try:
|
||||
_obs_buffer = create_shared_memory(
|
||||
self.single_observation_space, n=self.num_envs, ctx=ctx
|
||||
)
|
||||
self.observations = read_from_shared_memory(
|
||||
_obs_buffer, self.single_observation_space, n=self.num_envs
|
||||
)
|
||||
_obs_buffer = create_shared_memory(self.single_observation_space, n=self.num_envs, ctx=ctx)
|
||||
self.observations = read_from_shared_memory(_obs_buffer, self.single_observation_space, n=self.num_envs)
|
||||
except CustomSpaceError:
|
||||
raise ValueError(
|
||||
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
||||
@@ -124,9 +120,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
)
|
||||
else:
|
||||
_obs_buffer = None
|
||||
self.observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
|
||||
|
||||
self.parent_pipes, self.processes = [], []
|
||||
self.error_queue = ctx.Queue()
|
||||
@@ -168,8 +162,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `seed` while waiting "
|
||||
"for a pending call to `{0}` to complete.".format(self._state.value),
|
||||
"Calling `seed` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
@@ -182,8 +175,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `reset_async` while waiting "
|
||||
"for a pending call to `{0}` to complete".format(self._state.value),
|
||||
"Calling `reset_async` while waiting " "for a pending call to `{0}` to complete".format(self._state.value),
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
@@ -214,8 +206,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
"The call to `reset_wait` has timed out after "
|
||||
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
||||
"The call to `reset_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
@@ -223,9 +214,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
results, self.observations, self.single_observation_space
|
||||
)
|
||||
self.observations = concatenate(results, self.observations, self.single_observation_space)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
|
||||
@@ -239,8 +228,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `step_async` while waiting "
|
||||
"for a pending call to `{0}` to complete.".format(self._state.value),
|
||||
"Calling `step_async` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
@@ -280,8 +268,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
"The call to `step_wait` has timed out after "
|
||||
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
||||
"The call to `step_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
@@ -290,9 +277,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
observations_list, rewards, dones, infos = zip(*results)
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
observations_list, self.observations, self.single_observation_space
|
||||
)
|
||||
self.observations = concatenate(observations_list, self.observations, self.single_observation_space)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
@@ -318,8 +303,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
try:
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
logger.warn(
|
||||
"Calling `close` while waiting for a pending "
|
||||
"call to `{0}` to complete.".format(self._state.value)
|
||||
"Calling `close` while waiting for a pending " "call to `{0}` to complete.".format(self._state.value)
|
||||
)
|
||||
function = getattr(self, "{0}_wait".format(self._state.value))
|
||||
function(timeout)
|
||||
@@ -375,8 +359,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
def _assert_is_running(self):
|
||||
if self.closed:
|
||||
raise ClosedEnvironmentError(
|
||||
"Trying to operate on `{0}`, after a "
|
||||
"call to `close()`.".format(type(self).__name__)
|
||||
"Trying to operate on `{0}`, after a " "call to `close()`.".format(type(self).__name__)
|
||||
)
|
||||
|
||||
def _raise_if_errors(self, successes):
|
||||
@@ -387,10 +370,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
assert num_errors > 0
|
||||
for _ in range(num_errors):
|
||||
index, exctype, value = self.error_queue.get()
|
||||
logger.error(
|
||||
"Received the following error from Worker-{0}: "
|
||||
"{1}: {2}".format(index, exctype.__name__, value)
|
||||
)
|
||||
logger.error("Received the following error from Worker-{0}: " "{1}: {2}".format(index, exctype.__name__, value))
|
||||
logger.error("Shutting down Worker-{0}.".format(index))
|
||||
self.parent_pipes[index].close()
|
||||
self.parent_pipes[index] = None
|
||||
@@ -445,17 +425,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
observation = env.reset()
|
||||
write_to_shared_memory(
|
||||
index, observation, shared_memory, observation_space
|
||||
)
|
||||
write_to_shared_memory(index, observation, shared_memory, observation_space)
|
||||
pipe.send((None, True))
|
||||
elif command == "step":
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
observation = env.reset()
|
||||
write_to_shared_memory(
|
||||
index, observation, shared_memory, observation_space
|
||||
)
|
||||
write_to_shared_memory(index, observation, shared_memory, observation_space)
|
||||
pipe.send(((None, reward, done, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
|
@@ -44,9 +44,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
)
|
||||
|
||||
self._check_observation_spaces()
|
||||
self.observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
||||
)
|
||||
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
|
||||
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
self._actions = None
|
||||
@@ -67,9 +65,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
for env in self.envs:
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
)
|
||||
self.observations = concatenate(observations, self.observations, self.single_observation_space)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
|
||||
@@ -84,9 +80,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos.append(info)
|
||||
self.observations = concatenate(
|
||||
observations, self.observations, self.single_observation_space
|
||||
)
|
||||
self.observations = concatenate(observations, self.observations, self.single_observation_space)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
|
@@ -10,9 +10,7 @@ from gym.vector.tests.utils import spaces
|
||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||
)
|
||||
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_concatenate(space):
|
||||
def assert_type(lhs, rhs, n):
|
||||
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
|
||||
@@ -53,9 +51,7 @@ def test_concatenate(space):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1, 8])
|
||||
@pytest.mark.parametrize(
|
||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||
)
|
||||
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_empty_array(space, n):
|
||||
def assert_nested_type(arr, space, n):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
@@ -83,9 +79,7 @@ def test_create_empty_array(space, n):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1, 8])
|
||||
@pytest.mark.parametrize(
|
||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||
)
|
||||
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_empty_array_zeros(space, n):
|
||||
def assert_nested_type(arr, space, n):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
@@ -113,9 +107,7 @@ def test_create_empty_array_zeros(space, n):
|
||||
assert_nested_type(array, space, n=n)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||
)
|
||||
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_empty_array_none_shape_ones(space):
|
||||
def assert_nested_type(arr, space):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
|
@@ -46,9 +46,7 @@ expected_types = [
|
||||
list(zip(spaces, expected_types)),
|
||||
ids=[space.__class__.__name__ for space in spaces],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
||||
)
|
||||
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
|
||||
def test_create_shared_memory(space, expected_type, n, ctx):
|
||||
def assert_nested_type(lhs, rhs, n):
|
||||
assert type(lhs) == type(rhs)
|
||||
@@ -77,9 +75,7 @@ def test_create_shared_memory(space, expected_type, n, ctx):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1, 8])
|
||||
@pytest.mark.parametrize(
|
||||
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
||||
)
|
||||
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
|
||||
@pytest.mark.parametrize("space", custom_spaces)
|
||||
def test_create_shared_memory_custom_space(n, ctx, space):
|
||||
ctx = mp if (ctx is None) else mp.get_context(ctx)
|
||||
@@ -87,9 +83,7 @@ def test_create_shared_memory_custom_space(n, ctx, space):
|
||||
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||
)
|
||||
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_write_to_shared_memory(space):
|
||||
def assert_nested_equal(lhs, rhs):
|
||||
assert isinstance(rhs, list)
|
||||
@@ -113,9 +107,7 @@ def test_write_to_shared_memory(space):
|
||||
shared_memory_n8 = create_shared_memory(space, n=8)
|
||||
samples = [space.sample() for _ in range(8)]
|
||||
|
||||
processes = [
|
||||
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
|
||||
]
|
||||
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
|
||||
|
||||
for process in processes:
|
||||
process.start()
|
||||
@@ -125,25 +117,19 @@ def test_write_to_shared_memory(space):
|
||||
assert_nested_equal(shared_memory_n8, samples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
||||
)
|
||||
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_read_from_shared_memory(space):
|
||||
def assert_nested_equal(lhs, rhs, space, n):
|
||||
assert isinstance(rhs, list)
|
||||
if isinstance(space, Tuple):
|
||||
assert isinstance(lhs, tuple)
|
||||
for i in range(len(lhs)):
|
||||
assert_nested_equal(
|
||||
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
|
||||
)
|
||||
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n)
|
||||
|
||||
elif isinstance(space, Dict):
|
||||
assert isinstance(lhs, OrderedDict)
|
||||
for key in lhs.keys():
|
||||
assert_nested_equal(
|
||||
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
|
||||
)
|
||||
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n)
|
||||
|
||||
elif isinstance(space, _BaseGymSpaces):
|
||||
assert isinstance(lhs, np.ndarray)
|
||||
@@ -161,9 +147,7 @@ def test_read_from_shared_memory(space):
|
||||
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
|
||||
samples = [space.sample() for _ in range(8)]
|
||||
|
||||
processes = [
|
||||
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
|
||||
]
|
||||
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
|
||||
|
||||
for process in processes:
|
||||
process.start()
|
||||
|
@@ -10,12 +10,8 @@ expected_batch_spaces_4 = [
|
||||
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
|
||||
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32),
|
||||
Box(
|
||||
low=np.array(
|
||||
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]
|
||||
),
|
||||
high=np.array(
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
|
||||
),
|
||||
low=np.array([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]),
|
||||
high=np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
|
||||
dtype=np.float32,
|
||||
),
|
||||
Box(
|
||||
|
@@ -7,12 +7,8 @@ from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
||||
spaces = [
|
||||
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
|
||||
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32),
|
||||
Box(
|
||||
low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32
|
||||
),
|
||||
Box(
|
||||
low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32
|
||||
),
|
||||
Box(low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32),
|
||||
Box(low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32),
|
||||
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
||||
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
||||
Discrete(2),
|
||||
@@ -28,17 +24,13 @@ spaces = [
|
||||
Dict(
|
||||
{
|
||||
"position": Discrete(23),
|
||||
"velocity": Box(
|
||||
low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32
|
||||
),
|
||||
"velocity": Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32),
|
||||
}
|
||||
),
|
||||
Dict(
|
||||
{
|
||||
"position": Dict({"x": Discrete(29), "y": Discrete(31)}),
|
||||
"velocity": Tuple(
|
||||
(Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))
|
||||
),
|
||||
"velocity": Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))),
|
||||
}
|
||||
),
|
||||
]
|
||||
@@ -50,9 +42,7 @@ class UnittestSlowEnv(gym.Env):
|
||||
def __init__(self, slow_reset=0.3):
|
||||
super(UnittestSlowEnv, self).__init__()
|
||||
self.slow_reset = slow_reset
|
||||
self.observation_space = Box(
|
||||
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
|
||||
)
|
||||
self.observation_space = Box(low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||||
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
|
@@ -46,10 +46,7 @@ def concatenate(items, out, space):
|
||||
elif isinstance(space, Space):
|
||||
return concatenate_custom(items, out, space)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Space of type `{0}` is not a valid `gym.Space` "
|
||||
"instance.".format(type(space))
|
||||
)
|
||||
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
|
||||
|
||||
|
||||
def concatenate_base(items, out, space):
|
||||
@@ -57,18 +54,12 @@ def concatenate_base(items, out, space):
|
||||
|
||||
|
||||
def concatenate_tuple(items, out, space):
|
||||
return tuple(
|
||||
concatenate([item[i] for item in items], out[i], subspace)
|
||||
for (i, subspace) in enumerate(space.spaces)
|
||||
)
|
||||
return tuple(concatenate([item[i] for item in items], out[i], subspace) for (i, subspace) in enumerate(space.spaces))
|
||||
|
||||
|
||||
def concatenate_dict(items, out, space):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, concatenate([item[key] for item in items], out[key], subspace))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
[(key, concatenate([item[key] for item in items], out[key], subspace)) for (key, subspace) in space.spaces.items()]
|
||||
)
|
||||
|
||||
|
||||
@@ -118,10 +109,7 @@ def create_empty_array(space, n=1, fn=np.zeros):
|
||||
elif isinstance(space, Space):
|
||||
return create_empty_array_custom(space, n=n, fn=fn)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Space of type `{0}` is not a valid `gym.Space` "
|
||||
"instance.".format(type(space))
|
||||
)
|
||||
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
|
||||
|
||||
|
||||
def create_empty_array_base(space, n=1, fn=np.zeros):
|
||||
@@ -134,12 +122,7 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros):
|
||||
|
||||
|
||||
def create_empty_array_dict(space, n=1, fn=np.zeros):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, create_empty_array(subspace, n=n, fn=fn))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn)) for (key, subspace) in space.spaces.items()])
|
||||
|
||||
|
||||
def create_empty_array_custom(space, n=1, fn=np.zeros):
|
||||
|
@@ -56,18 +56,11 @@ def create_base_shared_memory(space, n=1, ctx=mp):
|
||||
|
||||
|
||||
def create_tuple_shared_memory(space, n=1, ctx=mp):
|
||||
return tuple(
|
||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||
)
|
||||
return tuple(create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces)
|
||||
|
||||
|
||||
def create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, create_shared_memory(subspace, n=n, ctx=ctx))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
return OrderedDict([(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()])
|
||||
|
||||
|
||||
def read_from_shared_memory(shared_memory, space, n=1):
|
||||
@@ -114,24 +107,16 @@ def read_from_shared_memory(shared_memory, space, n=1):
|
||||
|
||||
|
||||
def read_base_from_shared_memory(shared_memory, space, n=1):
|
||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
||||
(n,) + space.shape
|
||||
)
|
||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape((n,) + space.shape)
|
||||
|
||||
|
||||
def read_tuple_from_shared_memory(shared_memory, space, n=1):
|
||||
return tuple(
|
||||
read_from_shared_memory(memory, subspace, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||
)
|
||||
return tuple(read_from_shared_memory(memory, subspace, n=n) for (memory, subspace) in zip(shared_memory, space.spaces))
|
||||
|
||||
|
||||
def read_dict_from_shared_memory(shared_memory, space, n=1):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
[(key, read_from_shared_memory(shared_memory[key], subspace, n=n)) for (key, subspace) in space.spaces.items()]
|
||||
)
|
||||
|
||||
|
||||
|
@@ -44,8 +44,7 @@ def batch_space(space, n=1):
|
||||
return batch_space_custom(space, n=n)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot batch space with type `{0}`. The space must "
|
||||
"be a valid `gym.Space` instance.".format(type(space))
|
||||
"Cannot batch space with type `{0}`. The space must " "be a valid `gym.Space` instance.".format(type(space))
|
||||
)
|
||||
|
||||
|
||||
@@ -75,14 +74,7 @@ def batch_space_tuple(space, n=1):
|
||||
|
||||
|
||||
def batch_space_dict(space, n=1):
|
||||
return Dict(
|
||||
OrderedDict(
|
||||
[
|
||||
(key, batch_space(subspace, n=n))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
)
|
||||
return Dict(OrderedDict([(key, batch_space(subspace, n=n)) for (key, subspace) in space.spaces.items()]))
|
||||
|
||||
|
||||
def batch_space_custom(space, n=1):
|
||||
|
@@ -141,9 +141,7 @@ class VectorEnv(gym.Env):
|
||||
if self.spec is None:
|
||||
return "{}({})".format(self.__class__.__name__, self.num_envs)
|
||||
else:
|
||||
return "{}({}, {})".format(
|
||||
self.__class__.__name__, self.spec.id, self.num_envs
|
||||
)
|
||||
return "{}({}, {})".format(self.__class__.__name__, self.spec.id, self.num_envs)
|
||||
|
||||
|
||||
class VectorEnvWrapper(VectorEnv):
|
||||
@@ -189,9 +187,7 @@ class VectorEnvWrapper(VectorEnv):
|
||||
# implicitly forward all other methods and attributes to self.env
|
||||
def __getattr__(self, name):
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(
|
||||
"attempted to get missing private attribute '{}'".format(name)
|
||||
)
|
||||
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
|
||||
return getattr(self.env, name)
|
||||
|
||||
@property
|
||||
|
@@ -62,12 +62,13 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
assert noop_max >= 0
|
||||
if frame_skip > 1:
|
||||
assert "NoFrameskip" in env.spec.id, (
|
||||
"disable frame-skipping in the original env. for more than one"
|
||||
" frame-skip as it will be done by the wrapper"
|
||||
"disable frame-skipping in the original env. for more than one" " frame-skip as it will be done by the wrapper"
|
||||
)
|
||||
self.noop_max = noop_max
|
||||
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
|
||||
self.frame_skip = frame_skip
|
||||
self.screen_size = screen_size
|
||||
@@ -92,15 +93,11 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
self.lives = 0
|
||||
self.game_over = False
|
||||
|
||||
_low, _high, _obs_dtype = (
|
||||
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
||||
)
|
||||
_low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
||||
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
|
||||
if grayscale_obs and not grayscale_newaxis:
|
||||
_shape = _shape[:-1] # Remove channel axis
|
||||
self.observation_space = Box(
|
||||
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
|
||||
)
|
||||
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
|
||||
|
||||
def step(self, action):
|
||||
R = 0.0
|
||||
@@ -132,11 +129,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
# NoopReset
|
||||
self.env.reset(**kwargs)
|
||||
noops = (
|
||||
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||
if self.noop_max > 0
|
||||
else 0
|
||||
)
|
||||
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
|
||||
for _ in range(noops):
|
||||
_, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
|
@@ -9,7 +9,9 @@ class ClipAction(ActionWrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
assert isinstance(env.action_space, Box)
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
super(ClipAction, self).__init__(env)
|
||||
|
||||
def action(self, action):
|
||||
|
@@ -26,7 +26,9 @@ class FilterObservation(ObservationWrapper):
|
||||
assert isinstance(
|
||||
wrapped_observation_space, spaces.Dict
|
||||
), "FilterObservationWrapper is only usable with dict observations."
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
|
||||
observation_keys = wrapped_observation_space.spaces.keys()
|
||||
|
||||
@@ -49,11 +51,7 @@ class FilterObservation(ObservationWrapper):
|
||||
)
|
||||
|
||||
self.observation_space = type(wrapped_observation_space)(
|
||||
[
|
||||
(name, copy.deepcopy(space))
|
||||
for name, space in wrapped_observation_space.spaces.items()
|
||||
if name in filter_keys
|
||||
]
|
||||
[(name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys]
|
||||
)
|
||||
|
||||
self._env = env
|
||||
@@ -64,11 +62,5 @@ class FilterObservation(ObservationWrapper):
|
||||
return filter_observation
|
||||
|
||||
def _filter_observation(self, observation):
|
||||
observation = type(observation)(
|
||||
[
|
||||
(name, value)
|
||||
for name, value in observation.items()
|
||||
if name in self._filter_keys
|
||||
]
|
||||
)
|
||||
observation = type(observation)([(name, value) for name, value in observation.items() if name in self._filter_keys])
|
||||
return observation
|
||||
|
@@ -9,7 +9,9 @@ class FlattenObservation(ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super(FlattenObservation, self).__init__(env)
|
||||
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
|
||||
def observation(self, observation):
|
||||
return spaces.flatten(self.env.observation_space, observation)
|
||||
|
@@ -22,7 +22,9 @@ class LazyFrames(object):
|
||||
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
|
||||
|
||||
def __init__(self, frames, lz4_compress=False):
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
self.frame_shape = tuple(frames[0].shape)
|
||||
self.shape = (len(frames),) + self.frame_shape
|
||||
self.dtype = frames[0].dtype
|
||||
@@ -45,9 +47,7 @@ class LazyFrames(object):
|
||||
def __getitem__(self, int_or_slice):
|
||||
if isinstance(int_or_slice, int):
|
||||
return self._check_decompress(self._frames[int_or_slice]) # single frame
|
||||
return np.stack(
|
||||
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
|
||||
)
|
||||
return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__array__() == other
|
||||
@@ -56,9 +56,7 @@ class LazyFrames(object):
|
||||
if self.lz4_compress:
|
||||
from lz4.block import decompress
|
||||
|
||||
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
|
||||
self.frame_shape
|
||||
)
|
||||
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape)
|
||||
return frame
|
||||
|
||||
|
||||
@@ -102,12 +100,8 @@ class FrameStack(Wrapper):
|
||||
self.frames = deque(maxlen=num_stack)
|
||||
|
||||
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
|
||||
high = np.repeat(
|
||||
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
|
||||
)
|
||||
self.observation_space = Box(
|
||||
low=low, high=high, dtype=self.observation_space.dtype
|
||||
)
|
||||
high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
|
||||
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
|
||||
|
||||
def _get_observation(self):
|
||||
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
|
||||
|
@@ -11,20 +11,15 @@ class GrayScaleObservation(ObservationWrapper):
|
||||
super(GrayScaleObservation, self).__init__(env)
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
assert (
|
||||
len(env.observation_space.shape) == 3
|
||||
and env.observation_space.shape[-1] == 3
|
||||
assert len(env.observation_space.shape) == 3 and env.observation_space.shape[-1] == 3
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
obs_shape = self.observation_space.shape[:2]
|
||||
if self.keep_dim:
|
||||
self.observation_space = Box(
|
||||
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
|
||||
)
|
||||
self.observation_space = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8)
|
||||
else:
|
||||
self.observation_space = Box(
|
||||
low=0, high=255, shape=obs_shape, dtype=np.uint8
|
||||
)
|
||||
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
||||
|
||||
def observation(self, observation):
|
||||
import cv2
|
||||
|
@@ -37,9 +37,7 @@ class Monitor(Wrapper):
|
||||
self._monitor_id = None
|
||||
self.env_semantics_autoreset = env.metadata.get("semantics.autoreset")
|
||||
|
||||
self._start(
|
||||
directory, video_callable, force, resume, write_upon_reset, uid, mode
|
||||
)
|
||||
self._start(directory, video_callable, force, resume, write_upon_reset, uid, mode)
|
||||
|
||||
def step(self, action):
|
||||
self._before_step(action)
|
||||
@@ -163,10 +161,7 @@ class Monitor(Wrapper):
|
||||
json.dump(
|
||||
{
|
||||
"stats": os.path.basename(self.stats_recorder.path),
|
||||
"videos": [
|
||||
(os.path.basename(v), os.path.basename(m))
|
||||
for v, m in self.videos
|
||||
],
|
||||
"videos": [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos],
|
||||
"env_info": self._env_info(),
|
||||
},
|
||||
f,
|
||||
@@ -199,9 +194,7 @@ class Monitor(Wrapper):
|
||||
elif mode == "training":
|
||||
type = "t"
|
||||
else:
|
||||
raise error.Error(
|
||||
'Invalid mode {}: must be "training" or "evaluation"', mode
|
||||
)
|
||||
raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
|
||||
self.stats_recorder.type = type
|
||||
|
||||
def _before_step(self, action):
|
||||
@@ -257,9 +250,7 @@ class Monitor(Wrapper):
|
||||
env=self.env,
|
||||
base_path=os.path.join(
|
||||
self.directory,
|
||||
"{}.video.{}.video{:06}".format(
|
||||
self.file_prefix, self.file_infix, self.episode_id
|
||||
),
|
||||
"{}.video.{}.video{:06}".format(self.file_prefix, self.file_infix, self.episode_id),
|
||||
),
|
||||
metadata={"episode_id": self.episode_id},
|
||||
enabled=self._video_enabled(),
|
||||
@@ -269,9 +260,7 @@ class Monitor(Wrapper):
|
||||
def _close_video_recorder(self):
|
||||
self.video_recorder.close()
|
||||
if self.video_recorder.functional:
|
||||
self.videos.append(
|
||||
(self.video_recorder.path, self.video_recorder.metadata_path)
|
||||
)
|
||||
self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
|
||||
|
||||
def _video_enabled(self):
|
||||
return self.video_callable(self.episode_id)
|
||||
@@ -301,19 +290,11 @@ class Monitor(Wrapper):
|
||||
def detect_training_manifests(training_dir, files=None):
|
||||
if files is None:
|
||||
files = os.listdir(training_dir)
|
||||
return [
|
||||
os.path.join(training_dir, f)
|
||||
for f in files
|
||||
if f.startswith(MANIFEST_PREFIX + ".")
|
||||
]
|
||||
return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + ".")]
|
||||
|
||||
|
||||
def detect_monitor_files(training_dir):
|
||||
return [
|
||||
os.path.join(training_dir, f)
|
||||
for f in os.listdir(training_dir)
|
||||
if f.startswith(FILE_PREFIX + ".")
|
||||
]
|
||||
return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + ".")]
|
||||
|
||||
|
||||
def clear_monitor_files(training_dir):
|
||||
@@ -382,10 +363,7 @@ def load_results(training_dir):
|
||||
contents = json.load(f)
|
||||
# Make these paths absolute again
|
||||
stats_files.append(os.path.join(training_dir, contents["stats"]))
|
||||
videos += [
|
||||
(os.path.join(training_dir, v), os.path.join(training_dir, m))
|
||||
for v, m in contents["videos"]
|
||||
]
|
||||
videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m)) for v, m in contents["videos"]]
|
||||
env_infos.append(contents["env_info"])
|
||||
|
||||
env_info = collapse_env_infos(env_infos, training_dir)
|
||||
|
@@ -48,9 +48,7 @@ class VideoRecorder(object):
|
||||
self.ansi_mode = True
|
||||
else:
|
||||
logger.info(
|
||||
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(
|
||||
env
|
||||
)
|
||||
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env)
|
||||
)
|
||||
# Whoops, turns out we shouldn't be enabled after all
|
||||
self.enabled = False
|
||||
@@ -69,9 +67,7 @@ class VideoRecorder(object):
|
||||
path = base_path + required_ext
|
||||
else:
|
||||
# Otherwise, just generate a unique filename
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=required_ext, delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f:
|
||||
path = f.name
|
||||
self.path = path
|
||||
|
||||
@@ -83,28 +79,20 @@ class VideoRecorder(object):
|
||||
if self.ansi_mode
|
||||
else ""
|
||||
)
|
||||
raise error.Error(
|
||||
"Invalid path given: {} -- must have file extension {}.{}".format(
|
||||
self.path, required_ext, hint
|
||||
)
|
||||
)
|
||||
raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint))
|
||||
# Touch the file in any case, so we know it's present. (This
|
||||
# corrects for platform platform differences. Using ffmpeg on
|
||||
# OS X, the file is precreated, but not on Linux.
|
||||
touch(path)
|
||||
|
||||
self.frames_per_sec = env.metadata.get("video.frames_per_second", 30)
|
||||
self.output_frames_per_sec = env.metadata.get(
|
||||
"video.output_frames_per_second", self.frames_per_sec
|
||||
)
|
||||
self.output_frames_per_sec = env.metadata.get("video.output_frames_per_second", self.frames_per_sec)
|
||||
self.encoder = None # lazily start the process
|
||||
self.broken = False
|
||||
|
||||
# Dump metadata
|
||||
self.metadata = metadata or {}
|
||||
self.metadata["content_type"] = (
|
||||
"video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
|
||||
)
|
||||
self.metadata["content_type"] = "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
|
||||
self.metadata_path = "{}.meta.json".format(path_base)
|
||||
self.write_metadata()
|
||||
|
||||
@@ -191,9 +179,7 @@ class VideoRecorder(object):
|
||||
|
||||
def _encode_image_frame(self, frame):
|
||||
if not self.encoder:
|
||||
self.encoder = ImageEncoder(
|
||||
self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
|
||||
)
|
||||
self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec)
|
||||
self.metadata["encoder_version"] = self.encoder.version_info
|
||||
|
||||
try:
|
||||
@@ -222,24 +208,16 @@ class TextEncoder(object):
|
||||
string = frame.getvalue()
|
||||
else:
|
||||
raise error.InvalidFrame(
|
||||
"Wrong type {} for {}: text frame must be a string or StringIO".format(
|
||||
type(frame), frame
|
||||
)
|
||||
"Wrong type {} for {}: text frame must be a string or StringIO".format(type(frame), frame)
|
||||
)
|
||||
|
||||
frame_bytes = string.encode("utf-8")
|
||||
|
||||
if frame_bytes[-1:] != b"\n":
|
||||
raise error.InvalidFrame(
|
||||
'Frame must end with a newline: """{}"""'.format(string)
|
||||
)
|
||||
raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string))
|
||||
|
||||
if b"\r" in frame_bytes:
|
||||
raise error.InvalidFrame(
|
||||
'Frame contains carriage returns (only newlines are allowed: """{}"""'.format(
|
||||
string
|
||||
)
|
||||
)
|
||||
raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string))
|
||||
|
||||
self.frames.append(frame_bytes)
|
||||
|
||||
@@ -263,15 +241,7 @@ class TextEncoder(object):
|
||||
# Calculate frame size from the largest frames.
|
||||
# Add some padding since we'll get cut off otherwise.
|
||||
height = max([frame.count(b"\n") for frame in self.frames]) + 1
|
||||
width = (
|
||||
max(
|
||||
[
|
||||
max([len(line) for line in frame.split(b"\n")])
|
||||
for frame in self.frames
|
||||
]
|
||||
)
|
||||
+ 2
|
||||
)
|
||||
width = max([max([len(line) for line in frame.split(b"\n")]) for frame in self.frames]) + 2
|
||||
|
||||
data = {
|
||||
"version": 1,
|
||||
@@ -325,11 +295,7 @@ class ImageEncoder(object):
|
||||
def version_info(self):
|
||||
return {
|
||||
"backend": self.backend,
|
||||
"version": str(
|
||||
subprocess.check_output(
|
||||
[self.backend, "-version"], stderr=subprocess.STDOUT
|
||||
)
|
||||
),
|
||||
"version": str(subprocess.check_output([self.backend, "-version"], stderr=subprocess.STDOUT)),
|
||||
"cmdline": self.cmdline,
|
||||
}
|
||||
|
||||
@@ -396,19 +362,13 @@ class ImageEncoder(object):
|
||||
|
||||
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
|
||||
if hasattr(os, "setsid"): # setsid not present on Windows
|
||||
self.proc = subprocess.Popen(
|
||||
self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
|
||||
)
|
||||
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
|
||||
else:
|
||||
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
|
||||
|
||||
def capture_frame(self, frame):
|
||||
if not isinstance(frame, (np.ndarray, np.generic)):
|
||||
raise error.InvalidFrame(
|
||||
"Wrong type {} for {} (must be np.ndarray or np.generic)".format(
|
||||
type(frame), frame
|
||||
)
|
||||
)
|
||||
raise error.InvalidFrame("Wrong type {} for {} (must be np.ndarray or np.generic)".format(type(frame), frame))
|
||||
if frame.shape != self.frame_shape:
|
||||
raise error.InvalidFrame(
|
||||
"Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(
|
||||
@@ -417,15 +377,11 @@ class ImageEncoder(object):
|
||||
)
|
||||
if frame.dtype != np.uint8:
|
||||
raise error.InvalidFrame(
|
||||
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(
|
||||
frame.dtype
|
||||
)
|
||||
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype)
|
||||
)
|
||||
|
||||
try:
|
||||
if distutils.version.LooseVersion(
|
||||
np.__version__
|
||||
) >= distutils.version.LooseVersion("1.9.0"):
|
||||
if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion("1.9.0"):
|
||||
self.proc.stdin.write(frame.tobytes())
|
||||
else:
|
||||
self.proc.stdin.write(frame.tostring())
|
||||
|
@@ -14,9 +14,7 @@ STATE_KEY = "state"
|
||||
class PixelObservationWrapper(ObservationWrapper):
|
||||
"""Augment observations by pixel values."""
|
||||
|
||||
def __init__(
|
||||
self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
|
||||
):
|
||||
def __init__(self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)):
|
||||
"""Initializes a new pixel Wrapper.
|
||||
|
||||
Args:
|
||||
@@ -52,7 +50,9 @@ class PixelObservationWrapper(ObservationWrapper):
|
||||
assert render_mode == "rgb_array", render_mode
|
||||
render_kwargs[key]["mode"] = "rgb_array"
|
||||
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
|
||||
wrapped_observation_space = env.observation_space
|
||||
|
||||
@@ -70,9 +70,7 @@ class PixelObservationWrapper(ObservationWrapper):
|
||||
# `observation_keys`
|
||||
overlapping_keys = set(pixel_keys) & set(invalid_keys)
|
||||
if overlapping_keys:
|
||||
raise ValueError(
|
||||
"Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)
|
||||
)
|
||||
raise ValueError("Duplicate or reserved pixel keys {!r}.".format(overlapping_keys))
|
||||
|
||||
if pixels_only:
|
||||
self.observation_space = spaces.Dict()
|
||||
@@ -95,9 +93,7 @@ class PixelObservationWrapper(ObservationWrapper):
|
||||
else:
|
||||
raise TypeError(pixels.dtype)
|
||||
|
||||
pixels_space = spaces.Box(
|
||||
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
|
||||
)
|
||||
pixels_space = spaces.Box(shape=pixels.shape, low=low, high=high, dtype=pixels.dtype)
|
||||
pixels_spaces[pixel_key] = pixels_space
|
||||
|
||||
self.observation_space.spaces.update(pixels_spaces)
|
||||
@@ -120,10 +116,7 @@ class PixelObservationWrapper(ObservationWrapper):
|
||||
observation = collections.OrderedDict()
|
||||
observation[STATE_KEY] = wrapped_observation
|
||||
|
||||
pixel_observations = {
|
||||
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
|
||||
for pixel_key in self._pixel_keys
|
||||
}
|
||||
pixel_observations = {pixel_key: self.env.render(**self._render_kwargs[pixel_key]) for pixel_key in self._pixel_keys}
|
||||
|
||||
observation.update(pixel_observations)
|
||||
|
||||
|
@@ -7,14 +7,14 @@ import gym
|
||||
class RecordEpisodeStatistics(gym.Wrapper):
|
||||
def __init__(self, env, deque_size=100):
|
||||
super(RecordEpisodeStatistics, self).__init__(env)
|
||||
self.t0 = (
|
||||
time.time()
|
||||
) # TODO: use perf_counter when gym removes Python 2 support
|
||||
self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support
|
||||
self.episode_return = 0.0
|
||||
self.episode_length = 0
|
||||
self.return_queue = deque(maxlen=deque_size)
|
||||
self.length_queue = deque(maxlen=deque_size)
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
||||
@@ -23,9 +23,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
|
||||
action
|
||||
)
|
||||
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(action)
|
||||
self.episode_return += reward
|
||||
self.episode_length += 1
|
||||
if done:
|
||||
|
@@ -15,17 +15,15 @@ class RescaleAction(gym.ActionWrapper):
|
||||
"""
|
||||
|
||||
def __init__(self, env, a, b):
|
||||
assert isinstance(
|
||||
env.action_space, spaces.Box
|
||||
), "expected Box action space, got {}".format(type(env.action_space))
|
||||
assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space))
|
||||
assert np.less_equal(a, b).all(), (a, b)
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
super(RescaleAction, self).__init__(env)
|
||||
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
|
||||
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
|
||||
self.action_space = spaces.Box(
|
||||
low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype
|
||||
)
|
||||
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
|
||||
|
||||
def action(self, action):
|
||||
assert np.all(np.greater_equal(action, self.a)), (action, self.a)
|
||||
|
@@ -12,7 +12,9 @@ class ResizeObservation(ObservationWrapper):
|
||||
if isinstance(shape, int):
|
||||
shape = (shape, shape)
|
||||
assert all(x > 0 for x in shape), shape
|
||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
||||
warnings.warn(
|
||||
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||
)
|
||||
self.shape = tuple(shape)
|
||||
|
||||
obs_shape = self.shape + self.observation_space.shape[2:]
|
||||
@@ -21,9 +23,7 @@ class ResizeObservation(ObservationWrapper):
|
||||
def observation(self, observation):
|
||||
import cv2
|
||||
|
||||
observation = cv2.resize(
|
||||
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
|
||||
)
|
||||
observation = cv2.resize(observation, self.shape[::-1], interpolation=cv2.INTER_AREA)
|
||||
if observation.ndim == 2:
|
||||
observation = np.expand_dims(observation, -1)
|
||||
return observation
|
||||
|
@@ -15,12 +15,8 @@ def test_atari_preprocessing_grayscale(env_fn):
|
||||
import cv2
|
||||
|
||||
env1 = env_fn()
|
||||
env2 = AtariPreprocessing(
|
||||
env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0
|
||||
)
|
||||
env3 = AtariPreprocessing(
|
||||
env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0
|
||||
)
|
||||
env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
|
||||
env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
|
||||
env4 = AtariPreprocessing(
|
||||
env_fn(),
|
||||
screen_size=84,
|
||||
@@ -79,15 +75,11 @@ def test_atari_preprocessing_scale(env_fn):
|
||||
obs = env.reset().flatten()
|
||||
done, step_i = False, 0
|
||||
max_obs = 1 if scaled else 255
|
||||
assert (0 <= obs).all() and (
|
||||
obs <= max_obs
|
||||
).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
||||
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
||||
while not done or step_i <= max_test_steps:
|
||||
obs, _, done, _ = env.step(env.action_space.sample())
|
||||
obs = obs.flatten()
|
||||
assert (0 <= obs).all() and (
|
||||
obs <= max_obs
|
||||
).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
||||
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
||||
step_i += 1
|
||||
|
||||
env.close()
|
||||
|
@@ -19,9 +19,7 @@ def test_clip_action():
|
||||
|
||||
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
||||
for action in actions:
|
||||
obs1, r1, d1, _ = env.step(
|
||||
np.clip(action, env.action_space.low, env.action_space.high)
|
||||
)
|
||||
obs1, r1, d1, _ = env.step(np.clip(action, env.action_space.low, env.action_space.high))
|
||||
obs2, r2, d2, _ = wrapped_env.step(action)
|
||||
assert np.allclose(r1, r2)
|
||||
assert np.allclose(obs1, obs2)
|
||||
|
@@ -9,10 +9,7 @@ from gym.wrappers.filter_observation import FilterObservation
|
||||
class FakeEnvironment(gym.Env):
|
||||
def __init__(self, observation_keys=("state")):
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
|
||||
for name in observation_keys
|
||||
}
|
||||
{name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) for name in observation_keys}
|
||||
)
|
||||
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
|
||||
|
||||
@@ -48,9 +45,7 @@ ERROR_TEST_CASES = (
|
||||
|
||||
|
||||
class TestFilterObservation(object):
|
||||
@pytest.mark.parametrize(
|
||||
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
|
||||
)
|
||||
@pytest.mark.parametrize("observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES)
|
||||
def test_filter_observation(self, observation_keys, filter_keys):
|
||||
env = FakeEnvironment(observation_keys=observation_keys)
|
||||
|
||||
@@ -73,9 +68,7 @@ class TestFilterObservation(object):
|
||||
assert len(observation) == len(filter_keys)
|
||||
|
||||
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
|
||||
def test_raises_with_incorrect_arguments(
|
||||
self, filter_keys, error_type, error_match
|
||||
):
|
||||
def test_raises_with_incorrect_arguments(self, filter_keys, error_type, error_match):
|
||||
env = FakeEnvironment(observation_keys=("key1", "key2"))
|
||||
|
||||
ValueError
|
||||
|
@@ -16,14 +16,10 @@ def test_flatten_observation(env_id):
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
|
||||
if env_id == "Blackjack-v0":
|
||||
space = spaces.Tuple(
|
||||
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
||||
)
|
||||
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
||||
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
|
||||
elif env_id == "KellyCoinflip-v0":
|
||||
space = spaces.Tuple(
|
||||
(spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))
|
||||
)
|
||||
space = spaces.Tuple((spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)))
|
||||
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
|
||||
|
||||
assert space.contains(obs)
|
||||
|
@@ -19,9 +19,7 @@ except ImportError:
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
lz4 is None, reason="Need lz4 to run tests with compression"
|
||||
),
|
||||
marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression"),
|
||||
),
|
||||
False,
|
||||
],
|
||||
|
@@ -10,9 +10,7 @@ pytest.importorskip("atari_py")
|
||||
pytest.importorskip("cv2")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
|
||||
)
|
||||
@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"])
|
||||
@pytest.mark.parametrize("keep_dim", [True, False])
|
||||
def test_gray_scale_observation(env_id, keep_dim):
|
||||
gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True)
|
||||
|
@@ -32,9 +32,7 @@ class FakeEnvironment(gym.Env):
|
||||
|
||||
class FakeArrayObservationEnvironment(FakeEnvironment):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.observation_space = spaces.Box(
|
||||
shape=(2,), low=-1, high=1, dtype=np.float32
|
||||
)
|
||||
self.observation_space = spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
|
||||
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -75,10 +73,7 @@ class TestPixelObservationWrapper(object):
|
||||
assert len(wrapped_env.observation_space.spaces) == 1
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||
else:
|
||||
assert (
|
||||
len(wrapped_env.observation_space.spaces)
|
||||
== len(observation_space.spaces) + 1
|
||||
)
|
||||
assert len(wrapped_env.observation_space.spaces) == len(observation_space.spaces) + 1
|
||||
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||
|
||||
@@ -97,9 +92,7 @@ class TestPixelObservationWrapper(object):
|
||||
observation_space = env.observation_space
|
||||
assert isinstance(observation_space, spaces.Box)
|
||||
|
||||
wrapped_env = PixelObservationWrapper(
|
||||
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
|
||||
)
|
||||
wrapped_env = PixelObservationWrapper(env, pixel_keys=(pixel_key,), pixels_only=pixels_only)
|
||||
wrapped_env.observation_space = wrapped_env.observation_space
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user