mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
redo black
This commit is contained in:
@@ -3,9 +3,7 @@ import argparse
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.")
|
||||||
description="Renders a Gym environment for quick inspection."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"env_id",
|
"env_id",
|
||||||
type=str,
|
type=str,
|
||||||
|
@@ -35,12 +35,7 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0):
|
|||||||
th_std = np.ones_like(th_mean) * initial_std
|
th_std = np.ones_like(th_mean) * initial_std
|
||||||
|
|
||||||
for _ in range(n_iter):
|
for _ in range(n_iter):
|
||||||
ths = np.array(
|
ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)])
|
||||||
[
|
|
||||||
th_mean + dth
|
|
||||||
for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
ys = np.array([f(th) for th in ths])
|
ys = np.array([f(th) for th in ths])
|
||||||
elite_inds = ys.argsort()[::-1][:n_elite]
|
elite_inds = ys.argsort()[::-1][:n_elite]
|
||||||
elite_ths = ths[elite_inds]
|
elite_ths = ths[elite_inds]
|
||||||
@@ -101,9 +96,7 @@ if __name__ == "__main__":
|
|||||||
return rew
|
return rew
|
||||||
|
|
||||||
# Train the agent, and snapshot each stage
|
# Train the agent, and snapshot each stage
|
||||||
for (i, iterdata) in enumerate(
|
for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)):
|
||||||
cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)
|
|
||||||
):
|
|
||||||
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
|
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
|
||||||
agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
|
agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
|
||||||
if args.display:
|
if args.display:
|
||||||
|
@@ -17,9 +17,7 @@ class RandomAgent(object):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description=None)
|
parser = argparse.ArgumentParser(description=None)
|
||||||
parser.add_argument(
|
parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run")
|
||||||
"env_id", nargs="?", default="CartPole-v0", help="Select the environment to run"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# You can set the level to logger.DEBUG or logger.WARN if you
|
# You can set the level to logger.DEBUG or logger.WARN if you
|
||||||
|
14
gym/core.py
14
gym/core.py
@@ -173,16 +173,10 @@ class GoalEnv(Env):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
# Enforce that each GoalEnv uses a Goal-compatible observation space.
|
# Enforce that each GoalEnv uses a Goal-compatible observation space.
|
||||||
if not isinstance(self.observation_space, gym.spaces.Dict):
|
if not isinstance(self.observation_space, gym.spaces.Dict):
|
||||||
raise error.Error(
|
raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict")
|
||||||
"GoalEnv requires an observation space of type gym.spaces.Dict"
|
|
||||||
)
|
|
||||||
for key in ["observation", "achieved_goal", "desired_goal"]:
|
for key in ["observation", "achieved_goal", "desired_goal"]:
|
||||||
if key not in self.observation_space.spaces:
|
if key not in self.observation_space.spaces:
|
||||||
raise error.Error(
|
raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))
|
||||||
'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(
|
|
||||||
key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_reward(self, achieved_goal, desired_goal, info):
|
def compute_reward(self, achieved_goal, desired_goal, info):
|
||||||
"""Compute the step reward. This externalizes the reward function and makes
|
"""Compute the step reward. This externalizes the reward function and makes
|
||||||
@@ -227,9 +221,7 @@ class Wrapper(Env):
|
|||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name.startswith("_"):
|
if name.startswith("_"):
|
||||||
raise AttributeError(
|
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
|
||||||
"attempted to get missing private attribute '{}'".format(name)
|
|
||||||
)
|
|
||||||
return getattr(self.env, name)
|
return getattr(self.env, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@@ -422,9 +422,7 @@ for reward_type in ["sparse", "dense"]:
|
|||||||
register(
|
register(
|
||||||
id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
|
id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
|
||||||
entry_point="gym.envs.robotics:HandBlockEnv",
|
entry_point="gym.envs.robotics:HandBlockEnv",
|
||||||
kwargs=_merge(
|
kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs),
|
||||||
{"target_position": "ignore", "target_rotation": "parallel"}, kwargs
|
|
||||||
),
|
|
||||||
max_episode_steps=100,
|
max_episode_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -73,9 +73,7 @@ class AlgorithmicEnv(Env):
|
|||||||
# 1. Move read head left or right (or up/down)
|
# 1. Move read head left or right (or up/down)
|
||||||
# 2. Write or not
|
# 2. Write or not
|
||||||
# 3. Which character to write. (Ignored if should_write=0)
|
# 3. Which character to write. (Ignored if should_write=0)
|
||||||
self.action_space = Tuple(
|
self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
|
||||||
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
|
|
||||||
)
|
|
||||||
# Can see just what is on the input tape (one of n characters, or
|
# Can see just what is on the input tape (one of n characters, or
|
||||||
# nothing)
|
# nothing)
|
||||||
self.observation_space = Discrete(self.base + 1)
|
self.observation_space = Discrete(self.base + 1)
|
||||||
@@ -147,10 +145,7 @@ class AlgorithmicEnv(Env):
|
|||||||
move = self.MOVEMENTS[inp_act]
|
move = self.MOVEMENTS[inp_act]
|
||||||
outfile.write("Action : Tuple(move over input: %s,\n" % move)
|
outfile.write("Action : Tuple(move over input: %s,\n" % move)
|
||||||
out_act = out_act == 1
|
out_act = out_act == 1
|
||||||
outfile.write(
|
outfile.write(" write to the output tape: %s,\n" % out_act)
|
||||||
" write to the output tape: %s,\n"
|
|
||||||
% out_act
|
|
||||||
)
|
|
||||||
outfile.write(" prediction: %s)\n" % pred_str)
|
outfile.write(" prediction: %s)\n" % pred_str)
|
||||||
else:
|
else:
|
||||||
outfile.write("\n" * 5)
|
outfile.write("\n" * 5)
|
||||||
@@ -276,9 +271,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
x_str = "Observation Tape : "
|
x_str = "Observation Tape : "
|
||||||
for i in range(-2, self.input_width + 2):
|
for i in range(-2, self.input_width + 2):
|
||||||
if i == x:
|
if i == x:
|
||||||
x_str += colorize(
|
x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
|
||||||
self._get_str_obs(np.array([i])), "green", highlight=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
x_str += self._get_str_obs(np.array([i]))
|
x_str += self._get_str_obs(np.array([i]))
|
||||||
x_str += "\n"
|
x_str += "\n"
|
||||||
@@ -311,10 +304,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
self.read_head_position = x, y
|
self.read_head_position = x, y
|
||||||
|
|
||||||
def generate_input_data(self, size):
|
def generate_input_data(self, size):
|
||||||
return [
|
return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]
|
||||||
[self.np_random.randint(self.base) for _ in range(self.rows)]
|
|
||||||
for __ in range(size)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _get_obs(self, pos=None):
|
def _get_obs(self, pos=None):
|
||||||
if pos is None:
|
if pos is None:
|
||||||
@@ -336,9 +326,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
x_str += " " * len(label)
|
x_str += " " * len(label)
|
||||||
for i in range(-2, self.input_width + 2):
|
for i in range(-2, self.input_width + 2):
|
||||||
if i == x[0] and j == x[1]:
|
if i == x[0] and j == x[1]:
|
||||||
x_str += colorize(
|
x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
|
||||||
self._get_str_obs((i, j)), "green", highlight=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
x_str += self._get_str_obs((i, j))
|
x_str += self._get_str_obs((i, j))
|
||||||
x_str += "\n"
|
x_str += "\n"
|
||||||
|
@@ -10,12 +10,8 @@ ALL_ENVS = [
|
|||||||
alg.reverse.ReverseEnv,
|
alg.reverse.ReverseEnv,
|
||||||
alg.reversed_addition.ReversedAdditionEnv,
|
alg.reversed_addition.ReversedAdditionEnv,
|
||||||
]
|
]
|
||||||
ALL_TAPE_ENVS = [
|
ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)]
|
||||||
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)
|
ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)]
|
||||||
]
|
|
||||||
ALL_GRID_ENVS = [
|
|
||||||
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def imprint(env, input_arr):
|
def imprint(env, input_arr):
|
||||||
@@ -92,10 +88,7 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
|
|||||||
|
|
||||||
def test_grid_naviation(self):
|
def test_grid_naviation(self):
|
||||||
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
|
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
|
||||||
N, S, E, W = [
|
N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]]
|
||||||
env._movement_idx(named_dir)
|
|
||||||
for named_dir in ["up", "down", "right", "left"]
|
|
||||||
]
|
|
||||||
# Corresponds to a grid that looks like...
|
# Corresponds to a grid that looks like...
|
||||||
# 0 1 2
|
# 0 1 2
|
||||||
# 3 4 5
|
# 3 4 5
|
||||||
@@ -204,9 +197,7 @@ class TestTargets(unittest.TestCase):
|
|||||||
|
|
||||||
def test_repeat_copy_target(self):
|
def test_repeat_copy_target(self):
|
||||||
env = alg.repeat_copy.RepeatCopyEnv()
|
env = alg.repeat_copy.RepeatCopyEnv()
|
||||||
self.assertEqual(
|
self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2])
|
||||||
env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestInputGeneration(unittest.TestCase):
|
class TestInputGeneration(unittest.TestCase):
|
||||||
|
@@ -9,8 +9,7 @@ try:
|
|||||||
import atari_py
|
import atari_py
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise error.DependencyNotInstalled(
|
raise error.DependencyNotInstalled(
|
||||||
"{}. (HINT: you can install Atari dependencies by running "
|
"{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e)
|
||||||
"'pip install gym[atari]'.)".format(e)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,35 +63,23 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
|||||||
|
|
||||||
# Tune (or disable) ALE's action repeat:
|
# Tune (or disable) ALE's action repeat:
|
||||||
# https://github.com/openai/gym/issues/349
|
# https://github.com/openai/gym/issues/349
|
||||||
assert isinstance(
|
assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format(
|
||||||
repeat_action_probability, (float, int)
|
repeat_action_probability
|
||||||
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
|
|
||||||
self.ale.setFloat(
|
|
||||||
"repeat_action_probability".encode("utf-8"), repeat_action_probability
|
|
||||||
)
|
)
|
||||||
|
self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability)
|
||||||
|
|
||||||
self.seed()
|
self.seed()
|
||||||
|
|
||||||
self._action_set = (
|
self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet()
|
||||||
self.ale.getLegalActionSet()
|
|
||||||
if full_action_space
|
|
||||||
else self.ale.getMinimalActionSet()
|
|
||||||
)
|
|
||||||
self.action_space = spaces.Discrete(len(self._action_set))
|
self.action_space = spaces.Discrete(len(self._action_set))
|
||||||
|
|
||||||
(screen_width, screen_height) = self.ale.getScreenDims()
|
(screen_width, screen_height) = self.ale.getScreenDims()
|
||||||
if self._obs_type == "ram":
|
if self._obs_type == "ram":
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,))
|
||||||
low=0, high=255, dtype=np.uint8, shape=(128,)
|
|
||||||
)
|
|
||||||
elif self._obs_type == "image":
|
elif self._obs_type == "image":
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8)
|
||||||
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error("Unrecognized observation type: {}".format(self._obs_type))
|
||||||
"Unrecognized observation type: {}".format(self._obs_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self.np_random, seed1 = seeding.np_random(seed)
|
self.np_random, seed1 = seeding.np_random(seed)
|
||||||
@@ -107,9 +94,9 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
|||||||
if self.game_mode is not None:
|
if self.game_mode is not None:
|
||||||
modes = self.ale.getAvailableModes()
|
modes = self.ale.getAvailableModes()
|
||||||
|
|
||||||
assert self.game_mode in modes, (
|
assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format(
|
||||||
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
|
self.game_mode, self.game, modes
|
||||||
).format(self.game_mode, self.game, modes)
|
)
|
||||||
self.ale.setMode(self.game_mode)
|
self.ale.setMode(self.game_mode)
|
||||||
|
|
||||||
if self.game_difficulty is not None:
|
if self.game_difficulty is not None:
|
||||||
|
@@ -100,10 +100,7 @@ class ContactDetector(contactListener):
|
|||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
def BeginContact(self, contact):
|
def BeginContact(self, contact):
|
||||||
if (
|
if self.env.hull == contact.fixtureA.body or self.env.hull == contact.fixtureB.body:
|
||||||
self.env.hull == contact.fixtureA.body
|
|
||||||
or self.env.hull == contact.fixtureB.body
|
|
||||||
):
|
|
||||||
self.env.game_over = True
|
self.env.game_over = True
|
||||||
for leg in [self.env.legs[1], self.env.legs[3]]:
|
for leg in [self.env.legs[1], self.env.legs[3]]:
|
||||||
if leg in [contact.fixtureA.body, contact.fixtureB.body]:
|
if leg in [contact.fixtureA.body, contact.fixtureB.body]:
|
||||||
@@ -202,9 +199,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
|
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
|
||||||
self.terrain.append(t)
|
self.terrain.append(t)
|
||||||
|
|
||||||
self.fd_polygon.shape.vertices = [
|
self.fd_polygon.shape.vertices = [(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly]
|
||||||
(p[0] + TERRAIN_STEP * counter, p[1]) for p in poly
|
|
||||||
]
|
|
||||||
t = self.world.CreateStaticBody(fixtures=self.fd_polygon)
|
t = self.world.CreateStaticBody(fixtures=self.fd_polygon)
|
||||||
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
|
t.color1, t.color2 = (1, 1, 1), (0.6, 0.6, 0.6)
|
||||||
self.terrain.append(t)
|
self.terrain.append(t)
|
||||||
@@ -301,12 +296,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
y = VIEWPORT_H / SCALE * 3 / 4
|
y = VIEWPORT_H / SCALE * 3 / 4
|
||||||
poly = [
|
poly = [
|
||||||
(
|
(
|
||||||
x
|
x + 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
||||||
+ 15 * TERRAIN_STEP * math.sin(3.14 * 2 * a / 5)
|
y + 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5) + self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
||||||
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
|
||||||
y
|
|
||||||
+ 5 * TERRAIN_STEP * math.cos(3.14 * 2 * a / 5)
|
|
||||||
+ self.np_random.uniform(0, 5 * TERRAIN_STEP),
|
|
||||||
)
|
)
|
||||||
for a in range(5)
|
for a in range(5)
|
||||||
]
|
]
|
||||||
@@ -331,14 +322,10 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
|
|
||||||
init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2
|
init_x = TERRAIN_STEP * TERRAIN_STARTPAD / 2
|
||||||
init_y = TERRAIN_HEIGHT + 2 * LEG_H
|
init_y = TERRAIN_HEIGHT + 2 * LEG_H
|
||||||
self.hull = self.world.CreateDynamicBody(
|
self.hull = self.world.CreateDynamicBody(position=(init_x, init_y), fixtures=HULL_FD)
|
||||||
position=(init_x, init_y), fixtures=HULL_FD
|
|
||||||
)
|
|
||||||
self.hull.color1 = (0.5, 0.4, 0.9)
|
self.hull.color1 = (0.5, 0.4, 0.9)
|
||||||
self.hull.color2 = (0.3, 0.3, 0.5)
|
self.hull.color2 = (0.3, 0.3, 0.5)
|
||||||
self.hull.ApplyForceToCenter(
|
self.hull.ApplyForceToCenter((self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True)
|
||||||
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.legs = []
|
self.legs = []
|
||||||
self.joints = []
|
self.joints = []
|
||||||
@@ -412,21 +399,13 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1))
|
self.joints[3].motorSpeed = float(SPEED_KNEE * np.clip(action[3], -1, 1))
|
||||||
else:
|
else:
|
||||||
self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0]))
|
self.joints[0].motorSpeed = float(SPEED_HIP * np.sign(action[0]))
|
||||||
self.joints[0].maxMotorTorque = float(
|
self.joints[0].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1))
|
||||||
MOTORS_TORQUE * np.clip(np.abs(action[0]), 0, 1)
|
|
||||||
)
|
|
||||||
self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1]))
|
self.joints[1].motorSpeed = float(SPEED_KNEE * np.sign(action[1]))
|
||||||
self.joints[1].maxMotorTorque = float(
|
self.joints[1].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1))
|
||||||
MOTORS_TORQUE * np.clip(np.abs(action[1]), 0, 1)
|
|
||||||
)
|
|
||||||
self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2]))
|
self.joints[2].motorSpeed = float(SPEED_HIP * np.sign(action[2]))
|
||||||
self.joints[2].maxMotorTorque = float(
|
self.joints[2].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1))
|
||||||
MOTORS_TORQUE * np.clip(np.abs(action[2]), 0, 1)
|
|
||||||
)
|
|
||||||
self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3]))
|
self.joints[3].motorSpeed = float(SPEED_KNEE * np.sign(action[3]))
|
||||||
self.joints[3].maxMotorTorque = float(
|
self.joints[3].maxMotorTorque = float(MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1))
|
||||||
MOTORS_TORQUE * np.clip(np.abs(action[3]), 0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
|
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
|
||||||
|
|
||||||
@@ -465,12 +444,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
|
|
||||||
self.scroll = pos.x - VIEWPORT_W / SCALE / 5
|
self.scroll = pos.x - VIEWPORT_W / SCALE / 5
|
||||||
|
|
||||||
shaping = (
|
shaping = 130 * pos[0] / SCALE # moving forward is a way to receive reward (normalized to get 300 on completion)
|
||||||
130 * pos[0] / SCALE
|
shaping -= 5.0 * abs(state[0]) # keep head straight, other than that and falling, any behavior is unpunished
|
||||||
) # moving forward is a way to receive reward (normalized to get 300 on completion)
|
|
||||||
shaping -= 5.0 * abs(
|
|
||||||
state[0]
|
|
||||||
) # keep head straight, other than that and falling, any behavior is unpunished
|
|
||||||
|
|
||||||
reward = 0
|
reward = 0
|
||||||
if self.prev_shaping is not None:
|
if self.prev_shaping is not None:
|
||||||
@@ -494,9 +469,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
|
|
||||||
if self.viewer is None:
|
if self.viewer is None:
|
||||||
self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H)
|
self.viewer = rendering.Viewer(VIEWPORT_W, VIEWPORT_H)
|
||||||
self.viewer.set_bounds(
|
self.viewer.set_bounds(self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE)
|
||||||
self.scroll, VIEWPORT_W / SCALE + self.scroll, 0, VIEWPORT_H / SCALE
|
|
||||||
)
|
|
||||||
|
|
||||||
self.viewer.draw_polygon(
|
self.viewer.draw_polygon(
|
||||||
[
|
[
|
||||||
@@ -512,9 +485,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
continue
|
continue
|
||||||
if x1 > self.scroll / 2 + VIEWPORT_W / SCALE:
|
if x1 > self.scroll / 2 + VIEWPORT_W / SCALE:
|
||||||
continue
|
continue
|
||||||
self.viewer.draw_polygon(
|
self.viewer.draw_polygon([(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1))
|
||||||
[(p[0] + self.scroll / 2, p[1]) for p in poly], color=(1, 1, 1)
|
|
||||||
)
|
|
||||||
for poly, color in self.terrain_poly:
|
for poly, color in self.terrain_poly:
|
||||||
if poly[1][0] < self.scroll:
|
if poly[1][0] < self.scroll:
|
||||||
continue
|
continue
|
||||||
@@ -525,11 +496,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.lidar_render = (self.lidar_render + 1) % 100
|
self.lidar_render = (self.lidar_render + 1) % 100
|
||||||
i = self.lidar_render
|
i = self.lidar_render
|
||||||
if i < 2 * len(self.lidar):
|
if i < 2 * len(self.lidar):
|
||||||
l = (
|
l = self.lidar[i] if i < len(self.lidar) else self.lidar[len(self.lidar) - i - 1]
|
||||||
self.lidar[i]
|
|
||||||
if i < len(self.lidar)
|
|
||||||
else self.lidar[len(self.lidar) - i - 1]
|
|
||||||
)
|
|
||||||
self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1)
|
self.viewer.draw_polyline([l.p1, l.p2], color=(1, 0, 0), linewidth=1)
|
||||||
|
|
||||||
for obj in self.drawlist:
|
for obj in self.drawlist:
|
||||||
@@ -537,12 +504,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
trans = f.body.transform
|
trans = f.body.transform
|
||||||
if type(f.shape) is circleShape:
|
if type(f.shape) is circleShape:
|
||||||
t = rendering.Transform(translation=trans * f.shape.pos)
|
t = rendering.Transform(translation=trans * f.shape.pos)
|
||||||
self.viewer.draw_circle(
|
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t)
|
||||||
f.shape.radius, 30, color=obj.color1
|
self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t)
|
||||||
).add_attr(t)
|
|
||||||
self.viewer.draw_circle(
|
|
||||||
f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2
|
|
||||||
).add_attr(t)
|
|
||||||
else:
|
else:
|
||||||
path = [trans * v for v in f.shape.vertices]
|
path = [trans * v for v in f.shape.vertices]
|
||||||
self.viewer.draw_polygon(path, color=obj.color1)
|
self.viewer.draw_polygon(path, color=obj.color1)
|
||||||
@@ -552,9 +515,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
flagy1 = TERRAIN_HEIGHT
|
flagy1 = TERRAIN_HEIGHT
|
||||||
flagy2 = flagy1 + 50 / SCALE
|
flagy2 = flagy1 + 50 / SCALE
|
||||||
x = TERRAIN_STEP * 3
|
x = TERRAIN_STEP * 3
|
||||||
self.viewer.draw_polyline(
|
self.viewer.draw_polyline([(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2)
|
||||||
[(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2
|
|
||||||
)
|
|
||||||
f = [
|
f = [
|
||||||
(x, flagy2),
|
(x, flagy2),
|
||||||
(x, flagy2 - 10 / SCALE),
|
(x, flagy2 - 10 / SCALE),
|
||||||
|
@@ -23,9 +23,7 @@ from Box2D.b2 import (
|
|||||||
SIZE = 0.02
|
SIZE = 0.02
|
||||||
ENGINE_POWER = 100000000 * SIZE * SIZE
|
ENGINE_POWER = 100000000 * SIZE * SIZE
|
||||||
WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE
|
WHEEL_MOMENT_OF_INERTIA = 4000 * SIZE * SIZE
|
||||||
FRICTION_LIMIT = (
|
FRICTION_LIMIT = 1000000 * SIZE * SIZE # friction ~= mass ~= size^2 (calculated implicitly using density)
|
||||||
1000000 * SIZE * SIZE
|
|
||||||
) # friction ~= mass ~= size^2 (calculated implicitly using density)
|
|
||||||
WHEEL_R = 27
|
WHEEL_R = 27
|
||||||
WHEEL_W = 14
|
WHEEL_W = 14
|
||||||
WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)]
|
WHEELPOS = [(-55, +80), (+55, +80), (-55, -82), (+55, -82)]
|
||||||
@@ -55,27 +53,19 @@ class Car:
|
|||||||
angle=init_angle,
|
angle=init_angle,
|
||||||
fixtures=[
|
fixtures=[
|
||||||
fixtureDef(
|
fixtureDef(
|
||||||
shape=polygonShape(
|
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]),
|
||||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY1]
|
|
||||||
),
|
|
||||||
density=1.0,
|
density=1.0,
|
||||||
),
|
),
|
||||||
fixtureDef(
|
fixtureDef(
|
||||||
shape=polygonShape(
|
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]),
|
||||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY2]
|
|
||||||
),
|
|
||||||
density=1.0,
|
density=1.0,
|
||||||
),
|
),
|
||||||
fixtureDef(
|
fixtureDef(
|
||||||
shape=polygonShape(
|
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]),
|
||||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY3]
|
|
||||||
),
|
|
||||||
density=1.0,
|
density=1.0,
|
||||||
),
|
),
|
||||||
fixtureDef(
|
fixtureDef(
|
||||||
shape=polygonShape(
|
shape=polygonShape(vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]),
|
||||||
vertices=[(x * SIZE, y * SIZE) for x, y in HULL_POLY4]
|
|
||||||
),
|
|
||||||
density=1.0,
|
density=1.0,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -95,12 +85,7 @@ class Car:
|
|||||||
position=(init_x + wx * SIZE, init_y + wy * SIZE),
|
position=(init_x + wx * SIZE, init_y + wy * SIZE),
|
||||||
angle=init_angle,
|
angle=init_angle,
|
||||||
fixtures=fixtureDef(
|
fixtures=fixtureDef(
|
||||||
shape=polygonShape(
|
shape=polygonShape(vertices=[(x * front_k * SIZE, y * front_k * SIZE) for x, y in WHEEL_POLY]),
|
||||||
vertices=[
|
|
||||||
(x * front_k * SIZE, y * front_k * SIZE)
|
|
||||||
for x, y in WHEEL_POLY
|
|
||||||
]
|
|
||||||
),
|
|
||||||
density=0.1,
|
density=0.1,
|
||||||
categoryBits=0x0020,
|
categoryBits=0x0020,
|
||||||
maskBits=0x001,
|
maskBits=0x001,
|
||||||
@@ -175,9 +160,7 @@ class Car:
|
|||||||
grass = True
|
grass = True
|
||||||
friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile
|
friction_limit = FRICTION_LIMIT * 0.6 # Grass friction if no tile
|
||||||
for tile in w.tiles:
|
for tile in w.tiles:
|
||||||
friction_limit = max(
|
friction_limit = max(friction_limit, FRICTION_LIMIT * tile.road_friction)
|
||||||
friction_limit, FRICTION_LIMIT * tile.road_friction
|
|
||||||
)
|
|
||||||
grass = False
|
grass = False
|
||||||
|
|
||||||
# Force
|
# Force
|
||||||
@@ -192,13 +175,7 @@ class Car:
|
|||||||
# domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega
|
# domega = dt*W/WHEEL_MOMENT_OF_INERTIA/w.omega
|
||||||
|
|
||||||
# add small coef not to divide by zero
|
# add small coef not to divide by zero
|
||||||
w.omega += (
|
w.omega += dt * ENGINE_POWER * w.gas / WHEEL_MOMENT_OF_INERTIA / (abs(w.omega) + 5.0)
|
||||||
dt
|
|
||||||
* ENGINE_POWER
|
|
||||||
* w.gas
|
|
||||||
/ WHEEL_MOMENT_OF_INERTIA
|
|
||||||
/ (abs(w.omega) + 5.0)
|
|
||||||
)
|
|
||||||
self.fuel_spent += dt * ENGINE_POWER * w.gas
|
self.fuel_spent += dt * ENGINE_POWER * w.gas
|
||||||
|
|
||||||
if w.brake >= 0.9:
|
if w.brake >= 0.9:
|
||||||
@@ -226,18 +203,12 @@ class Car:
|
|||||||
|
|
||||||
# Skid trace
|
# Skid trace
|
||||||
if abs(force) > 2.0 * friction_limit:
|
if abs(force) > 2.0 * friction_limit:
|
||||||
if (
|
if w.skid_particle and w.skid_particle.grass == grass and len(w.skid_particle.poly) < 30:
|
||||||
w.skid_particle
|
|
||||||
and w.skid_particle.grass == grass
|
|
||||||
and len(w.skid_particle.poly) < 30
|
|
||||||
):
|
|
||||||
w.skid_particle.poly.append((w.position[0], w.position[1]))
|
w.skid_particle.poly.append((w.position[0], w.position[1]))
|
||||||
elif w.skid_start is None:
|
elif w.skid_start is None:
|
||||||
w.skid_start = w.position
|
w.skid_start = w.position
|
||||||
else:
|
else:
|
||||||
w.skid_particle = self._create_particle(
|
w.skid_particle = self._create_particle(w.skid_start, w.position, grass)
|
||||||
w.skid_start, w.position, grass
|
|
||||||
)
|
|
||||||
w.skid_start = None
|
w.skid_start = None
|
||||||
else:
|
else:
|
||||||
w.skid_start = None
|
w.skid_start = None
|
||||||
|
@@ -132,18 +132,14 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self.reward = 0.0
|
self.reward = 0.0
|
||||||
self.prev_reward = 0.0
|
self.prev_reward = 0.0
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.fd_tile = fixtureDef(
|
self.fd_tile = fixtureDef(shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)]))
|
||||||
shape=polygonShape(vertices=[(0, 0), (1, 0), (1, -1), (0, -1)])
|
|
||||||
)
|
|
||||||
|
|
||||||
self.action_space = spaces.Box(
|
self.action_space = spaces.Box(
|
||||||
np.array([-1, 0, 0]).astype(np.float32),
|
np.array([-1, 0, 0]).astype(np.float32),
|
||||||
np.array([+1, +1, +1]).astype(np.float32),
|
np.array([+1, +1, +1]).astype(np.float32),
|
||||||
) # steer, gas, brake
|
) # steer, gas, brake
|
||||||
|
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8)
|
||||||
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
@@ -246,9 +242,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
i -= 1
|
i -= 1
|
||||||
if i == 0:
|
if i == 0:
|
||||||
return False # Failed
|
return False # Failed
|
||||||
pass_through_start = (
|
pass_through_start = track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
|
||||||
track[i][0] > self.start_alpha and track[i - 1][0] <= self.start_alpha
|
|
||||||
)
|
|
||||||
if pass_through_start and i2 == -1:
|
if pass_through_start and i2 == -1:
|
||||||
i2 = i
|
i2 = i
|
||||||
elif pass_through_start and i1 == -1:
|
elif pass_through_start and i1 == -1:
|
||||||
@@ -266,8 +260,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
first_perp_y = math.sin(first_beta)
|
first_perp_y = math.sin(first_beta)
|
||||||
# Length of perpendicular jump to put together head and tail
|
# Length of perpendicular jump to put together head and tail
|
||||||
well_glued_together = np.sqrt(
|
well_glued_together = np.sqrt(
|
||||||
np.square(first_perp_x * (track[0][2] - track[-1][2]))
|
np.square(first_perp_x * (track[0][2] - track[-1][2])) + np.square(first_perp_y * (track[0][3] - track[-1][3]))
|
||||||
+ np.square(first_perp_y * (track[0][3] - track[-1][3]))
|
|
||||||
)
|
)
|
||||||
if well_glued_together > TRACK_DETAIL_STEP:
|
if well_glued_together > TRACK_DETAIL_STEP:
|
||||||
return False
|
return False
|
||||||
@@ -337,9 +330,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
|
x2 + side * (TRACK_WIDTH + BORDER) * math.cos(beta2),
|
||||||
y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
|
y2 + side * (TRACK_WIDTH + BORDER) * math.sin(beta2),
|
||||||
)
|
)
|
||||||
self.road_poly.append(
|
self.road_poly.append(([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0)))
|
||||||
([b1_l, b1_r, b2_r, b2_l], (1, 1, 1) if i % 2 == 0 else (1, 0, 0))
|
|
||||||
)
|
|
||||||
self.track = track
|
self.track = track
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -356,10 +347,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
if success:
|
if success:
|
||||||
break
|
break
|
||||||
if self.verbose == 1:
|
if self.verbose == 1:
|
||||||
print(
|
print("retry to generate track (normal if there are not many" "instances of this message)")
|
||||||
"retry to generate track (normal if there are not many"
|
|
||||||
"instances of this message)"
|
|
||||||
)
|
|
||||||
self.car = Car(self.world, *self.track[0][1:4])
|
self.car = Car(self.world, *self.track[0][1:4])
|
||||||
|
|
||||||
return self.step(None)[0]
|
return self.step(None)[0]
|
||||||
@@ -424,10 +412,8 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
angle = math.atan2(vel[0], vel[1])
|
angle = math.atan2(vel[0], vel[1])
|
||||||
self.transform.set_scale(zoom, zoom)
|
self.transform.set_scale(zoom, zoom)
|
||||||
self.transform.set_translation(
|
self.transform.set_translation(
|
||||||
WINDOW_W / 2
|
WINDOW_W / 2 - (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
|
||||||
- (scroll_x * zoom * math.cos(angle) - scroll_y * zoom * math.sin(angle)),
|
WINDOW_H / 4 - (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
|
||||||
WINDOW_H / 4
|
|
||||||
- (scroll_x * zoom * math.sin(angle) + scroll_y * zoom * math.cos(angle)),
|
|
||||||
)
|
)
|
||||||
self.transform.set_rotation(angle)
|
self.transform.set_rotation(angle)
|
||||||
|
|
||||||
@@ -449,9 +435,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
else:
|
else:
|
||||||
pixel_scale = 1
|
pixel_scale = 1
|
||||||
if hasattr(win.context, "_nscontext"):
|
if hasattr(win.context, "_nscontext"):
|
||||||
pixel_scale = (
|
pixel_scale = win.context._nscontext.view().backingScaleFactor() # pylint: disable=protected-access
|
||||||
win.context._nscontext.view().backingScaleFactor()
|
|
||||||
) # pylint: disable=protected-access
|
|
||||||
VP_W = int(pixel_scale * WINDOW_W)
|
VP_W = int(pixel_scale * WINDOW_W)
|
||||||
VP_H = int(pixel_scale * WINDOW_H)
|
VP_H = int(pixel_scale * WINDOW_H)
|
||||||
|
|
||||||
@@ -468,9 +452,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
win.flip()
|
win.flip()
|
||||||
return self.viewer.isopen
|
return self.viewer.isopen
|
||||||
|
|
||||||
image_data = (
|
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
||||||
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
|
||||||
)
|
|
||||||
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
|
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
|
||||||
arr = arr.reshape(VP_H, VP_W, 4)
|
arr = arr.reshape(VP_H, VP_W, 4)
|
||||||
arr = arr[::-1, :, 0:3]
|
arr = arr[::-1, :, 0:3]
|
||||||
@@ -525,9 +507,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
for p in poly:
|
for p in poly:
|
||||||
polygons_.extend([p[0], p[1], 0])
|
polygons_.extend([p[0], p[1], 0])
|
||||||
|
|
||||||
vl = pyglet.graphics.vertex_list(
|
vl = pyglet.graphics.vertex_list(len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors)) # gl.GL_QUADS,
|
||||||
len(polygons_) // 3, ("v3f", polygons_), ("c4f", colors) # gl.GL_QUADS,
|
|
||||||
)
|
|
||||||
vl.draw(gl.GL_QUADS)
|
vl.draw(gl.GL_QUADS)
|
||||||
vl.delete()
|
vl.delete()
|
||||||
|
|
||||||
@@ -575,10 +555,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
true_speed = np.sqrt(
|
true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1]))
|
||||||
np.square(self.car.hull.linearVelocity[0])
|
|
||||||
+ np.square(self.car.hull.linearVelocity[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
vertical_ind(5, 0.02 * true_speed, (1, 1, 1))
|
vertical_ind(5, 0.02 * true_speed, (1, 1, 1))
|
||||||
vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors
|
vertical_ind(7, 0.01 * self.car.wheels[0].omega, (0.0, 0, 1)) # ABS sensors
|
||||||
@@ -587,9 +564,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1))
|
vertical_ind(10, 0.01 * self.car.wheels[3].omega, (0.2, 0, 1))
|
||||||
horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0))
|
horiz_ind(20, -10.0 * self.car.wheels[0].joint.angle, (0, 1, 0))
|
||||||
horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0))
|
horiz_ind(30, -0.8 * self.car.hull.angularVelocity, (1, 0, 0))
|
||||||
vl = pyglet.graphics.vertex_list(
|
vl = pyglet.graphics.vertex_list(len(polygons) // 3, ("v3f", polygons), ("c4f", colors)) # gl.GL_QUADS,
|
||||||
len(polygons) // 3, ("v3f", polygons), ("c4f", colors) # gl.GL_QUADS,
|
|
||||||
)
|
|
||||||
vl.draw(gl.GL_QUADS)
|
vl.draw(gl.GL_QUADS)
|
||||||
vl.delete()
|
vl.delete()
|
||||||
self.score_label.text = "%04i" % self.reward
|
self.score_label.text = "%04i" % self.reward
|
||||||
|
@@ -71,10 +71,7 @@ class ContactDetector(contactListener):
|
|||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
def BeginContact(self, contact):
|
def BeginContact(self, contact):
|
||||||
if (
|
if self.env.lander == contact.fixtureA.body or self.env.lander == contact.fixtureB.body:
|
||||||
self.env.lander == contact.fixtureA.body
|
|
||||||
or self.env.lander == contact.fixtureB.body
|
|
||||||
):
|
|
||||||
self.env.game_over = True
|
self.env.game_over = True
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
|
if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
|
||||||
@@ -104,9 +101,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.prev_reward = None
|
self.prev_reward = None
|
||||||
|
|
||||||
# useful range is -1 .. +1, but spikes can be higher
|
# useful range is -1 .. +1, but spikes can be higher
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32)
|
||||||
-np.inf, np.inf, shape=(8,), dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.continuous:
|
if self.continuous:
|
||||||
# Action is two floats [main engine, left-right engines].
|
# Action is two floats [main engine, left-right engines].
|
||||||
@@ -157,14 +152,9 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
height[CHUNKS // 2 + 0] = self.helipad_y
|
height[CHUNKS // 2 + 0] = self.helipad_y
|
||||||
height[CHUNKS // 2 + 1] = self.helipad_y
|
height[CHUNKS // 2 + 1] = self.helipad_y
|
||||||
height[CHUNKS // 2 + 2] = self.helipad_y
|
height[CHUNKS // 2 + 2] = self.helipad_y
|
||||||
smooth_y = [
|
smooth_y = [0.33 * (height[i - 1] + height[i + 0] + height[i + 1]) for i in range(CHUNKS)]
|
||||||
0.33 * (height[i - 1] + height[i + 0] + height[i + 1])
|
|
||||||
for i in range(CHUNKS)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.moon = self.world.CreateStaticBody(
|
self.moon = self.world.CreateStaticBody(shapes=edgeShape(vertices=[(0, 0), (W, 0)]))
|
||||||
shapes=edgeShape(vertices=[(0, 0), (W, 0)])
|
|
||||||
)
|
|
||||||
self.sky_polys = []
|
self.sky_polys = []
|
||||||
for i in range(CHUNKS - 1):
|
for i in range(CHUNKS - 1):
|
||||||
p1 = (chunk_x[i], smooth_y[i])
|
p1 = (chunk_x[i], smooth_y[i])
|
||||||
@@ -180,9 +170,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
position=(VIEWPORT_W / SCALE / 2, initial_y),
|
position=(VIEWPORT_W / SCALE / 2, initial_y),
|
||||||
angle=0.0,
|
angle=0.0,
|
||||||
fixtures=fixtureDef(
|
fixtures=fixtureDef(
|
||||||
shape=polygonShape(
|
shape=polygonShape(vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]),
|
||||||
vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]
|
|
||||||
),
|
|
||||||
density=5.0,
|
density=5.0,
|
||||||
friction=0.1,
|
friction=0.1,
|
||||||
categoryBits=0x0010,
|
categoryBits=0x0010,
|
||||||
@@ -227,9 +215,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
motorSpeed=+0.3 * i, # low enough not to jump back into the sky
|
motorSpeed=+0.3 * i, # low enough not to jump back into the sky
|
||||||
)
|
)
|
||||||
if i == -1:
|
if i == -1:
|
||||||
rjd.lowerAngle = (
|
rjd.lowerAngle = +0.9 - 0.5 # The most esoteric numbers here, angled legs have freedom to travel within
|
||||||
+0.9 - 0.5
|
|
||||||
) # The most esoteric numbers here, angled legs have freedom to travel within
|
|
||||||
rjd.upperAngle = +0.9
|
rjd.upperAngle = +0.9
|
||||||
else:
|
else:
|
||||||
rjd.lowerAngle = -0.9
|
rjd.lowerAngle = -0.9
|
||||||
@@ -278,9 +264,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]
|
dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]
|
||||||
|
|
||||||
m_power = 0.0
|
m_power = 0.0
|
||||||
if (self.continuous and action[0] > 0.0) or (
|
if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2):
|
||||||
not self.continuous and action == 2
|
|
||||||
):
|
|
||||||
# Main engine
|
# Main engine
|
||||||
if self.continuous:
|
if self.continuous:
|
||||||
m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0
|
m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0
|
||||||
@@ -310,9 +294,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
)
|
)
|
||||||
|
|
||||||
s_power = 0.0
|
s_power = 0.0
|
||||||
if (self.continuous and np.abs(action[1]) > 0.5) or (
|
if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]):
|
||||||
not self.continuous and action in [1, 3]
|
|
||||||
):
|
|
||||||
# Orientation engines
|
# Orientation engines
|
||||||
if self.continuous:
|
if self.continuous:
|
||||||
direction = np.sign(action[1])
|
direction = np.sign(action[1])
|
||||||
@@ -321,12 +303,8 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
else:
|
else:
|
||||||
direction = action - 2
|
direction = action - 2
|
||||||
s_power = 1.0
|
s_power = 1.0
|
||||||
ox = tip[0] * dispersion[0] + side[0] * (
|
ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
|
||||||
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
|
oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE)
|
||||||
)
|
|
||||||
oy = -tip[1] * dispersion[0] - side[1] * (
|
|
||||||
3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
|
|
||||||
)
|
|
||||||
impulse_pos = (
|
impulse_pos = (
|
||||||
self.lander.position[0] + ox - tip[0] * 17 / SCALE,
|
self.lander.position[0] + ox - tip[0] * 17 / SCALE,
|
||||||
self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
|
self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
|
||||||
@@ -372,9 +350,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
reward = shaping - self.prev_shaping
|
reward = shaping - self.prev_shaping
|
||||||
self.prev_shaping = shaping
|
self.prev_shaping = shaping
|
||||||
|
|
||||||
reward -= (
|
reward -= m_power * 0.30 # less fuel spent is better, about -30 for heuristic landing
|
||||||
m_power * 0.30
|
|
||||||
) # less fuel spent is better, about -30 for heuristic landing
|
|
||||||
reward -= s_power * 0.03
|
reward -= s_power * 0.03
|
||||||
|
|
||||||
done = False
|
done = False
|
||||||
@@ -416,12 +392,8 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
trans = f.body.transform
|
trans = f.body.transform
|
||||||
if type(f.shape) is circleShape:
|
if type(f.shape) is circleShape:
|
||||||
t = rendering.Transform(translation=trans * f.shape.pos)
|
t = rendering.Transform(translation=trans * f.shape.pos)
|
||||||
self.viewer.draw_circle(
|
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color1).add_attr(t)
|
||||||
f.shape.radius, 20, color=obj.color1
|
self.viewer.draw_circle(f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2).add_attr(t)
|
||||||
).add_attr(t)
|
|
||||||
self.viewer.draw_circle(
|
|
||||||
f.shape.radius, 20, color=obj.color2, filled=False, linewidth=2
|
|
||||||
).add_attr(t)
|
|
||||||
else:
|
else:
|
||||||
path = [trans * v for v in f.shape.vertices]
|
path = [trans * v for v in f.shape.vertices]
|
||||||
self.viewer.draw_polygon(path, color=obj.color1)
|
self.viewer.draw_polygon(path, color=obj.color1)
|
||||||
@@ -479,18 +451,14 @@ def heuristic(env, s):
|
|||||||
angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad
|
angle_targ = 0.4 # more than 0.4 radians (22 degrees) is bad
|
||||||
if angle_targ < -0.4:
|
if angle_targ < -0.4:
|
||||||
angle_targ = -0.4
|
angle_targ = -0.4
|
||||||
hover_targ = 0.55 * np.abs(
|
hover_targ = 0.55 * np.abs(s[0]) # target y should be proportional to horizontal offset
|
||||||
s[0]
|
|
||||||
) # target y should be proportional to horizontal offset
|
|
||||||
|
|
||||||
angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
|
angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
|
||||||
hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5
|
hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5
|
||||||
|
|
||||||
if s[6] or s[7]: # legs have contact
|
if s[6] or s[7]: # legs have contact
|
||||||
angle_todo = 0
|
angle_todo = 0
|
||||||
hover_todo = (
|
hover_todo = -(s[3]) * 0.5 # override to reduce fall speed, that's all we need after contact
|
||||||
-(s[3]) * 0.5
|
|
||||||
) # override to reduce fall speed, that's all we need after contact
|
|
||||||
|
|
||||||
if env.continuous:
|
if env.continuous:
|
||||||
a = np.array([hover_todo * 20 - 1, -angle_todo * 20])
|
a = np.array([hover_todo * 20 - 1, -angle_todo * 20])
|
||||||
|
@@ -88,9 +88,7 @@ class AcrobotEnv(core.Env):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
high = np.array(
|
high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32)
|
||||||
[1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32
|
|
||||||
)
|
|
||||||
low = -high
|
low = -high
|
||||||
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)
|
||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
@@ -111,9 +109,7 @@ class AcrobotEnv(core.Env):
|
|||||||
|
|
||||||
# Add noise to the force action
|
# Add noise to the force action
|
||||||
if self.torque_noise_max > 0:
|
if self.torque_noise_max > 0:
|
||||||
torque += self.np_random.uniform(
|
torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max)
|
||||||
-self.torque_noise_max, self.torque_noise_max
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now, augment the state with our force action so it can be passed to
|
# Now, augment the state with our force action so it can be passed to
|
||||||
# _dsdt
|
# _dsdt
|
||||||
@@ -160,12 +156,7 @@ class AcrobotEnv(core.Env):
|
|||||||
theta2 = s[1]
|
theta2 = s[1]
|
||||||
dtheta1 = s[2]
|
dtheta1 = s[2]
|
||||||
dtheta2 = s[3]
|
dtheta2 = s[3]
|
||||||
d1 = (
|
d1 = m1 * lc1 ** 2 + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2
|
||||||
m1 * lc1 ** 2
|
|
||||||
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2))
|
|
||||||
+ I1
|
|
||||||
+ I2
|
|
||||||
)
|
|
||||||
d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2
|
d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2
|
||||||
phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
|
phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
|
||||||
phi1 = (
|
phi1 = (
|
||||||
@@ -181,9 +172,9 @@ class AcrobotEnv(core.Env):
|
|||||||
else:
|
else:
|
||||||
# the following line is consistent with the java implementation and the
|
# the following line is consistent with the java implementation and the
|
||||||
# book
|
# book
|
||||||
ddtheta2 = (
|
ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2) / (
|
||||||
a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2
|
m2 * lc2 ** 2 + I2 - d2 ** 2 / d1
|
||||||
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
|
)
|
||||||
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
|
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
|
||||||
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)
|
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)
|
||||||
|
|
||||||
|
@@ -111,9 +111,7 @@ class CartPoleEnv(gym.Env):
|
|||||||
|
|
||||||
# For the interested reader:
|
# For the interested reader:
|
||||||
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
||||||
temp = (
|
temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
|
||||||
force + self.polemass_length * theta_dot ** 2 * sintheta
|
|
||||||
) / self.total_mass
|
|
||||||
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
||||||
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
|
self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
|
||||||
)
|
)
|
||||||
|
@@ -62,27 +62,17 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
self.min_position = -1.2
|
self.min_position = -1.2
|
||||||
self.max_position = 0.6
|
self.max_position = 0.6
|
||||||
self.max_speed = 0.07
|
self.max_speed = 0.07
|
||||||
self.goal_position = (
|
self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
|
||||||
0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
|
|
||||||
)
|
|
||||||
self.goal_velocity = goal_velocity
|
self.goal_velocity = goal_velocity
|
||||||
self.power = 0.0015
|
self.power = 0.0015
|
||||||
|
|
||||||
self.low_state = np.array(
|
self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
|
||||||
[self.min_position, -self.max_speed], dtype=np.float32
|
self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)
|
||||||
)
|
|
||||||
self.high_state = np.array(
|
|
||||||
[self.max_position, self.max_speed], dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
self.action_space = spaces.Box(
|
self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32)
|
||||||
low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32
|
self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32)
|
||||||
)
|
|
||||||
self.observation_space = spaces.Box(
|
|
||||||
low=self.low_state, high=self.high_state, dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
self.seed()
|
self.seed()
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -159,15 +149,11 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
self.viewer.add_geom(car)
|
self.viewer.add_geom(car)
|
||||||
frontwheel = rendering.make_circle(carheight / 2.5)
|
frontwheel = rendering.make_circle(carheight / 2.5)
|
||||||
frontwheel.set_color(0.5, 0.5, 0.5)
|
frontwheel.set_color(0.5, 0.5, 0.5)
|
||||||
frontwheel.add_attr(
|
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
|
||||||
rendering.Transform(translation=(carwidth / 4, clearance))
|
|
||||||
)
|
|
||||||
frontwheel.add_attr(self.cartrans)
|
frontwheel.add_attr(self.cartrans)
|
||||||
self.viewer.add_geom(frontwheel)
|
self.viewer.add_geom(frontwheel)
|
||||||
backwheel = rendering.make_circle(carheight / 2.5)
|
backwheel = rendering.make_circle(carheight / 2.5)
|
||||||
backwheel.add_attr(
|
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
|
||||||
rendering.Transform(translation=(-carwidth / 4, clearance))
|
|
||||||
)
|
|
||||||
backwheel.add_attr(self.cartrans)
|
backwheel.add_attr(self.cartrans)
|
||||||
backwheel.set_color(0.5, 0.5, 0.5)
|
backwheel.set_color(0.5, 0.5, 0.5)
|
||||||
self.viewer.add_geom(backwheel)
|
self.viewer.add_geom(backwheel)
|
||||||
@@ -176,16 +162,12 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
flagy2 = flagy1 + 50
|
flagy2 = flagy1 + 50
|
||||||
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
|
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
|
||||||
self.viewer.add_geom(flagpole)
|
self.viewer.add_geom(flagpole)
|
||||||
flag = rendering.FilledPolygon(
|
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
|
||||||
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
|
|
||||||
)
|
|
||||||
flag.set_color(0.8, 0.8, 0)
|
flag.set_color(0.8, 0.8, 0)
|
||||||
self.viewer.add_geom(flag)
|
self.viewer.add_geom(flag)
|
||||||
|
|
||||||
pos = self.state[0]
|
pos = self.state[0]
|
||||||
self.cartrans.set_translation(
|
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
|
||||||
(pos - self.min_position) * scale, self._height(pos) * scale
|
|
||||||
)
|
|
||||||
self.cartrans.set_rotation(math.cos(3 * pos))
|
self.cartrans.set_rotation(math.cos(3 * pos))
|
||||||
|
|
||||||
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
||||||
|
@@ -136,15 +136,11 @@ class MountainCarEnv(gym.Env):
|
|||||||
self.viewer.add_geom(car)
|
self.viewer.add_geom(car)
|
||||||
frontwheel = rendering.make_circle(carheight / 2.5)
|
frontwheel = rendering.make_circle(carheight / 2.5)
|
||||||
frontwheel.set_color(0.5, 0.5, 0.5)
|
frontwheel.set_color(0.5, 0.5, 0.5)
|
||||||
frontwheel.add_attr(
|
frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
|
||||||
rendering.Transform(translation=(carwidth / 4, clearance))
|
|
||||||
)
|
|
||||||
frontwheel.add_attr(self.cartrans)
|
frontwheel.add_attr(self.cartrans)
|
||||||
self.viewer.add_geom(frontwheel)
|
self.viewer.add_geom(frontwheel)
|
||||||
backwheel = rendering.make_circle(carheight / 2.5)
|
backwheel = rendering.make_circle(carheight / 2.5)
|
||||||
backwheel.add_attr(
|
backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
|
||||||
rendering.Transform(translation=(-carwidth / 4, clearance))
|
|
||||||
)
|
|
||||||
backwheel.add_attr(self.cartrans)
|
backwheel.add_attr(self.cartrans)
|
||||||
backwheel.set_color(0.5, 0.5, 0.5)
|
backwheel.set_color(0.5, 0.5, 0.5)
|
||||||
self.viewer.add_geom(backwheel)
|
self.viewer.add_geom(backwheel)
|
||||||
@@ -153,16 +149,12 @@ class MountainCarEnv(gym.Env):
|
|||||||
flagy2 = flagy1 + 50
|
flagy2 = flagy1 + 50
|
||||||
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
|
flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
|
||||||
self.viewer.add_geom(flagpole)
|
self.viewer.add_geom(flagpole)
|
||||||
flag = rendering.FilledPolygon(
|
flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)])
|
||||||
[(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
|
|
||||||
)
|
|
||||||
flag.set_color(0.8, 0.8, 0)
|
flag.set_color(0.8, 0.8, 0)
|
||||||
self.viewer.add_geom(flag)
|
self.viewer.add_geom(flag)
|
||||||
|
|
||||||
pos = self.state[0]
|
pos = self.state[0]
|
||||||
self.cartrans.set_translation(
|
self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
|
||||||
(pos - self.min_position) * scale, self._height(pos) * scale
|
|
||||||
)
|
|
||||||
self.cartrans.set_rotation(math.cos(3 * pos))
|
self.cartrans.set_rotation(math.cos(3 * pos))
|
||||||
|
|
||||||
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
return self.viewer.render(return_rgb_array=mode == "rgb_array")
|
||||||
|
@@ -18,9 +18,7 @@ class PendulumEnv(gym.Env):
|
|||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
|
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
|
||||||
self.action_space = spaces.Box(
|
self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
|
||||||
low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32
|
|
||||||
)
|
|
||||||
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
|
self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
|
||||||
|
|
||||||
self.seed()
|
self.seed()
|
||||||
@@ -41,10 +39,7 @@ class PendulumEnv(gym.Env):
|
|||||||
self.last_u = u # for rendering
|
self.last_u = u # for rendering
|
||||||
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2)
|
costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2)
|
||||||
|
|
||||||
newthdot = (
|
newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
|
||||||
thdot
|
|
||||||
+ (-3 * g / (2 * l) * np.sin(th + np.pi) + 3.0 / (m * l ** 2) * u) * dt
|
|
||||||
)
|
|
||||||
newth = th + newthdot * dt
|
newth = th + newthdot * dt
|
||||||
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)
|
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)
|
||||||
|
|
||||||
|
@@ -54,11 +54,7 @@ def get_display(spec):
|
|||||||
elif isinstance(spec, str):
|
elif isinstance(spec, str):
|
||||||
return pyglet.canvas.Display(spec)
|
return pyglet.canvas.Display(spec)
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error("Invalid display specification: {}. (Must be a string like :0 or None.)".format(spec))
|
||||||
"Invalid display specification: {}. (Must be a string like :0 or None.)".format(
|
|
||||||
spec
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_window(width, height, display, **kwargs):
|
def get_window(width, height, display, **kwargs):
|
||||||
@@ -69,14 +65,7 @@ def get_window(width, height, display, **kwargs):
|
|||||||
config = screen[0].get_best_config() # selecting the first screen
|
config = screen[0].get_best_config() # selecting the first screen
|
||||||
context = config.create_context(None) # create GL context
|
context = config.create_context(None) # create GL context
|
||||||
|
|
||||||
return pyglet.window.Window(
|
return pyglet.window.Window(width=width, height=height, display=display, config=config, context=context, **kwargs)
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
display=display,
|
|
||||||
config=config,
|
|
||||||
context=context,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Viewer(object):
|
class Viewer(object):
|
||||||
@@ -108,9 +97,7 @@ class Viewer(object):
|
|||||||
assert right > left and top > bottom
|
assert right > left and top > bottom
|
||||||
scalex = self.width / (right - left)
|
scalex = self.width / (right - left)
|
||||||
scaley = self.height / (top - bottom)
|
scaley = self.height / (top - bottom)
|
||||||
self.transform = Transform(
|
self.transform = Transform(translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley))
|
||||||
translation=(-left * scalex, -bottom * scaley), scale=(scalex, scaley)
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_geom(self, geom):
|
def add_geom(self, geom):
|
||||||
self.geoms.append(geom)
|
self.geoms.append(geom)
|
||||||
@@ -173,9 +160,7 @@ class Viewer(object):
|
|||||||
|
|
||||||
def get_array(self):
|
def get_array(self):
|
||||||
self.window.flip()
|
self.window.flip()
|
||||||
image_data = (
|
image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
||||||
pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
|
|
||||||
)
|
|
||||||
self.window.flip()
|
self.window.flip()
|
||||||
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
|
arr = np.fromstring(image_data.get_data(), dtype=np.uint8, sep="")
|
||||||
arr = arr.reshape(self.height, self.width, 4)
|
arr = arr.reshape(self.height, self.width, 4)
|
||||||
@@ -230,9 +215,7 @@ class Transform(Attr):
|
|||||||
|
|
||||||
def enable(self):
|
def enable(self):
|
||||||
glPushMatrix()
|
glPushMatrix()
|
||||||
glTranslatef(
|
glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint
|
||||||
self.translation[0], self.translation[1], 0
|
|
||||||
) # translate to GL loc ppint
|
|
||||||
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
|
glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0)
|
||||||
glScalef(self.scale[0], self.scale[1], 1)
|
glScalef(self.scale[0], self.scale[1], 1)
|
||||||
|
|
||||||
@@ -392,9 +375,7 @@ class Image(Geom):
|
|||||||
self.flip = False
|
self.flip = False
|
||||||
|
|
||||||
def render1(self):
|
def render1(self):
|
||||||
self.img.blit(
|
self.img.blit(-self.width / 2, -self.height / 2, width=self.width, height=self.height)
|
||||||
-self.width / 2, -self.height / 2, width=self.width, height=self.height
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
@@ -435,9 +416,7 @@ class SimpleImageViewer(object):
|
|||||||
self.isopen = False
|
self.isopen = False
|
||||||
|
|
||||||
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
|
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
|
||||||
image = pyglet.image.ImageData(
|
image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3)
|
||||||
arr.shape[1], arr.shape[0], "RGB", arr.tobytes(), pitch=arr.shape[1] * -3
|
|
||||||
)
|
|
||||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
|
||||||
texture = image.get_texture()
|
texture = image.get_texture()
|
||||||
texture.width = self.width
|
texture.width = self.width
|
||||||
|
@@ -14,9 +14,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
xposafter = self.get_body_com("torso")[0]
|
xposafter = self.get_body_com("torso")[0]
|
||||||
forward_reward = (xposafter - xposbefore) / self.dt
|
forward_reward = (xposafter - xposbefore) / self.dt
|
||||||
ctrl_cost = 0.5 * np.square(a).sum()
|
ctrl_cost = 0.5 * np.square(a).sum()
|
||||||
contact_cost = (
|
contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
|
||||||
0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
|
|
||||||
)
|
|
||||||
survive_reward = 1.0
|
survive_reward = 1.0
|
||||||
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
|
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
|
||||||
state = self.state_vector()
|
state = self.state_vector()
|
||||||
@@ -45,9 +43,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.1, high=0.1)
|
||||||
size=self.model.nq, low=-0.1, high=0.1
|
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -34,18 +34,13 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._reset_noise_scale = reset_noise_scale
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
self._exclude_current_positions_from_observation = (
|
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||||
exclude_current_positions_from_observation
|
|
||||||
)
|
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthy_reward(self):
|
def healthy_reward(self):
|
||||||
return (
|
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
|
||||||
* self._healthy_reward
|
|
||||||
)
|
|
||||||
|
|
||||||
def control_cost(self, action):
|
def control_cost(self, action):
|
||||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
||||||
@@ -60,9 +55,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def contact_cost(self):
|
def contact_cost(self):
|
||||||
contact_cost = self._contact_cost_weight * np.sum(
|
contact_cost = self._contact_cost_weight * np.sum(np.square(self.contact_forces))
|
||||||
np.square(self.contact_forces)
|
|
||||||
)
|
|
||||||
return contact_cost
|
return contact_cost
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -128,12 +121,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
noise_low = -self._reset_noise_scale
|
noise_low = -self._reset_noise_scale
|
||||||
noise_high = self._reset_noise_scale
|
noise_high = self._reset_noise_scale
|
||||||
|
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
|
||||||
self.model.nv
|
|
||||||
)
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
|
@@ -28,9 +28,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
|
||||||
low=-0.1, high=0.1, size=self.model.nq
|
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -25,9 +25,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._reset_noise_scale = reset_noise_scale
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
self._exclude_current_positions_from_observation = (
|
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||||
exclude_current_positions_from_observation
|
|
||||||
)
|
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
||||||
|
|
||||||
@@ -71,12 +69,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
noise_low = -self._reset_noise_scale
|
noise_low = -self._reset_noise_scale
|
||||||
noise_high = self._reset_noise_scale
|
noise_high = self._reset_noise_scale
|
||||||
|
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
|
|
||||||
self.model.nv
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
|
@@ -17,27 +17,16 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
reward += alive_bonus
|
reward += alive_bonus
|
||||||
reward -= 1e-3 * np.square(a).sum()
|
reward -= 1e-3 * np.square(a).sum()
|
||||||
s = self.state_vector()
|
s = self.state_vector()
|
||||||
done = not (
|
done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2))
|
||||||
np.isfinite(s).all()
|
|
||||||
and (np.abs(s[2:]) < 100).all()
|
|
||||||
and (height > 0.7)
|
|
||||||
and (abs(ang) < 0.2)
|
|
||||||
)
|
|
||||||
ob = self._get_obs()
|
ob = self._get_obs()
|
||||||
return ob, reward, done, {}
|
return ob, reward, done, {}
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return np.concatenate(
|
return np.concatenate([self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)])
|
||||||
[self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq)
|
||||||
low=-0.005, high=0.005, size=self.model.nq
|
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
|
||||||
low=-0.005, high=0.005, size=self.model.nv
|
|
||||||
)
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -40,18 +40,13 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._reset_noise_scale = reset_noise_scale
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
self._exclude_current_positions_from_observation = (
|
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||||
exclude_current_positions_from_observation
|
|
||||||
)
|
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthy_reward(self):
|
def healthy_reward(self):
|
||||||
return (
|
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
|
||||||
* self._healthy_reward
|
|
||||||
)
|
|
||||||
|
|
||||||
def control_cost(self, action):
|
def control_cost(self, action):
|
||||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
||||||
@@ -117,12 +112,8 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
noise_low = -self._reset_noise_scale
|
noise_low = -self._reset_noise_scale
|
||||||
noise_high = self._reset_noise_scale
|
noise_high = self._reset_noise_scale
|
||||||
|
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
|
||||||
low=noise_low, high=noise_high, size=self.model.nv
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
|
@@ -43,18 +43,13 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._reset_noise_scale = reset_noise_scale
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
self._exclude_current_positions_from_observation = (
|
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||||
exclude_current_positions_from_observation
|
|
||||||
)
|
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthy_reward(self):
|
def healthy_reward(self):
|
||||||
return (
|
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
|
||||||
* self._healthy_reward
|
|
||||||
)
|
|
||||||
|
|
||||||
def control_cost(self, action):
|
def control_cost(self, action):
|
||||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl))
|
control_cost = self._ctrl_cost_weight * np.sum(np.square(self.sim.data.ctrl))
|
||||||
@@ -143,12 +138,8 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
noise_low = -self._reset_noise_scale
|
noise_low = -self._reset_noise_scale
|
||||||
noise_high = self._reset_noise_scale
|
noise_high = self._reset_noise_scale
|
||||||
|
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
|
||||||
low=noise_low, high=noise_high, size=self.model.nv
|
|
||||||
)
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
|
@@ -33,8 +33,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self.set_state(
|
self.set_state(
|
||||||
self.init_qpos
|
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
|
||||||
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
|
self.init_qvel + self.np_random.randn(self.model.nv) * 0.1,
|
||||||
)
|
)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -17,12 +17,8 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return ob, reward, done, {}
|
return ob, reward, done, {}
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
|
||||||
size=self.model.nq, low=-0.01, high=0.01
|
qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
|
||||||
size=self.model.nv, low=-0.01, high=0.01
|
|
||||||
)
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -22,14 +22,7 @@ DEFAULT_SIZE = 500
|
|||||||
|
|
||||||
def convert_observation_to_space(observation):
|
def convert_observation_to_space(observation):
|
||||||
if isinstance(observation, dict):
|
if isinstance(observation, dict):
|
||||||
space = spaces.Dict(
|
space = spaces.Dict(OrderedDict([(key, convert_observation_to_space(value)) for key, value in observation.items()]))
|
||||||
OrderedDict(
|
|
||||||
[
|
|
||||||
(key, convert_observation_to_space(value))
|
|
||||||
for key, value in observation.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(observation, np.ndarray):
|
elif isinstance(observation, np.ndarray):
|
||||||
low = np.full(observation.shape, -float("inf"), dtype=np.float32)
|
low = np.full(observation.shape, -float("inf"), dtype=np.float32)
|
||||||
high = np.full(observation.shape, float("inf"), dtype=np.float32)
|
high = np.full(observation.shape, float("inf"), dtype=np.float32)
|
||||||
@@ -117,9 +110,7 @@ class MujocoEnv(gym.Env):
|
|||||||
def set_state(self, qpos, qvel):
|
def set_state(self, qpos, qvel):
|
||||||
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
||||||
old_state = self.sim.get_state()
|
old_state = self.sim.get_state()
|
||||||
new_state = mujoco_py.MjSimState(
|
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state)
|
||||||
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
|
|
||||||
)
|
|
||||||
self.sim.set_state(new_state)
|
self.sim.set_state(new_state)
|
||||||
self.sim.forward()
|
self.sim.forward()
|
||||||
|
|
||||||
@@ -142,10 +133,7 @@ class MujocoEnv(gym.Env):
|
|||||||
):
|
):
|
||||||
if mode == "rgb_array" or mode == "depth_array":
|
if mode == "rgb_array" or mode == "depth_array":
|
||||||
if camera_id is not None and camera_name is not None:
|
if camera_id is not None and camera_name is not None:
|
||||||
raise ValueError(
|
raise ValueError("Both `camera_id` and `camera_name` cannot be" " specified at the same time.")
|
||||||
"Both `camera_id` and `camera_name` cannot be"
|
|
||||||
" specified at the same time."
|
|
||||||
)
|
|
||||||
|
|
||||||
no_camera_specified = camera_name is None and camera_id is None
|
no_camera_specified = camera_name is None and camera_id is None
|
||||||
if no_camera_specified:
|
if no_camera_specified:
|
||||||
|
@@ -44,9 +44,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
qpos[-4:-2] = self.cylinder_pos
|
qpos[-4:-2] = self.cylinder_pos
|
||||||
qpos[-2:] = self.goal_pos
|
qpos[-2:] = self.goal_pos
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||||
low=-0.005, high=0.005, size=self.model.nv
|
|
||||||
)
|
|
||||||
qvel[-4:] = 0
|
qvel[-4:] = 0
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -22,18 +22,13 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = (
|
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
||||||
self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
|
|
||||||
+ self.init_qpos
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
|
self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
|
||||||
if np.linalg.norm(self.goal) < 0.2:
|
if np.linalg.norm(self.goal) < 0.2:
|
||||||
break
|
break
|
||||||
qpos[-2:] = self.goal
|
qpos[-2:] = self.goal
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||||
low=-0.005, high=0.005, size=self.model.nv
|
|
||||||
)
|
|
||||||
qvel[-2:] = 0
|
qvel[-2:] = 0
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -62,9 +62,7 @@ class StrikerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
diff = self.ball - self.goal
|
diff = self.ball - self.goal
|
||||||
angle = -np.arctan(diff[0] / (diff[1] + 1e-8))
|
angle = -np.arctan(diff[0] / (diff[1] + 1e-8))
|
||||||
qpos[-1] = angle / 3.14
|
qpos[-1] = angle / 3.14
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
qvel = self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv)
|
||||||
low=-0.1, high=0.1, size=self.model.nv
|
|
||||||
)
|
|
||||||
qvel[7:] = 0
|
qvel[7:] = 0
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -26,9 +26,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self.set_state(
|
self.set_state(
|
||||||
self.init_qpos
|
self.init_qpos + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
||||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
|
self.init_qvel + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
|
||||||
self.init_qvel
|
|
||||||
+ self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
|
|
||||||
)
|
)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -22,9 +22,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._reset_noise_scale = reset_noise_scale
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
self._exclude_current_positions_from_observation = (
|
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||||
exclude_current_positions_from_observation
|
|
||||||
)
|
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
||||||
|
|
||||||
@@ -74,12 +72,8 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
noise_low = -self._reset_noise_scale
|
noise_low = -self._reset_noise_scale
|
||||||
noise_high = self._reset_noise_scale
|
noise_high = self._reset_noise_scale
|
||||||
|
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
|
||||||
low=noise_low, high=noise_high, size=self.model.nv
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
|
@@ -48,9 +48,7 @@ class ThrowerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
)
|
)
|
||||||
|
|
||||||
qpos[-9:-7] = self.goal
|
qpos[-9:-7] = self.goal
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
qvel = self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
|
||||||
low=-0.005, high=0.005, size=self.model.nv
|
|
||||||
)
|
|
||||||
qvel[7:] = 0
|
qvel[7:] = 0
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
@@ -27,10 +27,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self.set_state(
|
self.set_state(
|
||||||
self.init_qpos
|
self.init_qpos + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
|
||||||
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
|
self.init_qvel + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
|
||||||
self.init_qvel
|
|
||||||
+ self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
|
|
||||||
)
|
)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
@@ -37,18 +37,13 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._reset_noise_scale = reset_noise_scale
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
self._exclude_current_positions_from_observation = (
|
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
|
||||||
exclude_current_positions_from_observation
|
|
||||||
)
|
|
||||||
|
|
||||||
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthy_reward(self):
|
def healthy_reward(self):
|
||||||
return (
|
return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward
|
||||||
float(self.is_healthy or self._terminate_when_unhealthy)
|
|
||||||
* self._healthy_reward
|
|
||||||
)
|
|
||||||
|
|
||||||
def control_cost(self, action):
|
def control_cost(self, action):
|
||||||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
|
||||||
@@ -110,12 +105,8 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
noise_low = -self._reset_noise_scale
|
noise_low = -self._reset_noise_scale
|
||||||
noise_high = self._reset_noise_scale
|
noise_high = self._reset_noise_scale
|
||||||
|
|
||||||
qpos = self.init_qpos + self.np_random.uniform(
|
qpos = self.init_qpos + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||||
low=noise_low, high=noise_high, size=self.model.nq
|
qvel = self.init_qvel + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
|
||||||
)
|
|
||||||
qvel = self.init_qvel + self.np_random.uniform(
|
|
||||||
low=noise_low, high=noise_high, size=self.model.nv
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
|
@@ -108,11 +108,7 @@ class EnvRegistry(object):
|
|||||||
# reset/step. Set _gym_disable_underscore_compat = True on
|
# reset/step. Set _gym_disable_underscore_compat = True on
|
||||||
# your environment if you use these methods and don't want
|
# your environment if you use these methods and don't want
|
||||||
# compatibility code to be invoked.
|
# compatibility code to be invoked.
|
||||||
if (
|
if hasattr(env, "_reset") and hasattr(env, "_step") and not getattr(env, "_gym_disable_underscore_compat", False):
|
||||||
hasattr(env, "_reset")
|
|
||||||
and hasattr(env, "_step")
|
|
||||||
and not getattr(env, "_gym_disable_underscore_compat", False)
|
|
||||||
):
|
|
||||||
patch_deprecated_methods(env)
|
patch_deprecated_methods(env)
|
||||||
if env.spec.max_episode_steps is not None:
|
if env.spec.max_episode_steps is not None:
|
||||||
from gym.wrappers.time_limit import TimeLimit
|
from gym.wrappers.time_limit import TimeLimit
|
||||||
@@ -158,11 +154,7 @@ class EnvRegistry(object):
|
|||||||
if env_name == valid_env_spec._env_name
|
if env_name == valid_env_spec._env_name
|
||||||
]
|
]
|
||||||
if matching_envs:
|
if matching_envs:
|
||||||
raise error.DeprecatedEnv(
|
raise error.DeprecatedEnv("Env {} not found (valid versions include {})".format(id, matching_envs))
|
||||||
"Env {} not found (valid versions include {})".format(
|
|
||||||
id, matching_envs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise error.UnregisteredEnv("No registered env with id: {}".format(id))
|
raise error.UnregisteredEnv("No registered env with id: {}".format(id))
|
||||||
|
|
||||||
|
@@ -81,9 +81,7 @@ class FetchEnv(robot_env.RobotEnv):
|
|||||||
|
|
||||||
def _set_action(self, action):
|
def _set_action(self, action):
|
||||||
assert action.shape == (4,)
|
assert action.shape == (4,)
|
||||||
action = (
|
action = action.copy() # ensure that we don't change the action outside of this scope
|
||||||
action.copy()
|
|
||||||
) # ensure that we don't change the action outside of this scope
|
|
||||||
pos_ctrl, gripper_ctrl = action[:3], action[3]
|
pos_ctrl, gripper_ctrl = action[:3], action[3]
|
||||||
|
|
||||||
pos_ctrl *= 0.05 # limit maximum change in position
|
pos_ctrl *= 0.05 # limit maximum change in position
|
||||||
@@ -120,13 +118,9 @@ class FetchEnv(robot_env.RobotEnv):
|
|||||||
object_rel_pos = object_pos - grip_pos
|
object_rel_pos = object_pos - grip_pos
|
||||||
object_velp -= grip_velp
|
object_velp -= grip_velp
|
||||||
else:
|
else:
|
||||||
object_pos = (
|
object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
|
||||||
object_rot
|
|
||||||
) = object_velp = object_velr = object_rel_pos = np.zeros(0)
|
|
||||||
gripper_state = robot_qpos[-2:]
|
gripper_state = robot_qpos[-2:]
|
||||||
gripper_vel = (
|
gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric
|
||||||
robot_qvel[-2:] * dt
|
|
||||||
) # change to a scalar if the gripper is made symmetric
|
|
||||||
|
|
||||||
if not self.has_object:
|
if not self.has_object:
|
||||||
achieved_goal = grip_pos.copy()
|
achieved_goal = grip_pos.copy()
|
||||||
@@ -175,9 +169,7 @@ class FetchEnv(robot_env.RobotEnv):
|
|||||||
if self.has_object:
|
if self.has_object:
|
||||||
object_xpos = self.initial_gripper_xpos[:2]
|
object_xpos = self.initial_gripper_xpos[:2]
|
||||||
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
|
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
|
||||||
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(
|
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2)
|
||||||
-self.obj_range, self.obj_range, size=2
|
|
||||||
)
|
|
||||||
object_qpos = self.sim.data.get_joint_qpos("object0:joint")
|
object_qpos = self.sim.data.get_joint_qpos("object0:joint")
|
||||||
assert object_qpos.shape == (7,)
|
assert object_qpos.shape == (7,)
|
||||||
object_qpos[:2] = object_xpos
|
object_qpos[:2] = object_xpos
|
||||||
@@ -188,17 +180,13 @@ class FetchEnv(robot_env.RobotEnv):
|
|||||||
|
|
||||||
def _sample_goal(self):
|
def _sample_goal(self):
|
||||||
if self.has_object:
|
if self.has_object:
|
||||||
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
|
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
|
||||||
-self.target_range, self.target_range, size=3
|
|
||||||
)
|
|
||||||
goal += self.target_offset
|
goal += self.target_offset
|
||||||
goal[2] = self.height_offset
|
goal[2] = self.height_offset
|
||||||
if self.target_in_the_air and self.np_random.uniform() < 0.5:
|
if self.target_in_the_air and self.np_random.uniform() < 0.5:
|
||||||
goal[2] += self.np_random.uniform(0, 0.45)
|
goal[2] += self.np_random.uniform(0, 0.45)
|
||||||
else:
|
else:
|
||||||
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
|
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
|
||||||
-self.target_range, self.target_range, size=3
|
|
||||||
)
|
|
||||||
return goal.copy()
|
return goal.copy()
|
||||||
|
|
||||||
def _is_success(self, achieved_goal, desired_goal):
|
def _is_success(self, achieved_goal, desired_goal):
|
||||||
@@ -212,9 +200,9 @@ class FetchEnv(robot_env.RobotEnv):
|
|||||||
self.sim.forward()
|
self.sim.forward()
|
||||||
|
|
||||||
# Move end effector into position.
|
# Move end effector into position.
|
||||||
gripper_target = np.array(
|
gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos(
|
||||||
[-0.498, 0.005, -0.431 + self.gripper_extra_height]
|
"robot0:grip"
|
||||||
) + self.sim.data.get_site_xpos("robot0:grip")
|
)
|
||||||
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
|
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
|
||||||
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
|
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
|
||||||
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)
|
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)
|
||||||
|
@@ -74,9 +74,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
self.target_position = target_position
|
self.target_position = target_position
|
||||||
self.target_rotation = target_rotation
|
self.target_rotation = target_rotation
|
||||||
self.target_position_range = target_position_range
|
self.target_position_range = target_position_range
|
||||||
self.parallel_quats = [
|
self.parallel_quats = [rotations.euler2quat(r) for r in rotations.get_parallel_rotations()]
|
||||||
rotations.euler2quat(r) for r in rotations.get_parallel_rotations()
|
|
||||||
]
|
|
||||||
self.randomize_initial_rotation = randomize_initial_rotation
|
self.randomize_initial_rotation = randomize_initial_rotation
|
||||||
self.randomize_initial_position = randomize_initial_position
|
self.randomize_initial_position = randomize_initial_position
|
||||||
self.distance_threshold = distance_threshold
|
self.distance_threshold = distance_threshold
|
||||||
@@ -182,9 +180,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
angle = self.np_random.uniform(-np.pi, np.pi)
|
angle = self.np_random.uniform(-np.pi, np.pi)
|
||||||
axis = np.array([0.0, 0.0, 1.0])
|
axis = np.array([0.0, 0.0, 1.0])
|
||||||
z_quat = quat_from_angle_and_axis(angle, axis)
|
z_quat = quat_from_angle_and_axis(angle, axis)
|
||||||
parallel_quat = self.parallel_quats[
|
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
|
||||||
self.np_random.randint(len(self.parallel_quats))
|
|
||||||
]
|
|
||||||
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
|
offset_quat = rotations.quat_mul(z_quat, parallel_quat)
|
||||||
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
|
initial_quat = rotations.quat_mul(initial_quat, offset_quat)
|
||||||
elif self.target_rotation in ["xyz", "ignore"]:
|
elif self.target_rotation in ["xyz", "ignore"]:
|
||||||
@@ -195,9 +191,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
elif self.target_rotation == "fixed":
|
elif self.target_rotation == "fixed":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
|
||||||
'Unknown target_rotation option "{}".'.format(self.target_rotation)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Randomize initial position.
|
# Randomize initial position.
|
||||||
if self.randomize_initial_position:
|
if self.randomize_initial_position:
|
||||||
@@ -229,17 +223,13 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
target_pos = None
|
target_pos = None
|
||||||
if self.target_position == "random":
|
if self.target_position == "random":
|
||||||
assert self.target_position_range.shape == (3, 2)
|
assert self.target_position_range.shape == (3, 2)
|
||||||
offset = self.np_random.uniform(
|
offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1])
|
||||||
self.target_position_range[:, 0], self.target_position_range[:, 1]
|
|
||||||
)
|
|
||||||
assert offset.shape == (3,)
|
assert offset.shape == (3,)
|
||||||
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
|
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3] + offset
|
||||||
elif self.target_position in ["ignore", "fixed"]:
|
elif self.target_position in ["ignore", "fixed"]:
|
||||||
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
|
target_pos = self.sim.data.get_joint_qpos("object:joint")[:3]
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
|
||||||
'Unknown target_position option "{}".'.format(self.target_position)
|
|
||||||
)
|
|
||||||
assert target_pos is not None
|
assert target_pos is not None
|
||||||
assert target_pos.shape == (3,)
|
assert target_pos.shape == (3,)
|
||||||
|
|
||||||
@@ -253,9 +243,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
angle = self.np_random.uniform(-np.pi, np.pi)
|
angle = self.np_random.uniform(-np.pi, np.pi)
|
||||||
axis = np.array([0.0, 0.0, 1.0])
|
axis = np.array([0.0, 0.0, 1.0])
|
||||||
target_quat = quat_from_angle_and_axis(angle, axis)
|
target_quat = quat_from_angle_and_axis(angle, axis)
|
||||||
parallel_quat = self.parallel_quats[
|
parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
|
||||||
self.np_random.randint(len(self.parallel_quats))
|
|
||||||
]
|
|
||||||
target_quat = rotations.quat_mul(target_quat, parallel_quat)
|
target_quat = rotations.quat_mul(target_quat, parallel_quat)
|
||||||
elif self.target_rotation == "xyz":
|
elif self.target_rotation == "xyz":
|
||||||
angle = self.np_random.uniform(-np.pi, np.pi)
|
angle = self.np_random.uniform(-np.pi, np.pi)
|
||||||
@@ -264,9 +252,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
elif self.target_rotation in ["ignore", "fixed"]:
|
elif self.target_rotation in ["ignore", "fixed"]:
|
||||||
target_quat = self.sim.data.get_joint_qpos("object:joint")
|
target_quat = self.sim.data.get_joint_qpos("object:joint")
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
|
||||||
'Unknown target_rotation option "{}".'.format(self.target_rotation)
|
|
||||||
)
|
|
||||||
assert target_quat is not None
|
assert target_quat is not None
|
||||||
assert target_quat.shape == (4,)
|
assert target_quat.shape == (4,)
|
||||||
|
|
||||||
@@ -293,12 +279,8 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
robot_qpos, robot_qvel = robot_get_obs(self.sim)
|
robot_qpos, robot_qvel = robot_get_obs(self.sim)
|
||||||
object_qvel = self.sim.data.get_joint_qvel("object:joint")
|
object_qvel = self.sim.data.get_joint_qvel("object:joint")
|
||||||
achieved_goal = (
|
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
|
||||||
self._get_achieved_goal().ravel()
|
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, achieved_goal])
|
||||||
) # this contains the object position + rotation
|
|
||||||
observation = np.concatenate(
|
|
||||||
[robot_qpos, robot_qvel, object_qvel, achieved_goal]
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"observation": observation.copy(),
|
"observation": observation.copy(),
|
||||||
"achieved_goal": achieved_goal.copy(),
|
"achieved_goal": achieved_goal.copy(),
|
||||||
@@ -307,9 +289,7 @@ class ManipulateEnv(hand_env.HandEnv):
|
|||||||
|
|
||||||
|
|
||||||
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
||||||
def __init__(
|
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
|
||||||
self, target_position="random", target_rotation="xyz", reward_type="sparse"
|
|
||||||
):
|
|
||||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||||
ManipulateEnv.__init__(
|
ManipulateEnv.__init__(
|
||||||
self,
|
self,
|
||||||
@@ -322,9 +302,7 @@ class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
||||||
def __init__(
|
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
|
||||||
self, target_position="random", target_rotation="xyz", reward_type="sparse"
|
|
||||||
):
|
|
||||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||||
ManipulateEnv.__init__(
|
ManipulateEnv.__init__(
|
||||||
self,
|
self,
|
||||||
@@ -337,9 +315,7 @@ class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
class HandPenEnv(ManipulateEnv, utils.EzPickle):
|
class HandPenEnv(ManipulateEnv, utils.EzPickle):
|
||||||
def __init__(
|
def __init__(self, target_position="random", target_rotation="xyz", reward_type="sparse"):
|
||||||
self, target_position="random", target_rotation="xyz", reward_type="sparse"
|
|
||||||
):
|
|
||||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||||
ManipulateEnv.__init__(
|
ManipulateEnv.__init__(
|
||||||
self,
|
self,
|
||||||
|
@@ -70,16 +70,12 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
|||||||
for (
|
for (
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
) in (
|
) in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
|
||||||
self.sim.model._sensor_name2id.items()
|
|
||||||
): # get touch sensor site names and their ids
|
|
||||||
if "robot0:TS_" in k:
|
if "robot0:TS_" in k:
|
||||||
self._touch_sensor_id_site_id.append(
|
self._touch_sensor_id_site_id.append(
|
||||||
(
|
(
|
||||||
v,
|
v,
|
||||||
self.sim.model._site_name2id[
|
self.sim.model._site_name2id[k.replace("robot0:TS_", "robot0:T_")],
|
||||||
k.replace("robot0:TS_", "robot0:T_")
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._touch_sensor_id.append(v)
|
self._touch_sensor_id.append(v)
|
||||||
@@ -93,15 +89,9 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
|||||||
obs = self._get_obs()
|
obs = self._get_obs()
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
dict(
|
dict(
|
||||||
desired_goal=spaces.Box(
|
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||||
),
|
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
|
||||||
achieved_goal=spaces.Box(
|
|
||||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
|
||||||
),
|
|
||||||
observation=spaces.Box(
|
|
||||||
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,9 +107,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
|||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim)
|
robot_qpos, robot_qvel = manipulate.robot_get_obs(self.sim)
|
||||||
object_qvel = self.sim.data.get_joint_qvel("object:joint")
|
object_qvel = self.sim.data.get_joint_qvel("object:joint")
|
||||||
achieved_goal = (
|
achieved_goal = self._get_achieved_goal().ravel() # this contains the object position + rotation
|
||||||
self._get_achieved_goal().ravel()
|
|
||||||
) # this contains the object position + rotation
|
|
||||||
touch_values = [] # get touch sensor readings. if there is one, set value to 1
|
touch_values = [] # get touch sensor readings. if there is one, set value to 1
|
||||||
if self.touch_get_obs == "sensordata":
|
if self.touch_get_obs == "sensordata":
|
||||||
touch_values = self.sim.data.sensordata[self._touch_sensor_id]
|
touch_values = self.sim.data.sensordata[self._touch_sensor_id]
|
||||||
@@ -127,9 +115,7 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
|||||||
touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0
|
touch_values = self.sim.data.sensordata[self._touch_sensor_id] > 0.0
|
||||||
elif self.touch_get_obs == "log":
|
elif self.touch_get_obs == "log":
|
||||||
touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0)
|
touch_values = np.log(self.sim.data.sensordata[self._touch_sensor_id] + 1.0)
|
||||||
observation = np.concatenate(
|
observation = np.concatenate([robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal])
|
||||||
[robot_qpos, robot_qvel, object_qvel, touch_values, achieved_goal]
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"observation": observation.copy(),
|
"observation": observation.copy(),
|
||||||
@@ -146,9 +132,7 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
|||||||
touch_get_obs="sensordata",
|
touch_get_obs="sensordata",
|
||||||
reward_type="sparse",
|
reward_type="sparse",
|
||||||
):
|
):
|
||||||
utils.EzPickle.__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||||
self, target_position, target_rotation, touch_get_obs, reward_type
|
|
||||||
)
|
|
||||||
ManipulateTouchSensorsEnv.__init__(
|
ManipulateTouchSensorsEnv.__init__(
|
||||||
self,
|
self,
|
||||||
model_path=MANIPULATE_BLOCK_XML,
|
model_path=MANIPULATE_BLOCK_XML,
|
||||||
@@ -168,9 +152,7 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
|||||||
touch_get_obs="sensordata",
|
touch_get_obs="sensordata",
|
||||||
reward_type="sparse",
|
reward_type="sparse",
|
||||||
):
|
):
|
||||||
utils.EzPickle.__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||||
self, target_position, target_rotation, touch_get_obs, reward_type
|
|
||||||
)
|
|
||||||
ManipulateTouchSensorsEnv.__init__(
|
ManipulateTouchSensorsEnv.__init__(
|
||||||
self,
|
self,
|
||||||
model_path=MANIPULATE_EGG_XML,
|
model_path=MANIPULATE_EGG_XML,
|
||||||
@@ -190,9 +172,7 @@ class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
|||||||
touch_get_obs="sensordata",
|
touch_get_obs="sensordata",
|
||||||
reward_type="sparse",
|
reward_type="sparse",
|
||||||
):
|
):
|
||||||
utils.EzPickle.__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||||
self, target_position, target_rotation, touch_get_obs, reward_type
|
|
||||||
)
|
|
||||||
ManipulateTouchSensorsEnv.__init__(
|
ManipulateTouchSensorsEnv.__init__(
|
||||||
self,
|
self,
|
||||||
model_path=MANIPULATE_PEN_XML,
|
model_path=MANIPULATE_PEN_XML,
|
||||||
|
@@ -96,9 +96,7 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
|
|||||||
self.sim.forward()
|
self.sim.forward()
|
||||||
|
|
||||||
self.initial_goal = self._get_achieved_goal().copy()
|
self.initial_goal = self._get_achieved_goal().copy()
|
||||||
self.palm_xpos = self.sim.data.body_xpos[
|
self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id("robot0:palm")].copy()
|
||||||
self.sim.model.body_name2id("robot0:palm")
|
|
||||||
].copy()
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
robot_qpos, robot_qvel = robot_get_obs(self.sim)
|
robot_qpos, robot_qvel = robot_get_obs(self.sim)
|
||||||
@@ -155,7 +153,5 @@ class HandReachEnv(hand_env.HandEnv, utils.EzPickle):
|
|||||||
for finger_idx in range(5):
|
for finger_idx in range(5):
|
||||||
site_name = "finger{}".format(finger_idx)
|
site_name = "finger{}".format(finger_idx)
|
||||||
site_id = self.sim.model.site_name2id(site_name)
|
site_id = self.sim.model.site_name2id(site_name)
|
||||||
self.sim.model.site_pos[site_id] = (
|
self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id]
|
||||||
achieved_goal[finger_idx] - sites_offset[site_id]
|
|
||||||
)
|
|
||||||
self.sim.forward()
|
self.sim.forward()
|
||||||
|
@@ -30,22 +30,14 @@ class HandEnv(robot_env.RobotEnv):
|
|||||||
if self.relative_control:
|
if self.relative_control:
|
||||||
actuation_center = np.zeros_like(action)
|
actuation_center = np.zeros_like(action)
|
||||||
for i in range(self.sim.data.ctrl.shape[0]):
|
for i in range(self.sim.data.ctrl.shape[0]):
|
||||||
actuation_center[i] = self.sim.data.get_joint_qpos(
|
actuation_center[i] = self.sim.data.get_joint_qpos(self.sim.model.actuator_names[i].replace(":A_", ":"))
|
||||||
self.sim.model.actuator_names[i].replace(":A_", ":")
|
|
||||||
)
|
|
||||||
for joint_name in ["FF", "MF", "RF", "LF"]:
|
for joint_name in ["FF", "MF", "RF", "LF"]:
|
||||||
act_idx = self.sim.model.actuator_name2id(
|
act_idx = self.sim.model.actuator_name2id("robot0:A_{}J1".format(joint_name))
|
||||||
"robot0:A_{}J1".format(joint_name)
|
actuation_center[act_idx] += self.sim.data.get_joint_qpos("robot0:{}J0".format(joint_name))
|
||||||
)
|
|
||||||
actuation_center[act_idx] += self.sim.data.get_joint_qpos(
|
|
||||||
"robot0:{}J0".format(joint_name)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0
|
actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2.0
|
||||||
self.sim.data.ctrl[:] = actuation_center + action * actuation_range
|
self.sim.data.ctrl[:] = actuation_center + action * actuation_range
|
||||||
self.sim.data.ctrl[:] = np.clip(
|
self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1])
|
||||||
self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _viewer_setup(self):
|
def _viewer_setup(self):
|
||||||
body_id = self.sim.model.body_name2id("robot0:palm")
|
body_id = self.sim.model.body_name2id("robot0:palm")
|
||||||
|
@@ -46,15 +46,9 @@ class RobotEnv(gym.GoalEnv):
|
|||||||
self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32")
|
self.action_space = spaces.Box(-1.0, 1.0, shape=(n_actions,), dtype="float32")
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
dict(
|
dict(
|
||||||
desired_goal=spaces.Box(
|
desired_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"),
|
||||||
),
|
observation=spaces.Box(-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"),
|
||||||
achieved_goal=spaces.Box(
|
|
||||||
-np.inf, np.inf, shape=obs["achieved_goal"].shape, dtype="float32"
|
|
||||||
),
|
|
||||||
observation=spaces.Box(
|
|
||||||
-np.inf, np.inf, shape=obs["observation"].shape, dtype="float32"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -164,12 +164,8 @@ def mat2euler(mat):
|
|||||||
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
|
-np.arctan2(mat[..., 0, 1], mat[..., 0, 0]),
|
||||||
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]),
|
-np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]),
|
||||||
)
|
)
|
||||||
euler[..., 1] = np.where(
|
euler[..., 1] = np.where(condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy))
|
||||||
condition, -np.arctan2(-mat[..., 0, 2], cy), -np.arctan2(-mat[..., 0, 2], cy)
|
euler[..., 0] = np.where(condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0)
|
||||||
)
|
|
||||||
euler[..., 0] = np.where(
|
|
||||||
condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0
|
|
||||||
)
|
|
||||||
return euler
|
return euler
|
||||||
|
|
||||||
|
|
||||||
|
@@ -75,15 +75,9 @@ def reset_mocap2body_xpos(sim):
|
|||||||
values as the bodies they're welded to.
|
values as the bodies they're welded to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None:
|
||||||
sim.model.eq_type is None
|
|
||||||
or sim.model.eq_obj1id is None
|
|
||||||
or sim.model.eq_obj2id is None
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
for eq_type, obj1_id, obj2_id in zip(
|
for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id):
|
||||||
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id
|
|
||||||
):
|
|
||||||
if eq_type != mujoco_py.const.EQ_WELD:
|
if eq_type != mujoco_py.const.EQ_WELD:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@@ -2,10 +2,7 @@ from gym import envs, logger
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
SKIP_MUJOCO_WARNING_MESSAGE = (
|
SKIP_MUJOCO_WARNING_MESSAGE = "Cannot run mujoco test (either license key not found or mujoco not" "installed properly)."
|
||||||
"Cannot run mujoco test (either license key not found or mujoco not"
|
|
||||||
"installed properly)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
|
skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
|
||||||
@@ -21,9 +18,7 @@ def should_skip_env_spec_for_tests(spec):
|
|||||||
# troublesome to run frequently
|
# troublesome to run frequently
|
||||||
ep = spec.entry_point
|
ep = spec.entry_point
|
||||||
# Skip mujoco tests for pull request CI
|
# Skip mujoco tests for pull request CI
|
||||||
if skip_mujoco and (
|
if skip_mujoco and (ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")):
|
||||||
ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:")
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
import atari_py
|
import atari_py
|
||||||
@@ -39,11 +34,7 @@ def should_skip_env_spec_for_tests(spec):
|
|||||||
if (
|
if (
|
||||||
"GoEnv" in ep
|
"GoEnv" in ep
|
||||||
or "HexEnv" in ep
|
or "HexEnv" in ep
|
||||||
or (
|
or (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
|
||||||
ep.startswith("gym.envs.atari")
|
|
||||||
and not spec.id.startswith("Pong")
|
|
||||||
and not spec.id.startswith("Seaquest")
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
logger.warn("Skipping tests for env {}".format(ep))
|
logger.warn("Skipping tests for env {}".format(ep))
|
||||||
return True
|
return True
|
||||||
|
@@ -25,9 +25,7 @@ def test_env(spec):
|
|||||||
step_responses2 = [env2.step(action) for action in action_samples2]
|
step_responses2 = [env2.step(action) for action in action_samples2]
|
||||||
env2.close()
|
env2.close()
|
||||||
|
|
||||||
for i, (action_sample1, action_sample2) in enumerate(
|
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
|
||||||
zip(action_samples1, action_samples2)
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
assert_equals(action_sample1, action_sample2)
|
assert_equals(action_sample1, action_sample2)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
@@ -35,11 +33,7 @@ def test_env(spec):
|
|||||||
print("env2.action_space=", env2.action_space)
|
print("env2.action_space=", env2.action_space)
|
||||||
print("action_samples1=", action_samples1)
|
print("action_samples1=", action_samples1)
|
||||||
print("action_samples2=", action_samples2)
|
print("action_samples2=", action_samples2)
|
||||||
print(
|
print("[{}] action_sample1: {}, action_sample2: {}".format(i, action_sample1, action_sample2))
|
||||||
"[{}] action_sample1: {}, action_sample2: {}".format(
|
|
||||||
i, action_sample1, action_sample2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Don't check rollout equality if it's a a nondeterministic
|
# Don't check rollout equality if it's a a nondeterministic
|
||||||
@@ -49,9 +43,7 @@ def test_env(spec):
|
|||||||
|
|
||||||
assert_equals(initial_observation1, initial_observation2)
|
assert_equals(initial_observation1, initial_observation2)
|
||||||
|
|
||||||
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(
|
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
|
||||||
zip(step_responses1, step_responses2)
|
|
||||||
):
|
|
||||||
assert_equals(o1, o2, "[{}] ".format(i))
|
assert_equals(o1, o2, "[{}] ".format(i))
|
||||||
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
|
assert r1 == r2, "[{}] r1: {}, r2: {}".format(i, r1, r2)
|
||||||
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
|
assert d1 == d2, "[{}] d1: {}, d2: {}".format(i, d1, d2)
|
||||||
@@ -66,9 +58,7 @@ def test_env(spec):
|
|||||||
def assert_equals(a, b, prefix=None):
|
def assert_equals(a, b, prefix=None):
|
||||||
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
|
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
|
||||||
if isinstance(a, dict):
|
if isinstance(a, dict):
|
||||||
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(
|
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b)
|
||||||
prefix, a, b
|
|
||||||
)
|
|
||||||
|
|
||||||
for k in a.keys():
|
for k in a.keys():
|
||||||
v_a = a[k]
|
v_a = a[k]
|
||||||
|
@@ -24,9 +24,7 @@ def test_env(spec):
|
|||||||
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
|
assert ob_space.contains(ob), "Reset observation: {!r} not in space".format(ob)
|
||||||
a = act_space.sample()
|
a = act_space.sample()
|
||||||
observation, reward, done, _info = env.step(a)
|
observation, reward, done, _info = env.step(a)
|
||||||
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(
|
assert ob_space.contains(observation), "Step observation: {!r} not in space".format(observation)
|
||||||
observation
|
|
||||||
)
|
|
||||||
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
|
assert np.isscalar(reward), "{} is not a scalar for {}".format(reward, env)
|
||||||
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)
|
assert isinstance(done, bool), "Expected {} to be a boolean".format(done)
|
||||||
|
|
||||||
|
@@ -81,9 +81,7 @@ def test_env_semantics(spec):
|
|||||||
if spec.id not in rollout_dict:
|
if spec.id not in rollout_dict:
|
||||||
if not spec.nondeterministic:
|
if not spec.nondeterministic:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(
|
"Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id)
|
||||||
spec.id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -100,21 +98,15 @@ def test_env_semantics(spec):
|
|||||||
)
|
)
|
||||||
if rollout_dict[spec.id]["actions"] != actions_now:
|
if rollout_dict[spec.id]["actions"] != actions_now:
|
||||||
errors.append(
|
errors.append(
|
||||||
"Actions not equal for {} -- expected {} but got {}".format(
|
"Actions not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["actions"], actions_now)
|
||||||
spec.id, rollout_dict[spec.id]["actions"], actions_now
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if rollout_dict[spec.id]["rewards"] != rewards_now:
|
if rollout_dict[spec.id]["rewards"] != rewards_now:
|
||||||
errors.append(
|
errors.append(
|
||||||
"Rewards not equal for {} -- expected {} but got {}".format(
|
"Rewards not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["rewards"], rewards_now)
|
||||||
spec.id, rollout_dict[spec.id]["rewards"], rewards_now
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if rollout_dict[spec.id]["dones"] != dones_now:
|
if rollout_dict[spec.id]["dones"] != dones_now:
|
||||||
errors.append(
|
errors.append(
|
||||||
"Dones not equal for {} -- expected {} but got {}".format(
|
"Dones not equal for {} -- expected {} but got {}".format(spec.id, rollout_dict[spec.id]["dones"], dones_now)
|
||||||
spec.id, rollout_dict[spec.id]["dones"], dones_now
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if len(errors):
|
if len(errors):
|
||||||
for error in errors:
|
for error in errors:
|
||||||
|
@@ -4,9 +4,7 @@ from gym import envs
|
|||||||
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
|
from gym.envs.tests.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
|
||||||
|
|
||||||
|
|
||||||
def verify_environments_match(
|
def verify_environments_match(old_environment_id, new_environment_id, seed=1, num_actions=1000):
|
||||||
old_environment_id, new_environment_id, seed=1, num_actions=1000
|
|
||||||
):
|
|
||||||
old_environment = envs.make(old_environment_id)
|
old_environment = envs.make(old_environment_id)
|
||||||
new_environment = envs.make(new_environment_id)
|
new_environment = envs.make(new_environment_id)
|
||||||
|
|
||||||
|
@@ -83,8 +83,6 @@ def test_malformed_lookup():
|
|||||||
try:
|
try:
|
||||||
registry.spec(u"“Breakout-v0”")
|
registry.spec(u"“Breakout-v0”")
|
||||||
except error.Error as e:
|
except error.Error as e:
|
||||||
assert "malformed environment ID" in "{}".format(
|
assert "malformed environment ID" in "{}".format(e), "Unexpected message: {}".format(e)
|
||||||
e
|
|
||||||
), "Unexpected message: {}".format(e)
|
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
@@ -75,9 +75,7 @@ class BlackjackEnv(gym.Env):
|
|||||||
|
|
||||||
def __init__(self, natural=False):
|
def __init__(self, natural=False):
|
||||||
self.action_space = spaces.Discrete(2)
|
self.action_space = spaces.Discrete(2)
|
||||||
self.observation_space = spaces.Tuple(
|
self.observation_space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
||||||
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
|
||||||
)
|
|
||||||
self.seed()
|
self.seed()
|
||||||
|
|
||||||
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
|
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
|
||||||
|
@@ -141,9 +141,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
|||||||
else:
|
else:
|
||||||
if is_slippery:
|
if is_slippery:
|
||||||
for b in [(a - 1) % 4, a, (a + 1) % 4]:
|
for b in [(a - 1) % 4, a, (a + 1) % 4]:
|
||||||
li.append(
|
li.append((1.0 / 3.0, *update_probability_matrix(row, col, b)))
|
||||||
(1.0 / 3.0, *update_probability_matrix(row, col, b))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
li.append((1.0, *update_probability_matrix(row, col, a)))
|
li.append((1.0, *update_probability_matrix(row, col, a)))
|
||||||
|
|
||||||
@@ -157,9 +155,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
|||||||
desc = [[c.decode("utf-8") for c in line] for line in desc]
|
desc = [[c.decode("utf-8") for c in line] for line in desc]
|
||||||
desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
|
desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
|
||||||
if self.lastaction is not None:
|
if self.lastaction is not None:
|
||||||
outfile.write(
|
outfile.write(" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction]))
|
||||||
" ({})\n".format(["Left", "Down", "Right", "Up"][self.lastaction])
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
outfile.write("\n")
|
outfile.write("\n")
|
||||||
outfile.write("\n".join("".join(line) for line in desc) + "\n")
|
outfile.write("\n".join("".join(line) for line in desc) + "\n")
|
||||||
|
@@ -80,11 +80,7 @@ class GuessingGame(gym.Env):
|
|||||||
reward = 0
|
reward = 0
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
if (
|
if (self.number - self.range * 0.01) < action < (self.number + self.range * 0.01):
|
||||||
(self.number - self.range * 0.01)
|
|
||||||
< action
|
|
||||||
< (self.number + self.range * 0.01)
|
|
||||||
):
|
|
||||||
reward = 1
|
reward = 1
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
|
@@ -62,10 +62,7 @@ class HotterColder(gym.Env):
|
|||||||
elif action > self.number:
|
elif action > self.number:
|
||||||
self.observation = 3
|
self.observation = 3
|
||||||
|
|
||||||
reward = (
|
reward = ((min(action, self.number) + self.bounds) / (max(action, self.number) + self.bounds)) ** 2
|
||||||
(min(action, self.number) + self.bounds)
|
|
||||||
/ (max(action, self.number) + self.bounds)
|
|
||||||
) ** 2
|
|
||||||
|
|
||||||
self.guess_count += 1
|
self.guess_count += 1
|
||||||
done = self.guess_count >= self.guess_max
|
done = self.guess_count >= self.guess_max
|
||||||
|
@@ -65,9 +65,7 @@ class KellyCoinflipEnv(gym.Env):
|
|||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
bet_in_dollars = min(
|
bet_in_dollars = min(action / 100.0, self.wealth) # action = desired bet in pennies
|
||||||
action / 100.0, self.wealth
|
|
||||||
) # action = desired bet in pennies
|
|
||||||
self.rounds -= 1
|
self.rounds -= 1
|
||||||
|
|
||||||
coinflip = flip(self.edge, self.np_random)
|
coinflip = flip(self.edge, self.np_random)
|
||||||
@@ -149,35 +147,19 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
|
|||||||
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
|
edge = self.np_random.beta(edge_prior_alpha, edge_prior_beta)
|
||||||
if self.clip_distributions:
|
if self.clip_distributions:
|
||||||
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
|
# (clip/resample some parameters to be able to fix obs/action space sizes/bounds)
|
||||||
max_wealth_bound = round(
|
max_wealth_bound = round(genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m))
|
||||||
genpareto.ppf(0.85, max_wealth_alpha, max_wealth_m)
|
|
||||||
)
|
|
||||||
max_wealth = max_wealth_bound + 1.0
|
max_wealth = max_wealth_bound + 1.0
|
||||||
while max_wealth > max_wealth_bound:
|
while max_wealth > max_wealth_bound:
|
||||||
max_wealth = round(
|
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
|
||||||
genpareto.rvs(
|
max_rounds_bound = int(round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd)))
|
||||||
max_wealth_alpha, max_wealth_m, random_state=self.np_random
|
|
||||||
)
|
|
||||||
)
|
|
||||||
max_rounds_bound = int(
|
|
||||||
round(norm.ppf(0.99, max_rounds_mean, max_rounds_sd))
|
|
||||||
)
|
|
||||||
max_rounds = max_rounds_bound + 1
|
max_rounds = max_rounds_bound + 1
|
||||||
while max_rounds > max_rounds_bound:
|
while max_rounds > max_rounds_bound:
|
||||||
max_rounds = int(
|
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
|
||||||
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
max_wealth = round(
|
max_wealth = round(genpareto.rvs(max_wealth_alpha, max_wealth_m, random_state=self.np_random))
|
||||||
genpareto.rvs(
|
|
||||||
max_wealth_alpha, max_wealth_m, random_state=self.np_random
|
|
||||||
)
|
|
||||||
)
|
|
||||||
max_wealth_bound = max_wealth
|
max_wealth_bound = max_wealth
|
||||||
max_rounds = int(
|
max_rounds = int(round(self.np_random.normal(max_rounds_mean, max_rounds_sd)))
|
||||||
round(self.np_random.normal(max_rounds_mean, max_rounds_sd))
|
|
||||||
)
|
|
||||||
max_rounds_bound = max_rounds
|
max_rounds_bound = max_rounds
|
||||||
|
|
||||||
# add an additional global variable which is the sufficient statistic for the
|
# add an additional global variable which is the sufficient statistic for the
|
||||||
@@ -194,9 +176,7 @@ class KellyCoinflipGeneralizedEnv(gym.Env):
|
|||||||
self.action_space = spaces.Discrete(int(max_wealth_bound * 100))
|
self.action_space = spaces.Discrete(int(max_wealth_bound * 100))
|
||||||
self.observation_space = spaces.Tuple(
|
self.observation_space = spaces.Tuple(
|
||||||
(
|
(
|
||||||
spaces.Box(
|
spaces.Box(0, max_wealth_bound, shape=[1], dtype=np.float32), # current wealth
|
||||||
0, max_wealth_bound, shape=[1], dtype=np.float32
|
|
||||||
), # current wealth
|
|
||||||
spaces.Discrete(max_rounds_bound + 1), # rounds elapsed
|
spaces.Discrete(max_rounds_bound + 1), # rounds elapsed
|
||||||
spaces.Discrete(max_rounds_bound + 1), # wins
|
spaces.Discrete(max_rounds_bound + 1), # wins
|
||||||
spaces.Discrete(max_rounds_bound + 1), # losses
|
spaces.Discrete(max_rounds_bound + 1), # losses
|
||||||
|
@@ -80,10 +80,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
max_col = num_columns - 1
|
max_col = num_columns - 1
|
||||||
initial_state_distrib = np.zeros(num_states)
|
initial_state_distrib = np.zeros(num_states)
|
||||||
num_actions = 6
|
num_actions = 6
|
||||||
P = {
|
P = {state: {action: [] for action in range(num_actions)} for state in range(num_states)}
|
||||||
state: {action: [] for action in range(num_actions)}
|
|
||||||
for state in range(num_states)
|
|
||||||
}
|
|
||||||
for row in range(num_rows):
|
for row in range(num_rows):
|
||||||
for col in range(num_columns):
|
for col in range(num_columns):
|
||||||
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
|
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
|
||||||
@@ -94,9 +91,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
for action in range(num_actions):
|
for action in range(num_actions):
|
||||||
# defaults
|
# defaults
|
||||||
new_row, new_col, new_pass_idx = row, col, pass_idx
|
new_row, new_col, new_pass_idx = row, col, pass_idx
|
||||||
reward = (
|
reward = -1 # default reward when there is no pickup/dropoff
|
||||||
-1
|
|
||||||
) # default reward when there is no pickup/dropoff
|
|
||||||
done = False
|
done = False
|
||||||
taxi_loc = (row, col)
|
taxi_loc = (row, col)
|
||||||
|
|
||||||
@@ -122,14 +117,10 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
new_pass_idx = locs.index(taxi_loc)
|
new_pass_idx = locs.index(taxi_loc)
|
||||||
else: # dropoff at wrong location
|
else: # dropoff at wrong location
|
||||||
reward = -10
|
reward = -10
|
||||||
new_state = self.encode(
|
new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
|
||||||
new_row, new_col, new_pass_idx, dest_idx
|
|
||||||
)
|
|
||||||
P[state][action].append((1.0, new_state, reward, done))
|
P[state][action].append((1.0, new_state, reward, done))
|
||||||
initial_state_distrib /= initial_state_distrib.sum()
|
initial_state_distrib /= initial_state_distrib.sum()
|
||||||
discrete.DiscreteEnv.__init__(
|
discrete.DiscreteEnv.__init__(self, num_states, num_actions, P, initial_state_distrib)
|
||||||
self, num_states, num_actions, P, initial_state_distrib
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
|
def encode(self, taxi_row, taxi_col, pass_loc, dest_idx):
|
||||||
# (5) 5, 5, 4
|
# (5) 5, 5, 4
|
||||||
@@ -165,13 +156,9 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
return "_" if x == " " else x
|
return "_" if x == " " else x
|
||||||
|
|
||||||
if pass_idx < 4:
|
if pass_idx < 4:
|
||||||
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True)
|
||||||
out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True
|
|
||||||
)
|
|
||||||
pi, pj = self.locs[pass_idx]
|
pi, pj = self.locs[pass_idx]
|
||||||
out[1 + pi][2 * pj + 1] = utils.colorize(
|
out[1 + pi][2 * pj + 1] = utils.colorize(out[1 + pi][2 * pj + 1], "blue", bold=True)
|
||||||
out[1 + pi][2 * pj + 1], "blue", bold=True
|
|
||||||
)
|
|
||||||
else: # passenger in taxi
|
else: # passenger in taxi
|
||||||
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize(
|
||||||
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
|
ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True
|
||||||
@@ -181,13 +168,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
|
out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta")
|
||||||
outfile.write("\n".join(["".join(row) for row in out]) + "\n")
|
outfile.write("\n".join(["".join(row) for row in out]) + "\n")
|
||||||
if self.lastaction is not None:
|
if self.lastaction is not None:
|
||||||
outfile.write(
|
outfile.write(" ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
|
||||||
" ({})\n".format(
|
|
||||||
["South", "North", "East", "West", "Pickup", "Dropoff"][
|
|
||||||
self.lastaction
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
outfile.write("\n")
|
outfile.write("\n")
|
||||||
|
|
||||||
|
@@ -55,9 +55,7 @@ class CubeCrash(gym.Env):
|
|||||||
self.seed()
|
self.seed()
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
||||||
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -83,16 +81,9 @@ class CubeCrash(gym.Env):
|
|||||||
self.potential = None
|
self.potential = None
|
||||||
self.step_n = 0
|
self.step_n = 0
|
||||||
while 1:
|
while 1:
|
||||||
self.wall_color = (
|
self.wall_color = self.random_color() if self.use_random_colors else color_white
|
||||||
self.random_color() if self.use_random_colors else color_white
|
self.cube_color = self.random_color() if self.use_random_colors else color_green
|
||||||
)
|
if np.linalg.norm(self.wall_color - self.bg_color) < 50 or np.linalg.norm(self.cube_color - self.bg_color) < 50:
|
||||||
self.cube_color = (
|
|
||||||
self.random_color() if self.use_random_colors else color_green
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
np.linalg.norm(self.wall_color - self.bg_color) < 50
|
|
||||||
or np.linalg.norm(self.cube_color - self.bg_color) < 50
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
return self.step(0)[0]
|
return self.step(0)[0]
|
||||||
@@ -117,9 +108,7 @@ class CubeCrash(gym.Env):
|
|||||||
self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1,
|
self.hole_x - HOLE_WIDTH // 2 : self.hole_x + HOLE_WIDTH // 2 + 1,
|
||||||
:,
|
:,
|
||||||
] = self.bg_color
|
] = self.bg_color
|
||||||
obs[
|
obs[self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :] = self.cube_color
|
||||||
self.cube_y - 1 : self.cube_y + 2, self.cube_x - 1 : self.cube_x + 2, :
|
|
||||||
] = self.cube_color
|
|
||||||
if self.use_black_screen and self.step_n > 4:
|
if self.use_black_screen and self.step_n > 4:
|
||||||
obs[:] = np.zeros((3,), dtype=np.uint8)
|
obs[:] = np.zeros((3,), dtype=np.uint8)
|
||||||
|
|
||||||
|
@@ -62,16 +62,12 @@ class MemorizeDigits(gym.Env):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed()
|
self.seed()
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
|
||||||
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
self.action_space = spaces.Discrete(10)
|
self.action_space = spaces.Discrete(10)
|
||||||
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
|
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
|
||||||
for digit in range(10):
|
for digit in range(10):
|
||||||
for y in range(6):
|
for y in range(6):
|
||||||
self.bogus_mnist[digit, y, :] = [
|
self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]]
|
||||||
ord(char) for char in bogus_mnist[digit][y]
|
|
||||||
]
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
@@ -93,9 +89,7 @@ class MemorizeDigits(gym.Env):
|
|||||||
self.color_bg = self.random_color() if self.use_random_colors else color_black
|
self.color_bg = self.random_color() if self.use_random_colors else color_black
|
||||||
self.step_n = 0
|
self.step_n = 0
|
||||||
while 1:
|
while 1:
|
||||||
self.color_digit = (
|
self.color_digit = self.random_color() if self.use_random_colors else color_white
|
||||||
self.random_color() if self.use_random_colors else color_white
|
|
||||||
)
|
|
||||||
if np.linalg.norm(self.color_digit - self.color_bg) < 50:
|
if np.linalg.norm(self.color_digit - self.color_bg) < 50:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@@ -119,9 +113,7 @@ class MemorizeDigits(gym.Env):
|
|||||||
digit_img[:] = self.color_bg
|
digit_img[:] = self.color_bg
|
||||||
xxx = self.bogus_mnist[self.digit] == 42
|
xxx = self.bogus_mnist[self.digit] == 42
|
||||||
digit_img[xxx] = self.color_digit
|
digit_img[xxx] = self.color_digit
|
||||||
obs[
|
obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img
|
||||||
self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
|
|
||||||
] = digit_img
|
|
||||||
self.last_obs = obs
|
self.last_obs = obs
|
||||||
return obs, reward, done, {}
|
return obs, reward, done, {}
|
||||||
|
|
||||||
|
@@ -102,10 +102,7 @@ class APIError(Error):
|
|||||||
try:
|
try:
|
||||||
http_body = http_body.decode("utf-8")
|
http_body = http_body.decode("utf-8")
|
||||||
except:
|
except:
|
||||||
http_body = (
|
http_body = "<Could not decode body as utf-8. " "Please report to gym@openai.com>"
|
||||||
"<Could not decode body as utf-8. "
|
|
||||||
"Please report to gym@openai.com>"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._message = message
|
self._message = message
|
||||||
self.http_body = http_body
|
self.http_body = http_body
|
||||||
@@ -142,9 +139,7 @@ class InvalidRequestError(APIError):
|
|||||||
json_body=None,
|
json_body=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
super(InvalidRequestError, self).__init__(
|
super(InvalidRequestError, self).__init__(message, http_body, http_status, json_body, headers)
|
||||||
message, http_body, http_status, json_body, headers
|
|
||||||
)
|
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
|
|
||||||
|
@@ -29,26 +29,16 @@ class Box(Space):
|
|||||||
# determine shape if it isn't provided directly
|
# determine shape if it isn't provided directly
|
||||||
if shape is not None:
|
if shape is not None:
|
||||||
shape = tuple(shape)
|
shape = tuple(shape)
|
||||||
assert (
|
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match provided shape"
|
||||||
np.isscalar(low) or low.shape == shape
|
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match provided shape"
|
||||||
), "low.shape doesn't match provided shape"
|
|
||||||
assert (
|
|
||||||
np.isscalar(high) or high.shape == shape
|
|
||||||
), "high.shape doesn't match provided shape"
|
|
||||||
elif not np.isscalar(low):
|
elif not np.isscalar(low):
|
||||||
shape = low.shape
|
shape = low.shape
|
||||||
assert (
|
assert np.isscalar(high) or high.shape == shape, "high.shape doesn't match low.shape"
|
||||||
np.isscalar(high) or high.shape == shape
|
|
||||||
), "high.shape doesn't match low.shape"
|
|
||||||
elif not np.isscalar(high):
|
elif not np.isscalar(high):
|
||||||
shape = high.shape
|
shape = high.shape
|
||||||
assert (
|
assert np.isscalar(low) or low.shape == shape, "low.shape doesn't match high.shape"
|
||||||
np.isscalar(low) or low.shape == shape
|
|
||||||
), "low.shape doesn't match high.shape"
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("shape must be provided or inferred from the shapes of low or high")
|
||||||
"shape must be provided or inferred from the shapes of low or high"
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.isscalar(low):
|
if np.isscalar(low):
|
||||||
low = np.full(shape, low, dtype=dtype)
|
low = np.full(shape, low, dtype=dtype)
|
||||||
@@ -70,9 +60,7 @@ class Box(Space):
|
|||||||
high_precision = _get_precision(self.high.dtype)
|
high_precision = _get_precision(self.high.dtype)
|
||||||
dtype_precision = _get_precision(self.dtype)
|
dtype_precision = _get_precision(self.dtype)
|
||||||
if min(low_precision, high_precision) > dtype_precision:
|
if min(low_precision, high_precision) > dtype_precision:
|
||||||
logger.warn(
|
logger.warn("Box bound precision lowered by casting to {}".format(self.dtype))
|
||||||
"Box bound precision lowered by casting to {}".format(self.dtype)
|
|
||||||
)
|
|
||||||
self.low = self.low.astype(self.dtype)
|
self.low = self.low.astype(self.dtype)
|
||||||
self.high = self.high.astype(self.dtype)
|
self.high = self.high.astype(self.dtype)
|
||||||
|
|
||||||
@@ -119,19 +107,11 @@ class Box(Space):
|
|||||||
# Vectorized sampling by interval type
|
# Vectorized sampling by interval type
|
||||||
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
|
sample[unbounded] = self.np_random.normal(size=unbounded[unbounded].shape)
|
||||||
|
|
||||||
sample[low_bounded] = (
|
sample[low_bounded] = self.np_random.exponential(size=low_bounded[low_bounded].shape) + self.low[low_bounded]
|
||||||
self.np_random.exponential(size=low_bounded[low_bounded].shape)
|
|
||||||
+ self.low[low_bounded]
|
|
||||||
)
|
|
||||||
|
|
||||||
sample[upp_bounded] = (
|
sample[upp_bounded] = -self.np_random.exponential(size=upp_bounded[upp_bounded].shape) + self.high[upp_bounded]
|
||||||
-self.np_random.exponential(size=upp_bounded[upp_bounded].shape)
|
|
||||||
+ self.high[upp_bounded]
|
|
||||||
)
|
|
||||||
|
|
||||||
sample[bounded] = self.np_random.uniform(
|
sample[bounded] = self.np_random.uniform(low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape)
|
||||||
low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
|
|
||||||
)
|
|
||||||
if self.dtype.kind == "i":
|
if self.dtype.kind == "i":
|
||||||
sample = np.floor(sample)
|
sample = np.floor(sample)
|
||||||
|
|
||||||
@@ -140,9 +120,7 @@ class Box(Space):
|
|||||||
def contains(self, x):
|
def contains(self, x):
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
x = np.array(x) # Promote list to array for contains check
|
x = np.array(x) # Promote list to array for contains check
|
||||||
return (
|
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
|
||||||
x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_jsonable(self, sample_n):
|
def to_jsonable(self, sample_n):
|
||||||
return np.array(sample_n).tolist()
|
return np.array(sample_n).tolist()
|
||||||
@@ -151,9 +129,7 @@ class Box(Space):
|
|||||||
return [np.asarray(sample) for sample in sample_n]
|
return [np.asarray(sample) for sample in sample_n]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "Box({}, {}, {}, {})".format(
|
return "Box({}, {}, {}, {})".format(self.low.min(), self.high.max(), self.shape, self.dtype)
|
||||||
self.low.min(), self.high.max(), self.shape, self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
return (
|
||||||
|
@@ -33,9 +33,7 @@ class Dict(Space):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, spaces=None, **spaces_kwargs):
|
def __init__(self, spaces=None, **spaces_kwargs):
|
||||||
assert (spaces is None) or (
|
assert (spaces is None) or (not spaces_kwargs), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||||
not spaces_kwargs
|
|
||||||
), "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
|
||||||
if spaces is None:
|
if spaces is None:
|
||||||
spaces = spaces_kwargs
|
spaces = spaces_kwargs
|
||||||
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
if isinstance(spaces, dict) and not isinstance(spaces, OrderedDict):
|
||||||
@@ -44,12 +42,8 @@ class Dict(Space):
|
|||||||
spaces = OrderedDict(spaces)
|
spaces = OrderedDict(spaces)
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
for space in spaces.values():
|
for space in spaces.values():
|
||||||
assert isinstance(
|
assert isinstance(space, Space), "Values of the dict should be instances of gym.Space"
|
||||||
space, Space
|
super(Dict, self).__init__(None, None) # None for shape and dtype, since it'll require special handling
|
||||||
), "Values of the dict should be instances of gym.Space"
|
|
||||||
super(Dict, self).__init__(
|
|
||||||
None, None
|
|
||||||
) # None for shape and dtype, since it'll require special handling
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
[space.seed(seed) for space in self.spaces.values()]
|
[space.seed(seed) for space in self.spaces.values()]
|
||||||
@@ -75,18 +69,11 @@ class Dict(Space):
|
|||||||
yield key
|
yield key
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return "Dict(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
|
||||||
"Dict("
|
|
||||||
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
|
||||||
+ ")"
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_jsonable(self, sample_n):
|
def to_jsonable(self, sample_n):
|
||||||
# serialize as dict-repr of vectors
|
# serialize as dict-repr of vectors
|
||||||
return {
|
return {key: space.to_jsonable([sample[key] for sample in sample_n]) for key, space in self.spaces.items()}
|
||||||
key: space.to_jsonable([sample[key] for sample in sample_n])
|
|
||||||
for key, space in self.spaces.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def from_jsonable(self, sample_n):
|
def from_jsonable(self, sample_n):
|
||||||
dict_of_list = {}
|
dict_of_list = {}
|
||||||
|
@@ -22,9 +22,7 @@ class Discrete(Space):
|
|||||||
def contains(self, x):
|
def contains(self, x):
|
||||||
if isinstance(x, int):
|
if isinstance(x, int):
|
||||||
as_int = x
|
as_int = x
|
||||||
elif isinstance(x, (np.generic, np.ndarray)) and (
|
elif isinstance(x, (np.generic, np.ndarray)) and (x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()):
|
||||||
x.dtype.char in np.typecodes["AllInteger"] and x.shape == ()
|
|
||||||
):
|
|
||||||
as_int = int(x)
|
as_int = int(x)
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
@@ -34,9 +34,7 @@ class MultiDiscrete(Space):
|
|||||||
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
|
super(MultiDiscrete, self).__init__(self.nvec.shape, dtype)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(
|
return (self.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
|
||||||
self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def contains(self, x):
|
def contains(self, x):
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
|
@@ -25,9 +25,7 @@ from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -71,9 +69,7 @@ def test_roundtripping(space):
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
@@ -28,9 +28,7 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, u
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
7,
|
7,
|
||||||
@@ -60,18 +58,14 @@ def test_flatdim(space, flatdim):
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_flatten_space_boxes(space):
|
def test_flatten_space_boxes(space):
|
||||||
flat_space = utils.flatten_space(space)
|
flat_space = utils.flatten_space(space)
|
||||||
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(
|
assert isinstance(flat_space, Box), "Expected {} to equal {}".format(type(flat_space), Box)
|
||||||
type(flat_space), Box
|
|
||||||
)
|
|
||||||
flatdim = utils.flatdim(space)
|
flatdim = utils.flatdim(space)
|
||||||
(single_dim,) = flat_space.shape
|
(single_dim,) = flat_space.shape
|
||||||
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
assert single_dim == flatdim, "Expected {} to equal {}".format(single_dim, flatdim)
|
||||||
@@ -95,9 +89,7 @@ def test_flatten_space_boxes(space):
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -107,9 +99,7 @@ def test_flat_space_contains_flat_points(space):
|
|||||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||||
flat_space = utils.flatten_space(space)
|
flat_space = utils.flatten_space(space)
|
||||||
for i, flat_sample in enumerate(flattened_samples):
|
for i, flat_sample in enumerate(flattened_samples):
|
||||||
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(
|
assert flat_sample in flat_space, "Expected sample #{} {} to be in {}".format(i, flat_sample, flat_space)
|
||||||
i, flat_sample, flat_space
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -130,9 +120,7 @@ def test_flat_space_contains_flat_points(space):
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -162,9 +150,7 @@ def test_flatten_dim(space):
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -172,15 +158,9 @@ def test_flatten_dim(space):
|
|||||||
def test_flatten_roundtripping(space):
|
def test_flatten_roundtripping(space):
|
||||||
some_samples = [space.sample() for _ in range(10)]
|
some_samples = [space.sample() for _ in range(10)]
|
||||||
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
flattened_samples = [utils.flatten(space, sample) for sample in some_samples]
|
||||||
roundtripped_samples = [
|
roundtripped_samples = [utils.unflatten(space, sample) for sample in flattened_samples]
|
||||||
utils.unflatten(space, sample) for sample in flattened_samples
|
for i, (original, roundtripped) in enumerate(zip(some_samples, roundtripped_samples)):
|
||||||
]
|
assert compare_nested(original, roundtripped), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
|
||||||
for i, (original, roundtripped) in enumerate(
|
|
||||||
zip(some_samples, roundtripped_samples)
|
|
||||||
):
|
|
||||||
assert compare_nested(
|
|
||||||
original, roundtripped
|
|
||||||
), "Expected sample #{} {} to equal {}".format(i, original, roundtripped)
|
|
||||||
|
|
||||||
|
|
||||||
def compare_nested(left, right):
|
def compare_nested(left, right):
|
||||||
@@ -188,9 +168,7 @@ def compare_nested(left, right):
|
|||||||
return np.allclose(left, right)
|
return np.allclose(left, right)
|
||||||
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
elif isinstance(left, OrderedDict) and isinstance(right, OrderedDict):
|
||||||
res = len(left) == len(right)
|
res = len(left) == len(right)
|
||||||
for ((left_key, left_value), (right_key, right_value)) in zip(
|
for ((left_key, left_value), (right_key, right_value)) in zip(left.items(), right.items()):
|
||||||
left.items(), right.items()
|
|
||||||
):
|
|
||||||
if not res:
|
if not res:
|
||||||
return False
|
return False
|
||||||
res = left_key == right_key and compare_nested(left_value, right_value)
|
res = left_key == right_key and compare_nested(left_value, right_value)
|
||||||
@@ -238,9 +216,7 @@ Expecteded flattened types are based off:
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(5),
|
"position": Discrete(5),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16),
|
||||||
low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
np.float64,
|
np.float64,
|
||||||
@@ -254,12 +230,8 @@ def test_dtypes(original_space, expected_flattened_dtype):
|
|||||||
flattened_sample = utils.flatten(original_space, original_sample)
|
flattened_sample = utils.flatten(original_space, original_sample)
|
||||||
unflattened_sample = utils.unflatten(original_space, flattened_sample)
|
unflattened_sample = utils.unflatten(original_space, flattened_sample)
|
||||||
|
|
||||||
assert flattened_space.contains(
|
assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
|
||||||
flattened_sample
|
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " "{}".format(
|
||||||
), "Expected flattened_space to contain flattened_sample"
|
|
||||||
assert (
|
|
||||||
flattened_space.dtype == expected_flattened_dtype
|
|
||||||
), "Expected flattened_space's dtype to equal " "{}".format(
|
|
||||||
expected_flattened_dtype
|
expected_flattened_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -272,9 +244,10 @@ def test_dtypes(original_space, expected_flattened_dtype):
|
|||||||
|
|
||||||
def compare_sample_types(original_space, original_sample, unflattened_sample):
|
def compare_sample_types(original_space, original_sample, unflattened_sample):
|
||||||
if isinstance(original_space, Discrete):
|
if isinstance(original_space, Discrete):
|
||||||
assert isinstance(unflattened_sample, int), (
|
assert isinstance(
|
||||||
"Expected unflattened_sample to be an int. unflattened_sample: "
|
unflattened_sample, int
|
||||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
), "Expected unflattened_sample to be an int. unflattened_sample: " "{} original_sample: {}".format(
|
||||||
|
unflattened_sample, original_sample
|
||||||
)
|
)
|
||||||
elif isinstance(original_space, Tuple):
|
elif isinstance(original_space, Tuple):
|
||||||
for index in range(len(original_space)):
|
for index in range(len(original_space)):
|
||||||
|
@@ -13,9 +13,7 @@ class Tuple(Space):
|
|||||||
def __init__(self, spaces):
|
def __init__(self, spaces):
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
for space in spaces:
|
for space in spaces:
|
||||||
assert isinstance(
|
assert isinstance(space, Space), "Elements of the tuple must be instances of gym.Space"
|
||||||
space, Space
|
|
||||||
), "Elements of the tuple must be instances of gym.Space"
|
|
||||||
super(Tuple, self).__init__(None, None)
|
super(Tuple, self).__init__(None, None)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
@@ -38,21 +36,10 @@ class Tuple(Space):
|
|||||||
|
|
||||||
def to_jsonable(self, sample_n):
|
def to_jsonable(self, sample_n):
|
||||||
# serialize as list-repr of tuple of vectors
|
# serialize as list-repr of tuple of vectors
|
||||||
return [
|
return [space.to_jsonable([sample[i] for sample in sample_n]) for i, space in enumerate(self.spaces)]
|
||||||
space.to_jsonable([sample[i] for sample in sample_n])
|
|
||||||
for i, space in enumerate(self.spaces)
|
|
||||||
]
|
|
||||||
|
|
||||||
def from_jsonable(self, sample_n):
|
def from_jsonable(self, sample_n):
|
||||||
return [
|
return [sample for sample in zip(*[space.from_jsonable(sample_n[i]) for i, space in enumerate(self.spaces)])]
|
||||||
sample
|
|
||||||
for sample in zip(
|
|
||||||
*[
|
|
||||||
space.from_jsonable(sample_n[i])
|
|
||||||
for i, space in enumerate(self.spaces)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return self.spaces[index]
|
return self.spaces[index]
|
||||||
|
@@ -49,9 +49,7 @@ def flatten(space, x):
|
|||||||
onehot[x] = 1
|
onehot[x] = 1
|
||||||
return onehot
|
return onehot
|
||||||
elif isinstance(space, Tuple):
|
elif isinstance(space, Tuple):
|
||||||
return np.concatenate(
|
return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
|
||||||
[flatten(s, x_part) for x_part, s in zip(x, space.spaces)]
|
|
||||||
)
|
|
||||||
elif isinstance(space, Dict):
|
elif isinstance(space, Dict):
|
||||||
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||||
elif isinstance(space, MultiBinary):
|
elif isinstance(space, MultiBinary):
|
||||||
@@ -79,17 +77,13 @@ def unflatten(space, x):
|
|||||||
elif isinstance(space, Tuple):
|
elif isinstance(space, Tuple):
|
||||||
dims = [flatdim(s) for s in space.spaces]
|
dims = [flatdim(s) for s in space.spaces]
|
||||||
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
||||||
list_unflattened = [
|
list_unflattened = [unflatten(s, flattened) for flattened, s in zip(list_flattened, space.spaces)]
|
||||||
unflatten(s, flattened)
|
|
||||||
for flattened, s in zip(list_flattened, space.spaces)
|
|
||||||
]
|
|
||||||
return tuple(list_unflattened)
|
return tuple(list_unflattened)
|
||||||
elif isinstance(space, Dict):
|
elif isinstance(space, Dict):
|
||||||
dims = [flatdim(s) for s in space.spaces.values()]
|
dims = [flatdim(s) for s in space.spaces.values()]
|
||||||
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
list_flattened = np.split(x, np.cumsum(dims)[:-1])
|
||||||
list_unflattened = [
|
list_unflattened = [
|
||||||
(key, unflatten(s, flattened))
|
(key, unflatten(s, flattened)) for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
||||||
for flattened, (key, s) in zip(list_flattened, space.spaces.items())
|
|
||||||
]
|
]
|
||||||
return OrderedDict(list_unflattened)
|
return OrderedDict(list_unflattened)
|
||||||
elif isinstance(space, MultiBinary):
|
elif isinstance(space, MultiBinary):
|
||||||
|
@@ -88,11 +88,7 @@ def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=N
|
|||||||
elif hasattr(env.unwrapped, "get_keys_to_action"):
|
elif hasattr(env.unwrapped, "get_keys_to_action"):
|
||||||
keys_to_action = env.unwrapped.get_keys_to_action()
|
keys_to_action = env.unwrapped.get_keys_to_action()
|
||||||
else:
|
else:
|
||||||
assert False, (
|
assert False, env.spec.id + " does not have explicit key to action mapping, " + "please specify one manually"
|
||||||
env.spec.id
|
|
||||||
+ " does not have explicit key to action mapping, "
|
|
||||||
+ "please specify one manually"
|
|
||||||
)
|
|
||||||
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
|
relevant_keys = set(sum(map(list, keys_to_action.keys()), []))
|
||||||
|
|
||||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||||
@@ -172,9 +168,7 @@ class PlayPlot(object):
|
|||||||
for i, plot in enumerate(self.cur_plot):
|
for i, plot in enumerate(self.cur_plot):
|
||||||
if plot is not None:
|
if plot is not None:
|
||||||
plot.remove()
|
plot.remove()
|
||||||
self.cur_plot[i] = self.ax[i].scatter(
|
self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]), c="blue")
|
||||||
range(xmin, xmax), list(self.data[i]), c="blue"
|
|
||||||
)
|
|
||||||
self.ax[i].set_xlim(xmin, xmax)
|
self.ax[i].set_xlim(xmin, xmax)
|
||||||
plt.pause(0.000001)
|
plt.pause(0.000001)
|
||||||
|
|
||||||
|
@@ -10,9 +10,7 @@ from gym import error
|
|||||||
|
|
||||||
def np_random(seed=None):
|
def np_random(seed=None):
|
||||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||||
raise error.Error(
|
raise error.Error("Seed must be a non-negative integer or omitted, not {}".format(seed))
|
||||||
"Seed must be a non-negative integer or omitted, not {}".format(seed)
|
|
||||||
)
|
|
||||||
|
|
||||||
seed = create_seed(seed)
|
seed = create_seed(seed)
|
||||||
|
|
||||||
|
@@ -53,9 +53,7 @@ def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
|
|||||||
if wrappers is not None:
|
if wrappers is not None:
|
||||||
if callable(wrappers):
|
if callable(wrappers):
|
||||||
env = wrappers(env)
|
env = wrappers(env)
|
||||||
elif isinstance(wrappers, Iterable) and all(
|
elif isinstance(wrappers, Iterable) and all([callable(w) for w in wrappers]):
|
||||||
[callable(w) for w in wrappers]
|
|
||||||
):
|
|
||||||
for wrapper in wrappers:
|
for wrapper in wrappers:
|
||||||
env = wrapper(env)
|
env = wrapper(env)
|
||||||
else:
|
else:
|
||||||
|
@@ -107,12 +107,8 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
if self.shared_memory:
|
if self.shared_memory:
|
||||||
try:
|
try:
|
||||||
_obs_buffer = create_shared_memory(
|
_obs_buffer = create_shared_memory(self.single_observation_space, n=self.num_envs, ctx=ctx)
|
||||||
self.single_observation_space, n=self.num_envs, ctx=ctx
|
self.observations = read_from_shared_memory(_obs_buffer, self.single_observation_space, n=self.num_envs)
|
||||||
)
|
|
||||||
self.observations = read_from_shared_memory(
|
|
||||||
_obs_buffer, self.single_observation_space, n=self.num_envs
|
|
||||||
)
|
|
||||||
except CustomSpaceError:
|
except CustomSpaceError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
||||||
@@ -124,9 +120,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_obs_buffer = None
|
_obs_buffer = None
|
||||||
self.observations = create_empty_array(
|
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
|
||||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
|
||||||
)
|
|
||||||
|
|
||||||
self.parent_pipes, self.processes = [], []
|
self.parent_pipes, self.processes = [], []
|
||||||
self.error_queue = ctx.Queue()
|
self.error_queue = ctx.Queue()
|
||||||
@@ -168,8 +162,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
raise AlreadyPendingCallError(
|
raise AlreadyPendingCallError(
|
||||||
"Calling `seed` while waiting "
|
"Calling `seed` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
|
||||||
"for a pending call to `{0}` to complete.".format(self._state.value),
|
|
||||||
self._state.value,
|
self._state.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -182,8 +175,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
raise AlreadyPendingCallError(
|
raise AlreadyPendingCallError(
|
||||||
"Calling `reset_async` while waiting "
|
"Calling `reset_async` while waiting " "for a pending call to `{0}` to complete".format(self._state.value),
|
||||||
"for a pending call to `{0}` to complete".format(self._state.value),
|
|
||||||
self._state.value,
|
self._state.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -214,8 +206,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
if not self._poll(timeout):
|
if not self._poll(timeout):
|
||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
raise mp.TimeoutError(
|
raise mp.TimeoutError(
|
||||||
"The call to `reset_wait` has timed out after "
|
"The call to `reset_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
||||||
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||||
@@ -223,9 +214,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
|
|
||||||
if not self.shared_memory:
|
if not self.shared_memory:
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(results, self.observations, self.single_observation_space)
|
||||||
results, self.observations, self.single_observation_space
|
|
||||||
)
|
|
||||||
|
|
||||||
return deepcopy(self.observations) if self.copy else self.observations
|
return deepcopy(self.observations) if self.copy else self.observations
|
||||||
|
|
||||||
@@ -239,8 +228,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
raise AlreadyPendingCallError(
|
raise AlreadyPendingCallError(
|
||||||
"Calling `step_async` while waiting "
|
"Calling `step_async` while waiting " "for a pending call to `{0}` to complete.".format(self._state.value),
|
||||||
"for a pending call to `{0}` to complete.".format(self._state.value),
|
|
||||||
self._state.value,
|
self._state.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -280,8 +268,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
if not self._poll(timeout):
|
if not self._poll(timeout):
|
||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
raise mp.TimeoutError(
|
raise mp.TimeoutError(
|
||||||
"The call to `step_wait` has timed out after "
|
"The call to `step_wait` has timed out after " "{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
||||||
"{0} second{1}.".format(timeout, "s" if timeout > 1 else "")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||||
@@ -290,9 +277,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
observations_list, rewards, dones, infos = zip(*results)
|
observations_list, rewards, dones, infos = zip(*results)
|
||||||
|
|
||||||
if not self.shared_memory:
|
if not self.shared_memory:
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(observations_list, self.observations, self.single_observation_space)
|
||||||
observations_list, self.observations, self.single_observation_space
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
deepcopy(self.observations) if self.copy else self.observations,
|
deepcopy(self.observations) if self.copy else self.observations,
|
||||||
@@ -318,8 +303,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
try:
|
try:
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Calling `close` while waiting for a pending "
|
"Calling `close` while waiting for a pending " "call to `{0}` to complete.".format(self._state.value)
|
||||||
"call to `{0}` to complete.".format(self._state.value)
|
|
||||||
)
|
)
|
||||||
function = getattr(self, "{0}_wait".format(self._state.value))
|
function = getattr(self, "{0}_wait".format(self._state.value))
|
||||||
function(timeout)
|
function(timeout)
|
||||||
@@ -375,8 +359,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
def _assert_is_running(self):
|
def _assert_is_running(self):
|
||||||
if self.closed:
|
if self.closed:
|
||||||
raise ClosedEnvironmentError(
|
raise ClosedEnvironmentError(
|
||||||
"Trying to operate on `{0}`, after a "
|
"Trying to operate on `{0}`, after a " "call to `close()`.".format(type(self).__name__)
|
||||||
"call to `close()`.".format(type(self).__name__)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _raise_if_errors(self, successes):
|
def _raise_if_errors(self, successes):
|
||||||
@@ -387,10 +370,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
assert num_errors > 0
|
assert num_errors > 0
|
||||||
for _ in range(num_errors):
|
for _ in range(num_errors):
|
||||||
index, exctype, value = self.error_queue.get()
|
index, exctype, value = self.error_queue.get()
|
||||||
logger.error(
|
logger.error("Received the following error from Worker-{0}: " "{1}: {2}".format(index, exctype.__name__, value))
|
||||||
"Received the following error from Worker-{0}: "
|
|
||||||
"{1}: {2}".format(index, exctype.__name__, value)
|
|
||||||
)
|
|
||||||
logger.error("Shutting down Worker-{0}.".format(index))
|
logger.error("Shutting down Worker-{0}.".format(index))
|
||||||
self.parent_pipes[index].close()
|
self.parent_pipes[index].close()
|
||||||
self.parent_pipes[index] = None
|
self.parent_pipes[index] = None
|
||||||
@@ -445,17 +425,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
|||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
write_to_shared_memory(
|
write_to_shared_memory(index, observation, shared_memory, observation_space)
|
||||||
index, observation, shared_memory, observation_space
|
|
||||||
)
|
|
||||||
pipe.send((None, True))
|
pipe.send((None, True))
|
||||||
elif command == "step":
|
elif command == "step":
|
||||||
observation, reward, done, info = env.step(data)
|
observation, reward, done, info = env.step(data)
|
||||||
if done:
|
if done:
|
||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
write_to_shared_memory(
|
write_to_shared_memory(index, observation, shared_memory, observation_space)
|
||||||
index, observation, shared_memory, observation_space
|
|
||||||
)
|
|
||||||
pipe.send(((None, reward, done, info), True))
|
pipe.send(((None, reward, done, info), True))
|
||||||
elif command == "seed":
|
elif command == "seed":
|
||||||
env.seed(data)
|
env.seed(data)
|
||||||
|
@@ -44,9 +44,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._check_observation_spaces()
|
self._check_observation_spaces()
|
||||||
self.observations = create_empty_array(
|
self.observations = create_empty_array(self.single_observation_space, n=self.num_envs, fn=np.zeros)
|
||||||
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
|
||||||
)
|
|
||||||
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
||||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||||
self._actions = None
|
self._actions = None
|
||||||
@@ -67,9 +65,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
for env in self.envs:
|
for env in self.envs:
|
||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(observations, self.observations, self.single_observation_space)
|
||||||
observations, self.observations, self.single_observation_space
|
|
||||||
)
|
|
||||||
|
|
||||||
return deepcopy(self.observations) if self.copy else self.observations
|
return deepcopy(self.observations) if self.copy else self.observations
|
||||||
|
|
||||||
@@ -84,9 +80,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
infos.append(info)
|
infos.append(info)
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(observations, self.observations, self.single_observation_space)
|
||||||
observations, self.observations, self.single_observation_space
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
deepcopy(self.observations) if self.copy else self.observations,
|
deepcopy(self.observations) if self.copy else self.observations,
|
||||||
|
@@ -10,9 +10,7 @@ from gym.vector.tests.utils import spaces
|
|||||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
|
||||||
)
|
|
||||||
def test_concatenate(space):
|
def test_concatenate(space):
|
||||||
def assert_type(lhs, rhs, n):
|
def assert_type(lhs, rhs, n):
|
||||||
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
|
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
|
||||||
@@ -53,9 +51,7 @@ def test_concatenate(space):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", [1, 8])
|
@pytest.mark.parametrize("n", [1, 8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
|
||||||
)
|
|
||||||
def test_create_empty_array(space, n):
|
def test_create_empty_array(space, n):
|
||||||
def assert_nested_type(arr, space, n):
|
def assert_nested_type(arr, space, n):
|
||||||
if isinstance(space, _BaseGymSpaces):
|
if isinstance(space, _BaseGymSpaces):
|
||||||
@@ -83,9 +79,7 @@ def test_create_empty_array(space, n):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", [1, 8])
|
@pytest.mark.parametrize("n", [1, 8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
|
||||||
)
|
|
||||||
def test_create_empty_array_zeros(space, n):
|
def test_create_empty_array_zeros(space, n):
|
||||||
def assert_nested_type(arr, space, n):
|
def assert_nested_type(arr, space, n):
|
||||||
if isinstance(space, _BaseGymSpaces):
|
if isinstance(space, _BaseGymSpaces):
|
||||||
@@ -113,9 +107,7 @@ def test_create_empty_array_zeros(space, n):
|
|||||||
assert_nested_type(array, space, n=n)
|
assert_nested_type(array, space, n=n)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
|
||||||
)
|
|
||||||
def test_create_empty_array_none_shape_ones(space):
|
def test_create_empty_array_none_shape_ones(space):
|
||||||
def assert_nested_type(arr, space):
|
def assert_nested_type(arr, space):
|
||||||
if isinstance(space, _BaseGymSpaces):
|
if isinstance(space, _BaseGymSpaces):
|
||||||
|
@@ -46,9 +46,7 @@ expected_types = [
|
|||||||
list(zip(spaces, expected_types)),
|
list(zip(spaces, expected_types)),
|
||||||
ids=[space.__class__.__name__ for space in spaces],
|
ids=[space.__class__.__name__ for space in spaces],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
|
||||||
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
|
||||||
)
|
|
||||||
def test_create_shared_memory(space, expected_type, n, ctx):
|
def test_create_shared_memory(space, expected_type, n, ctx):
|
||||||
def assert_nested_type(lhs, rhs, n):
|
def assert_nested_type(lhs, rhs, n):
|
||||||
assert type(lhs) == type(rhs)
|
assert type(lhs) == type(rhs)
|
||||||
@@ -77,9 +75,7 @@ def test_create_shared_memory(space, expected_type, n, ctx):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n", [1, 8])
|
@pytest.mark.parametrize("n", [1, 8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"])
|
||||||
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("space", custom_spaces)
|
@pytest.mark.parametrize("space", custom_spaces)
|
||||||
def test_create_shared_memory_custom_space(n, ctx, space):
|
def test_create_shared_memory_custom_space(n, ctx, space):
|
||||||
ctx = mp if (ctx is None) else mp.get_context(ctx)
|
ctx = mp if (ctx is None) else mp.get_context(ctx)
|
||||||
@@ -87,9 +83,7 @@ def test_create_shared_memory_custom_space(n, ctx, space):
|
|||||||
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
|
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
|
||||||
)
|
|
||||||
def test_write_to_shared_memory(space):
|
def test_write_to_shared_memory(space):
|
||||||
def assert_nested_equal(lhs, rhs):
|
def assert_nested_equal(lhs, rhs):
|
||||||
assert isinstance(rhs, list)
|
assert isinstance(rhs, list)
|
||||||
@@ -113,9 +107,7 @@ def test_write_to_shared_memory(space):
|
|||||||
shared_memory_n8 = create_shared_memory(space, n=8)
|
shared_memory_n8 = create_shared_memory(space, n=8)
|
||||||
samples = [space.sample() for _ in range(8)]
|
samples = [space.sample() for _ in range(8)]
|
||||||
|
|
||||||
processes = [
|
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
|
||||||
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
|
|
||||||
]
|
|
||||||
|
|
||||||
for process in processes:
|
for process in processes:
|
||||||
process.start()
|
process.start()
|
||||||
@@ -125,25 +117,19 @@ def test_write_to_shared_memory(space):
|
|||||||
assert_nested_equal(shared_memory_n8, samples)
|
assert_nested_equal(shared_memory_n8, samples)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("space", spaces, ids=[space.__class__.__name__ for space in spaces])
|
||||||
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
|
|
||||||
)
|
|
||||||
def test_read_from_shared_memory(space):
|
def test_read_from_shared_memory(space):
|
||||||
def assert_nested_equal(lhs, rhs, space, n):
|
def assert_nested_equal(lhs, rhs, space, n):
|
||||||
assert isinstance(rhs, list)
|
assert isinstance(rhs, list)
|
||||||
if isinstance(space, Tuple):
|
if isinstance(space, Tuple):
|
||||||
assert isinstance(lhs, tuple)
|
assert isinstance(lhs, tuple)
|
||||||
for i in range(len(lhs)):
|
for i in range(len(lhs)):
|
||||||
assert_nested_equal(
|
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n)
|
||||||
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(space, Dict):
|
elif isinstance(space, Dict):
|
||||||
assert isinstance(lhs, OrderedDict)
|
assert isinstance(lhs, OrderedDict)
|
||||||
for key in lhs.keys():
|
for key in lhs.keys():
|
||||||
assert_nested_equal(
|
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n)
|
||||||
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(space, _BaseGymSpaces):
|
elif isinstance(space, _BaseGymSpaces):
|
||||||
assert isinstance(lhs, np.ndarray)
|
assert isinstance(lhs, np.ndarray)
|
||||||
@@ -161,9 +147,7 @@ def test_read_from_shared_memory(space):
|
|||||||
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
|
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
|
||||||
samples = [space.sample() for _ in range(8)]
|
samples = [space.sample() for _ in range(8)]
|
||||||
|
|
||||||
processes = [
|
processes = [Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)]
|
||||||
Process(target=write, args=(i, shared_memory_n8, samples[i])) for i in range(8)
|
|
||||||
]
|
|
||||||
|
|
||||||
for process in processes:
|
for process in processes:
|
||||||
process.start()
|
process.start()
|
||||||
|
@@ -10,12 +10,8 @@ expected_batch_spaces_4 = [
|
|||||||
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
|
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
|
||||||
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32),
|
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float32),
|
||||||
Box(
|
Box(
|
||||||
low=np.array(
|
low=np.array([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]),
|
||||||
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]
|
high=np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
|
||||||
),
|
|
||||||
high=np.array(
|
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
|
|
||||||
),
|
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
),
|
),
|
||||||
Box(
|
Box(
|
||||||
|
@@ -7,12 +7,8 @@ from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
|||||||
spaces = [
|
spaces = [
|
||||||
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
|
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
|
||||||
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32),
|
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float32),
|
||||||
Box(
|
Box(low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32),
|
||||||
low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32
|
Box(low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32),
|
||||||
),
|
|
||||||
Box(
|
|
||||||
low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float32
|
|
||||||
),
|
|
||||||
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
||||||
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
||||||
Discrete(2),
|
Discrete(2),
|
||||||
@@ -28,17 +24,13 @@ spaces = [
|
|||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Discrete(23),
|
"position": Discrete(23),
|
||||||
"velocity": Box(
|
"velocity": Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32),
|
||||||
low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
Dict(
|
Dict(
|
||||||
{
|
{
|
||||||
"position": Dict({"x": Discrete(29), "y": Discrete(31)}),
|
"position": Dict({"x": Discrete(29), "y": Discrete(31)}),
|
||||||
"velocity": Tuple(
|
"velocity": Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))),
|
||||||
(Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -50,9 +42,7 @@ class UnittestSlowEnv(gym.Env):
|
|||||||
def __init__(self, slow_reset=0.3):
|
def __init__(self, slow_reset=0.3):
|
||||||
super(UnittestSlowEnv, self).__init__()
|
super(UnittestSlowEnv, self).__init__()
|
||||||
self.slow_reset = slow_reset
|
self.slow_reset = slow_reset
|
||||||
self.observation_space = Box(
|
self.observation_space = Box(low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||||||
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
|
|
||||||
)
|
|
||||||
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
|
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@@ -46,10 +46,7 @@ def concatenate(items, out, space):
|
|||||||
elif isinstance(space, Space):
|
elif isinstance(space, Space):
|
||||||
return concatenate_custom(items, out, space)
|
return concatenate_custom(items, out, space)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
|
||||||
"Space of type `{0}` is not a valid `gym.Space` "
|
|
||||||
"instance.".format(type(space))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def concatenate_base(items, out, space):
|
def concatenate_base(items, out, space):
|
||||||
@@ -57,18 +54,12 @@ def concatenate_base(items, out, space):
|
|||||||
|
|
||||||
|
|
||||||
def concatenate_tuple(items, out, space):
|
def concatenate_tuple(items, out, space):
|
||||||
return tuple(
|
return tuple(concatenate([item[i] for item in items], out[i], subspace) for (i, subspace) in enumerate(space.spaces))
|
||||||
concatenate([item[i] for item in items], out[i], subspace)
|
|
||||||
for (i, subspace) in enumerate(space.spaces)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def concatenate_dict(items, out, space):
|
def concatenate_dict(items, out, space):
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[(key, concatenate([item[key] for item in items], out[key], subspace)) for (key, subspace) in space.spaces.items()]
|
||||||
(key, concatenate([item[key] for item in items], out[key], subspace))
|
|
||||||
for (key, subspace) in space.spaces.items()
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -118,10 +109,7 @@ def create_empty_array(space, n=1, fn=np.zeros):
|
|||||||
elif isinstance(space, Space):
|
elif isinstance(space, Space):
|
||||||
return create_empty_array_custom(space, n=n, fn=fn)
|
return create_empty_array_custom(space, n=n, fn=fn)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Space of type `{0}` is not a valid `gym.Space` " "instance.".format(type(space)))
|
||||||
"Space of type `{0}` is not a valid `gym.Space` "
|
|
||||||
"instance.".format(type(space))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_empty_array_base(space, n=1, fn=np.zeros):
|
def create_empty_array_base(space, n=1, fn=np.zeros):
|
||||||
@@ -134,12 +122,7 @@ def create_empty_array_tuple(space, n=1, fn=np.zeros):
|
|||||||
|
|
||||||
|
|
||||||
def create_empty_array_dict(space, n=1, fn=np.zeros):
|
def create_empty_array_dict(space, n=1, fn=np.zeros):
|
||||||
return OrderedDict(
|
return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn)) for (key, subspace) in space.spaces.items()])
|
||||||
[
|
|
||||||
(key, create_empty_array(subspace, n=n, fn=fn))
|
|
||||||
for (key, subspace) in space.spaces.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_empty_array_custom(space, n=1, fn=np.zeros):
|
def create_empty_array_custom(space, n=1, fn=np.zeros):
|
||||||
|
@@ -56,18 +56,11 @@ def create_base_shared_memory(space, n=1, ctx=mp):
|
|||||||
|
|
||||||
|
|
||||||
def create_tuple_shared_memory(space, n=1, ctx=mp):
|
def create_tuple_shared_memory(space, n=1, ctx=mp):
|
||||||
return tuple(
|
return tuple(create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces)
|
||||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_dict_shared_memory(space, n=1, ctx=mp):
|
def create_dict_shared_memory(space, n=1, ctx=mp):
|
||||||
return OrderedDict(
|
return OrderedDict([(key, create_shared_memory(subspace, n=n, ctx=ctx)) for (key, subspace) in space.spaces.items()])
|
||||||
[
|
|
||||||
(key, create_shared_memory(subspace, n=n, ctx=ctx))
|
|
||||||
for (key, subspace) in space.spaces.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def read_from_shared_memory(shared_memory, space, n=1):
|
def read_from_shared_memory(shared_memory, space, n=1):
|
||||||
@@ -114,24 +107,16 @@ def read_from_shared_memory(shared_memory, space, n=1):
|
|||||||
|
|
||||||
|
|
||||||
def read_base_from_shared_memory(shared_memory, space, n=1):
|
def read_base_from_shared_memory(shared_memory, space, n=1):
|
||||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape((n,) + space.shape)
|
||||||
(n,) + space.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def read_tuple_from_shared_memory(shared_memory, space, n=1):
|
def read_tuple_from_shared_memory(shared_memory, space, n=1):
|
||||||
return tuple(
|
return tuple(read_from_shared_memory(memory, subspace, n=n) for (memory, subspace) in zip(shared_memory, space.spaces))
|
||||||
read_from_shared_memory(memory, subspace, n=n)
|
|
||||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def read_dict_from_shared_memory(shared_memory, space, n=1):
|
def read_dict_from_shared_memory(shared_memory, space, n=1):
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[(key, read_from_shared_memory(shared_memory[key], subspace, n=n)) for (key, subspace) in space.spaces.items()]
|
||||||
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
|
|
||||||
for (key, subspace) in space.spaces.items()
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -44,8 +44,7 @@ def batch_space(space, n=1):
|
|||||||
return batch_space_custom(space, n=n)
|
return batch_space_custom(space, n=n)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot batch space with type `{0}`. The space must "
|
"Cannot batch space with type `{0}`. The space must " "be a valid `gym.Space` instance.".format(type(space))
|
||||||
"be a valid `gym.Space` instance.".format(type(space))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,14 +74,7 @@ def batch_space_tuple(space, n=1):
|
|||||||
|
|
||||||
|
|
||||||
def batch_space_dict(space, n=1):
|
def batch_space_dict(space, n=1):
|
||||||
return Dict(
|
return Dict(OrderedDict([(key, batch_space(subspace, n=n)) for (key, subspace) in space.spaces.items()]))
|
||||||
OrderedDict(
|
|
||||||
[
|
|
||||||
(key, batch_space(subspace, n=n))
|
|
||||||
for (key, subspace) in space.spaces.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_space_custom(space, n=1):
|
def batch_space_custom(space, n=1):
|
||||||
|
@@ -141,9 +141,7 @@ class VectorEnv(gym.Env):
|
|||||||
if self.spec is None:
|
if self.spec is None:
|
||||||
return "{}({})".format(self.__class__.__name__, self.num_envs)
|
return "{}({})".format(self.__class__.__name__, self.num_envs)
|
||||||
else:
|
else:
|
||||||
return "{}({}, {})".format(
|
return "{}({}, {})".format(self.__class__.__name__, self.spec.id, self.num_envs)
|
||||||
self.__class__.__name__, self.spec.id, self.num_envs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorEnvWrapper(VectorEnv):
|
class VectorEnvWrapper(VectorEnv):
|
||||||
@@ -189,9 +187,7 @@ class VectorEnvWrapper(VectorEnv):
|
|||||||
# implicitly forward all other methods and attributes to self.env
|
# implicitly forward all other methods and attributes to self.env
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name.startswith("_"):
|
if name.startswith("_"):
|
||||||
raise AttributeError(
|
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
|
||||||
"attempted to get missing private attribute '{}'".format(name)
|
|
||||||
)
|
|
||||||
return getattr(self.env, name)
|
return getattr(self.env, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@@ -62,12 +62,13 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
assert noop_max >= 0
|
assert noop_max >= 0
|
||||||
if frame_skip > 1:
|
if frame_skip > 1:
|
||||||
assert "NoFrameskip" in env.spec.id, (
|
assert "NoFrameskip" in env.spec.id, (
|
||||||
"disable frame-skipping in the original env. for more than one"
|
"disable frame-skipping in the original env. for more than one" " frame-skip as it will be done by the wrapper"
|
||||||
" frame-skip as it will be done by the wrapper"
|
|
||||||
)
|
)
|
||||||
self.noop_max = noop_max
|
self.noop_max = noop_max
|
||||||
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
|
|
||||||
self.frame_skip = frame_skip
|
self.frame_skip = frame_skip
|
||||||
self.screen_size = screen_size
|
self.screen_size = screen_size
|
||||||
@@ -92,15 +93,11 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
self.lives = 0
|
self.lives = 0
|
||||||
self.game_over = False
|
self.game_over = False
|
||||||
|
|
||||||
_low, _high, _obs_dtype = (
|
_low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
||||||
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
|
|
||||||
)
|
|
||||||
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
|
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
|
||||||
if grayscale_obs and not grayscale_newaxis:
|
if grayscale_obs and not grayscale_newaxis:
|
||||||
_shape = _shape[:-1] # Remove channel axis
|
_shape = _shape[:-1] # Remove channel axis
|
||||||
self.observation_space = Box(
|
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_obs_dtype)
|
||||||
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
R = 0.0
|
R = 0.0
|
||||||
@@ -132,11 +129,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
# NoopReset
|
# NoopReset
|
||||||
self.env.reset(**kwargs)
|
self.env.reset(**kwargs)
|
||||||
noops = (
|
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
|
||||||
self.env.unwrapped.np_random.randint(1, self.noop_max + 1)
|
|
||||||
if self.noop_max > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
for _ in range(noops):
|
for _ in range(noops):
|
||||||
_, _, done, _ = self.env.step(0)
|
_, _, done, _ = self.env.step(0)
|
||||||
if done:
|
if done:
|
||||||
|
@@ -9,7 +9,9 @@ class ClipAction(ActionWrapper):
|
|||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
assert isinstance(env.action_space, Box)
|
assert isinstance(env.action_space, Box)
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
super(ClipAction, self).__init__(env)
|
super(ClipAction, self).__init__(env)
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
|
@@ -26,7 +26,9 @@ class FilterObservation(ObservationWrapper):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
wrapped_observation_space, spaces.Dict
|
wrapped_observation_space, spaces.Dict
|
||||||
), "FilterObservationWrapper is only usable with dict observations."
|
), "FilterObservationWrapper is only usable with dict observations."
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
|
|
||||||
observation_keys = wrapped_observation_space.spaces.keys()
|
observation_keys = wrapped_observation_space.spaces.keys()
|
||||||
|
|
||||||
@@ -49,11 +51,7 @@ class FilterObservation(ObservationWrapper):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.observation_space = type(wrapped_observation_space)(
|
self.observation_space = type(wrapped_observation_space)(
|
||||||
[
|
[(name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys]
|
||||||
(name, copy.deepcopy(space))
|
|
||||||
for name, space in wrapped_observation_space.spaces.items()
|
|
||||||
if name in filter_keys
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._env = env
|
self._env = env
|
||||||
@@ -64,11 +62,5 @@ class FilterObservation(ObservationWrapper):
|
|||||||
return filter_observation
|
return filter_observation
|
||||||
|
|
||||||
def _filter_observation(self, observation):
|
def _filter_observation(self, observation):
|
||||||
observation = type(observation)(
|
observation = type(observation)([(name, value) for name, value in observation.items() if name in self._filter_keys])
|
||||||
[
|
|
||||||
(name, value)
|
|
||||||
for name, value in observation.items()
|
|
||||||
if name in self._filter_keys
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return observation
|
return observation
|
||||||
|
@@ -9,7 +9,9 @@ class FlattenObservation(ObservationWrapper):
|
|||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super(FlattenObservation, self).__init__(env)
|
super(FlattenObservation, self).__init__(env)
|
||||||
self.observation_space = spaces.flatten_space(env.observation_space)
|
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
return spaces.flatten(self.env.observation_space, observation)
|
return spaces.flatten(self.env.observation_space, observation)
|
||||||
|
@@ -22,7 +22,9 @@ class LazyFrames(object):
|
|||||||
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
|
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
|
||||||
|
|
||||||
def __init__(self, frames, lz4_compress=False):
|
def __init__(self, frames, lz4_compress=False):
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
self.frame_shape = tuple(frames[0].shape)
|
self.frame_shape = tuple(frames[0].shape)
|
||||||
self.shape = (len(frames),) + self.frame_shape
|
self.shape = (len(frames),) + self.frame_shape
|
||||||
self.dtype = frames[0].dtype
|
self.dtype = frames[0].dtype
|
||||||
@@ -45,9 +47,7 @@ class LazyFrames(object):
|
|||||||
def __getitem__(self, int_or_slice):
|
def __getitem__(self, int_or_slice):
|
||||||
if isinstance(int_or_slice, int):
|
if isinstance(int_or_slice, int):
|
||||||
return self._check_decompress(self._frames[int_or_slice]) # single frame
|
return self._check_decompress(self._frames[int_or_slice]) # single frame
|
||||||
return np.stack(
|
return np.stack([self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0)
|
||||||
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.__array__() == other
|
return self.__array__() == other
|
||||||
@@ -56,9 +56,7 @@ class LazyFrames(object):
|
|||||||
if self.lz4_compress:
|
if self.lz4_compress:
|
||||||
from lz4.block import decompress
|
from lz4.block import decompress
|
||||||
|
|
||||||
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
|
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(self.frame_shape)
|
||||||
self.frame_shape
|
|
||||||
)
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|
||||||
@@ -102,12 +100,8 @@ class FrameStack(Wrapper):
|
|||||||
self.frames = deque(maxlen=num_stack)
|
self.frames = deque(maxlen=num_stack)
|
||||||
|
|
||||||
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
|
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
|
||||||
high = np.repeat(
|
high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
|
||||||
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
|
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
|
||||||
)
|
|
||||||
self.observation_space = Box(
|
|
||||||
low=low, high=high, dtype=self.observation_space.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
|
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
|
||||||
|
@@ -11,20 +11,15 @@ class GrayScaleObservation(ObservationWrapper):
|
|||||||
super(GrayScaleObservation, self).__init__(env)
|
super(GrayScaleObservation, self).__init__(env)
|
||||||
self.keep_dim = keep_dim
|
self.keep_dim = keep_dim
|
||||||
|
|
||||||
assert (
|
assert len(env.observation_space.shape) == 3 and env.observation_space.shape[-1] == 3
|
||||||
len(env.observation_space.shape) == 3
|
warnings.warn(
|
||||||
and env.observation_space.shape[-1] == 3
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
)
|
)
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
|
||||||
obs_shape = self.observation_space.shape[:2]
|
obs_shape = self.observation_space.shape[:2]
|
||||||
if self.keep_dim:
|
if self.keep_dim:
|
||||||
self.observation_space = Box(
|
self.observation_space = Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8)
|
||||||
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.observation_space = Box(
|
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
||||||
low=0, high=255, shape=obs_shape, dtype=np.uint8
|
|
||||||
)
|
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
import cv2
|
import cv2
|
||||||
|
@@ -37,9 +37,7 @@ class Monitor(Wrapper):
|
|||||||
self._monitor_id = None
|
self._monitor_id = None
|
||||||
self.env_semantics_autoreset = env.metadata.get("semantics.autoreset")
|
self.env_semantics_autoreset = env.metadata.get("semantics.autoreset")
|
||||||
|
|
||||||
self._start(
|
self._start(directory, video_callable, force, resume, write_upon_reset, uid, mode)
|
||||||
directory, video_callable, force, resume, write_upon_reset, uid, mode
|
|
||||||
)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self._before_step(action)
|
self._before_step(action)
|
||||||
@@ -163,10 +161,7 @@ class Monitor(Wrapper):
|
|||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"stats": os.path.basename(self.stats_recorder.path),
|
"stats": os.path.basename(self.stats_recorder.path),
|
||||||
"videos": [
|
"videos": [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos],
|
||||||
(os.path.basename(v), os.path.basename(m))
|
|
||||||
for v, m in self.videos
|
|
||||||
],
|
|
||||||
"env_info": self._env_info(),
|
"env_info": self._env_info(),
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
@@ -199,9 +194,7 @@ class Monitor(Wrapper):
|
|||||||
elif mode == "training":
|
elif mode == "training":
|
||||||
type = "t"
|
type = "t"
|
||||||
else:
|
else:
|
||||||
raise error.Error(
|
raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
|
||||||
'Invalid mode {}: must be "training" or "evaluation"', mode
|
|
||||||
)
|
|
||||||
self.stats_recorder.type = type
|
self.stats_recorder.type = type
|
||||||
|
|
||||||
def _before_step(self, action):
|
def _before_step(self, action):
|
||||||
@@ -257,9 +250,7 @@ class Monitor(Wrapper):
|
|||||||
env=self.env,
|
env=self.env,
|
||||||
base_path=os.path.join(
|
base_path=os.path.join(
|
||||||
self.directory,
|
self.directory,
|
||||||
"{}.video.{}.video{:06}".format(
|
"{}.video.{}.video{:06}".format(self.file_prefix, self.file_infix, self.episode_id),
|
||||||
self.file_prefix, self.file_infix, self.episode_id
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
metadata={"episode_id": self.episode_id},
|
metadata={"episode_id": self.episode_id},
|
||||||
enabled=self._video_enabled(),
|
enabled=self._video_enabled(),
|
||||||
@@ -269,9 +260,7 @@ class Monitor(Wrapper):
|
|||||||
def _close_video_recorder(self):
|
def _close_video_recorder(self):
|
||||||
self.video_recorder.close()
|
self.video_recorder.close()
|
||||||
if self.video_recorder.functional:
|
if self.video_recorder.functional:
|
||||||
self.videos.append(
|
self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
|
||||||
(self.video_recorder.path, self.video_recorder.metadata_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _video_enabled(self):
|
def _video_enabled(self):
|
||||||
return self.video_callable(self.episode_id)
|
return self.video_callable(self.episode_id)
|
||||||
@@ -301,19 +290,11 @@ class Monitor(Wrapper):
|
|||||||
def detect_training_manifests(training_dir, files=None):
|
def detect_training_manifests(training_dir, files=None):
|
||||||
if files is None:
|
if files is None:
|
||||||
files = os.listdir(training_dir)
|
files = os.listdir(training_dir)
|
||||||
return [
|
return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + ".")]
|
||||||
os.path.join(training_dir, f)
|
|
||||||
for f in files
|
|
||||||
if f.startswith(MANIFEST_PREFIX + ".")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def detect_monitor_files(training_dir):
|
def detect_monitor_files(training_dir):
|
||||||
return [
|
return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + ".")]
|
||||||
os.path.join(training_dir, f)
|
|
||||||
for f in os.listdir(training_dir)
|
|
||||||
if f.startswith(FILE_PREFIX + ".")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def clear_monitor_files(training_dir):
|
def clear_monitor_files(training_dir):
|
||||||
@@ -382,10 +363,7 @@ def load_results(training_dir):
|
|||||||
contents = json.load(f)
|
contents = json.load(f)
|
||||||
# Make these paths absolute again
|
# Make these paths absolute again
|
||||||
stats_files.append(os.path.join(training_dir, contents["stats"]))
|
stats_files.append(os.path.join(training_dir, contents["stats"]))
|
||||||
videos += [
|
videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m)) for v, m in contents["videos"]]
|
||||||
(os.path.join(training_dir, v), os.path.join(training_dir, m))
|
|
||||||
for v, m in contents["videos"]
|
|
||||||
]
|
|
||||||
env_infos.append(contents["env_info"])
|
env_infos.append(contents["env_info"])
|
||||||
|
|
||||||
env_info = collapse_env_infos(env_infos, training_dir)
|
env_info = collapse_env_infos(env_infos, training_dir)
|
||||||
|
@@ -48,9 +48,7 @@ class VideoRecorder(object):
|
|||||||
self.ansi_mode = True
|
self.ansi_mode = True
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(
|
'Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env)
|
||||||
env
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Whoops, turns out we shouldn't be enabled after all
|
# Whoops, turns out we shouldn't be enabled after all
|
||||||
self.enabled = False
|
self.enabled = False
|
||||||
@@ -69,9 +67,7 @@ class VideoRecorder(object):
|
|||||||
path = base_path + required_ext
|
path = base_path + required_ext
|
||||||
else:
|
else:
|
||||||
# Otherwise, just generate a unique filename
|
# Otherwise, just generate a unique filename
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f:
|
||||||
suffix=required_ext, delete=False
|
|
||||||
) as f:
|
|
||||||
path = f.name
|
path = f.name
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
@@ -83,28 +79,20 @@ class VideoRecorder(object):
|
|||||||
if self.ansi_mode
|
if self.ansi_mode
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
raise error.Error(
|
raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint))
|
||||||
"Invalid path given: {} -- must have file extension {}.{}".format(
|
|
||||||
self.path, required_ext, hint
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Touch the file in any case, so we know it's present. (This
|
# Touch the file in any case, so we know it's present. (This
|
||||||
# corrects for platform platform differences. Using ffmpeg on
|
# corrects for platform platform differences. Using ffmpeg on
|
||||||
# OS X, the file is precreated, but not on Linux.
|
# OS X, the file is precreated, but not on Linux.
|
||||||
touch(path)
|
touch(path)
|
||||||
|
|
||||||
self.frames_per_sec = env.metadata.get("video.frames_per_second", 30)
|
self.frames_per_sec = env.metadata.get("video.frames_per_second", 30)
|
||||||
self.output_frames_per_sec = env.metadata.get(
|
self.output_frames_per_sec = env.metadata.get("video.output_frames_per_second", self.frames_per_sec)
|
||||||
"video.output_frames_per_second", self.frames_per_sec
|
|
||||||
)
|
|
||||||
self.encoder = None # lazily start the process
|
self.encoder = None # lazily start the process
|
||||||
self.broken = False
|
self.broken = False
|
||||||
|
|
||||||
# Dump metadata
|
# Dump metadata
|
||||||
self.metadata = metadata or {}
|
self.metadata = metadata or {}
|
||||||
self.metadata["content_type"] = (
|
self.metadata["content_type"] = "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
|
||||||
"video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
|
|
||||||
)
|
|
||||||
self.metadata_path = "{}.meta.json".format(path_base)
|
self.metadata_path = "{}.meta.json".format(path_base)
|
||||||
self.write_metadata()
|
self.write_metadata()
|
||||||
|
|
||||||
@@ -191,9 +179,7 @@ class VideoRecorder(object):
|
|||||||
|
|
||||||
def _encode_image_frame(self, frame):
|
def _encode_image_frame(self, frame):
|
||||||
if not self.encoder:
|
if not self.encoder:
|
||||||
self.encoder = ImageEncoder(
|
self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec)
|
||||||
self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
|
|
||||||
)
|
|
||||||
self.metadata["encoder_version"] = self.encoder.version_info
|
self.metadata["encoder_version"] = self.encoder.version_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -222,24 +208,16 @@ class TextEncoder(object):
|
|||||||
string = frame.getvalue()
|
string = frame.getvalue()
|
||||||
else:
|
else:
|
||||||
raise error.InvalidFrame(
|
raise error.InvalidFrame(
|
||||||
"Wrong type {} for {}: text frame must be a string or StringIO".format(
|
"Wrong type {} for {}: text frame must be a string or StringIO".format(type(frame), frame)
|
||||||
type(frame), frame
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
frame_bytes = string.encode("utf-8")
|
frame_bytes = string.encode("utf-8")
|
||||||
|
|
||||||
if frame_bytes[-1:] != b"\n":
|
if frame_bytes[-1:] != b"\n":
|
||||||
raise error.InvalidFrame(
|
raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string))
|
||||||
'Frame must end with a newline: """{}"""'.format(string)
|
|
||||||
)
|
|
||||||
|
|
||||||
if b"\r" in frame_bytes:
|
if b"\r" in frame_bytes:
|
||||||
raise error.InvalidFrame(
|
raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string))
|
||||||
'Frame contains carriage returns (only newlines are allowed: """{}"""'.format(
|
|
||||||
string
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.frames.append(frame_bytes)
|
self.frames.append(frame_bytes)
|
||||||
|
|
||||||
@@ -263,15 +241,7 @@ class TextEncoder(object):
|
|||||||
# Calculate frame size from the largest frames.
|
# Calculate frame size from the largest frames.
|
||||||
# Add some padding since we'll get cut off otherwise.
|
# Add some padding since we'll get cut off otherwise.
|
||||||
height = max([frame.count(b"\n") for frame in self.frames]) + 1
|
height = max([frame.count(b"\n") for frame in self.frames]) + 1
|
||||||
width = (
|
width = max([max([len(line) for line in frame.split(b"\n")]) for frame in self.frames]) + 2
|
||||||
max(
|
|
||||||
[
|
|
||||||
max([len(line) for line in frame.split(b"\n")])
|
|
||||||
for frame in self.frames
|
|
||||||
]
|
|
||||||
)
|
|
||||||
+ 2
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
@@ -325,11 +295,7 @@ class ImageEncoder(object):
|
|||||||
def version_info(self):
|
def version_info(self):
|
||||||
return {
|
return {
|
||||||
"backend": self.backend,
|
"backend": self.backend,
|
||||||
"version": str(
|
"version": str(subprocess.check_output([self.backend, "-version"], stderr=subprocess.STDOUT)),
|
||||||
subprocess.check_output(
|
|
||||||
[self.backend, "-version"], stderr=subprocess.STDOUT
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"cmdline": self.cmdline,
|
"cmdline": self.cmdline,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -396,19 +362,13 @@ class ImageEncoder(object):
|
|||||||
|
|
||||||
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
|
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
|
||||||
if hasattr(os, "setsid"): # setsid not present on Windows
|
if hasattr(os, "setsid"): # setsid not present on Windows
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
|
||||||
self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
|
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
|
||||||
|
|
||||||
def capture_frame(self, frame):
|
def capture_frame(self, frame):
|
||||||
if not isinstance(frame, (np.ndarray, np.generic)):
|
if not isinstance(frame, (np.ndarray, np.generic)):
|
||||||
raise error.InvalidFrame(
|
raise error.InvalidFrame("Wrong type {} for {} (must be np.ndarray or np.generic)".format(type(frame), frame))
|
||||||
"Wrong type {} for {} (must be np.ndarray or np.generic)".format(
|
|
||||||
type(frame), frame
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if frame.shape != self.frame_shape:
|
if frame.shape != self.frame_shape:
|
||||||
raise error.InvalidFrame(
|
raise error.InvalidFrame(
|
||||||
"Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(
|
"Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(
|
||||||
@@ -417,15 +377,11 @@ class ImageEncoder(object):
|
|||||||
)
|
)
|
||||||
if frame.dtype != np.uint8:
|
if frame.dtype != np.uint8:
|
||||||
raise error.InvalidFrame(
|
raise error.InvalidFrame(
|
||||||
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(
|
"Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype)
|
||||||
frame.dtype
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if distutils.version.LooseVersion(
|
if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion("1.9.0"):
|
||||||
np.__version__
|
|
||||||
) >= distutils.version.LooseVersion("1.9.0"):
|
|
||||||
self.proc.stdin.write(frame.tobytes())
|
self.proc.stdin.write(frame.tobytes())
|
||||||
else:
|
else:
|
||||||
self.proc.stdin.write(frame.tostring())
|
self.proc.stdin.write(frame.tostring())
|
||||||
|
@@ -14,9 +14,7 @@ STATE_KEY = "state"
|
|||||||
class PixelObservationWrapper(ObservationWrapper):
|
class PixelObservationWrapper(ObservationWrapper):
|
||||||
"""Augment observations by pixel values."""
|
"""Augment observations by pixel values."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)):
|
||||||
self, env, pixels_only=True, render_kwargs=None, pixel_keys=("pixels",)
|
|
||||||
):
|
|
||||||
"""Initializes a new pixel Wrapper.
|
"""Initializes a new pixel Wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -52,7 +50,9 @@ class PixelObservationWrapper(ObservationWrapper):
|
|||||||
assert render_mode == "rgb_array", render_mode
|
assert render_mode == "rgb_array", render_mode
|
||||||
render_kwargs[key]["mode"] = "rgb_array"
|
render_kwargs[key]["mode"] = "rgb_array"
|
||||||
|
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
|
|
||||||
wrapped_observation_space = env.observation_space
|
wrapped_observation_space = env.observation_space
|
||||||
|
|
||||||
@@ -70,9 +70,7 @@ class PixelObservationWrapper(ObservationWrapper):
|
|||||||
# `observation_keys`
|
# `observation_keys`
|
||||||
overlapping_keys = set(pixel_keys) & set(invalid_keys)
|
overlapping_keys = set(pixel_keys) & set(invalid_keys)
|
||||||
if overlapping_keys:
|
if overlapping_keys:
|
||||||
raise ValueError(
|
raise ValueError("Duplicate or reserved pixel keys {!r}.".format(overlapping_keys))
|
||||||
"Duplicate or reserved pixel keys {!r}.".format(overlapping_keys)
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixels_only:
|
if pixels_only:
|
||||||
self.observation_space = spaces.Dict()
|
self.observation_space = spaces.Dict()
|
||||||
@@ -95,9 +93,7 @@ class PixelObservationWrapper(ObservationWrapper):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(pixels.dtype)
|
raise TypeError(pixels.dtype)
|
||||||
|
|
||||||
pixels_space = spaces.Box(
|
pixels_space = spaces.Box(shape=pixels.shape, low=low, high=high, dtype=pixels.dtype)
|
||||||
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
|
|
||||||
)
|
|
||||||
pixels_spaces[pixel_key] = pixels_space
|
pixels_spaces[pixel_key] = pixels_space
|
||||||
|
|
||||||
self.observation_space.spaces.update(pixels_spaces)
|
self.observation_space.spaces.update(pixels_spaces)
|
||||||
@@ -120,10 +116,7 @@ class PixelObservationWrapper(ObservationWrapper):
|
|||||||
observation = collections.OrderedDict()
|
observation = collections.OrderedDict()
|
||||||
observation[STATE_KEY] = wrapped_observation
|
observation[STATE_KEY] = wrapped_observation
|
||||||
|
|
||||||
pixel_observations = {
|
pixel_observations = {pixel_key: self.env.render(**self._render_kwargs[pixel_key]) for pixel_key in self._pixel_keys}
|
||||||
pixel_key: self.env.render(**self._render_kwargs[pixel_key])
|
|
||||||
for pixel_key in self._pixel_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
observation.update(pixel_observations)
|
observation.update(pixel_observations)
|
||||||
|
|
||||||
|
@@ -7,14 +7,14 @@ import gym
|
|||||||
class RecordEpisodeStatistics(gym.Wrapper):
|
class RecordEpisodeStatistics(gym.Wrapper):
|
||||||
def __init__(self, env, deque_size=100):
|
def __init__(self, env, deque_size=100):
|
||||||
super(RecordEpisodeStatistics, self).__init__(env)
|
super(RecordEpisodeStatistics, self).__init__(env)
|
||||||
self.t0 = (
|
self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support
|
||||||
time.time()
|
|
||||||
) # TODO: use perf_counter when gym removes Python 2 support
|
|
||||||
self.episode_return = 0.0
|
self.episode_return = 0.0
|
||||||
self.episode_length = 0
|
self.episode_length = 0
|
||||||
self.return_queue = deque(maxlen=deque_size)
|
self.return_queue = deque(maxlen=deque_size)
|
||||||
self.length_queue = deque(maxlen=deque_size)
|
self.length_queue = deque(maxlen=deque_size)
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
||||||
@@ -23,9 +23,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
|
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(action)
|
||||||
action
|
|
||||||
)
|
|
||||||
self.episode_return += reward
|
self.episode_return += reward
|
||||||
self.episode_length += 1
|
self.episode_length += 1
|
||||||
if done:
|
if done:
|
||||||
|
@@ -15,17 +15,15 @@ class RescaleAction(gym.ActionWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, a, b):
|
def __init__(self, env, a, b):
|
||||||
assert isinstance(
|
assert isinstance(env.action_space, spaces.Box), "expected Box action space, got {}".format(type(env.action_space))
|
||||||
env.action_space, spaces.Box
|
|
||||||
), "expected Box action space, got {}".format(type(env.action_space))
|
|
||||||
assert np.less_equal(a, b).all(), (a, b)
|
assert np.less_equal(a, b).all(), (a, b)
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
super(RescaleAction, self).__init__(env)
|
super(RescaleAction, self).__init__(env)
|
||||||
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
|
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
|
||||||
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
|
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
|
||||||
self.action_space = spaces.Box(
|
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
|
||||||
low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
assert np.all(np.greater_equal(action, self.a)), (action, self.a)
|
assert np.all(np.greater_equal(action, self.a)), (action, self.a)
|
||||||
|
@@ -12,7 +12,9 @@ class ResizeObservation(ObservationWrapper):
|
|||||||
if isinstance(shape, int):
|
if isinstance(shape, int):
|
||||||
shape = (shape, shape)
|
shape = (shape, shape)
|
||||||
assert all(x > 0 for x in shape), shape
|
assert all(x > 0 for x in shape), shape
|
||||||
warnings.warn("Gym\'s internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit")
|
warnings.warn(
|
||||||
|
"Gym's internal preprocessing wrappers are now deprecated. While they will continue to work for the foreseeable future, we strongly recommend using SuperSuit instead: https://github.com/PettingZoo-Team/SuperSuit"
|
||||||
|
)
|
||||||
self.shape = tuple(shape)
|
self.shape = tuple(shape)
|
||||||
|
|
||||||
obs_shape = self.shape + self.observation_space.shape[2:]
|
obs_shape = self.shape + self.observation_space.shape[2:]
|
||||||
@@ -21,9 +23,7 @@ class ResizeObservation(ObservationWrapper):
|
|||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
observation = cv2.resize(
|
observation = cv2.resize(observation, self.shape[::-1], interpolation=cv2.INTER_AREA)
|
||||||
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
|
|
||||||
)
|
|
||||||
if observation.ndim == 2:
|
if observation.ndim == 2:
|
||||||
observation = np.expand_dims(observation, -1)
|
observation = np.expand_dims(observation, -1)
|
||||||
return observation
|
return observation
|
||||||
|
@@ -15,12 +15,8 @@ def test_atari_preprocessing_grayscale(env_fn):
|
|||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
env1 = env_fn()
|
env1 = env_fn()
|
||||||
env2 = AtariPreprocessing(
|
env2 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0)
|
||||||
env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0
|
env3 = AtariPreprocessing(env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0)
|
||||||
)
|
|
||||||
env3 = AtariPreprocessing(
|
|
||||||
env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0
|
|
||||||
)
|
|
||||||
env4 = AtariPreprocessing(
|
env4 = AtariPreprocessing(
|
||||||
env_fn(),
|
env_fn(),
|
||||||
screen_size=84,
|
screen_size=84,
|
||||||
@@ -79,15 +75,11 @@ def test_atari_preprocessing_scale(env_fn):
|
|||||||
obs = env.reset().flatten()
|
obs = env.reset().flatten()
|
||||||
done, step_i = False, 0
|
done, step_i = False, 0
|
||||||
max_obs = 1 if scaled else 255
|
max_obs = 1 if scaled else 255
|
||||||
assert (0 <= obs).all() and (
|
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
||||||
obs <= max_obs
|
|
||||||
).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
|
||||||
while not done or step_i <= max_test_steps:
|
while not done or step_i <= max_test_steps:
|
||||||
obs, _, done, _ = env.step(env.action_space.sample())
|
obs, _, done, _ = env.step(env.action_space.sample())
|
||||||
obs = obs.flatten()
|
obs = obs.flatten()
|
||||||
assert (0 <= obs).all() and (
|
assert (0 <= obs).all() and (obs <= max_obs).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
||||||
obs <= max_obs
|
|
||||||
).all(), "Obs. must be in range [0,{}]".format(max_obs)
|
|
||||||
step_i += 1
|
step_i += 1
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -19,9 +19,7 @@ def test_clip_action():
|
|||||||
|
|
||||||
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
||||||
for action in actions:
|
for action in actions:
|
||||||
obs1, r1, d1, _ = env.step(
|
obs1, r1, d1, _ = env.step(np.clip(action, env.action_space.low, env.action_space.high))
|
||||||
np.clip(action, env.action_space.low, env.action_space.high)
|
|
||||||
)
|
|
||||||
obs2, r2, d2, _ = wrapped_env.step(action)
|
obs2, r2, d2, _ = wrapped_env.step(action)
|
||||||
assert np.allclose(r1, r2)
|
assert np.allclose(r1, r2)
|
||||||
assert np.allclose(obs1, obs2)
|
assert np.allclose(obs1, obs2)
|
||||||
|
@@ -9,10 +9,7 @@ from gym.wrappers.filter_observation import FilterObservation
|
|||||||
class FakeEnvironment(gym.Env):
|
class FakeEnvironment(gym.Env):
|
||||||
def __init__(self, observation_keys=("state")):
|
def __init__(self, observation_keys=("state")):
|
||||||
self.observation_space = spaces.Dict(
|
self.observation_space = spaces.Dict(
|
||||||
{
|
{name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) for name in observation_keys}
|
||||||
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
|
|
||||||
for name in observation_keys
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
|
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
|
||||||
|
|
||||||
@@ -48,9 +45,7 @@ ERROR_TEST_CASES = (
|
|||||||
|
|
||||||
|
|
||||||
class TestFilterObservation(object):
|
class TestFilterObservation(object):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES)
|
||||||
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
|
|
||||||
)
|
|
||||||
def test_filter_observation(self, observation_keys, filter_keys):
|
def test_filter_observation(self, observation_keys, filter_keys):
|
||||||
env = FakeEnvironment(observation_keys=observation_keys)
|
env = FakeEnvironment(observation_keys=observation_keys)
|
||||||
|
|
||||||
@@ -73,9 +68,7 @@ class TestFilterObservation(object):
|
|||||||
assert len(observation) == len(filter_keys)
|
assert len(observation) == len(filter_keys)
|
||||||
|
|
||||||
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
|
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
|
||||||
def test_raises_with_incorrect_arguments(
|
def test_raises_with_incorrect_arguments(self, filter_keys, error_type, error_match):
|
||||||
self, filter_keys, error_type, error_match
|
|
||||||
):
|
|
||||||
env = FakeEnvironment(observation_keys=("key1", "key2"))
|
env = FakeEnvironment(observation_keys=("key1", "key2"))
|
||||||
|
|
||||||
ValueError
|
ValueError
|
||||||
|
@@ -16,14 +16,10 @@ def test_flatten_observation(env_id):
|
|||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs = wrapped_env.reset()
|
||||||
|
|
||||||
if env_id == "Blackjack-v0":
|
if env_id == "Blackjack-v0":
|
||||||
space = spaces.Tuple(
|
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
||||||
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
|
|
||||||
)
|
|
||||||
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
|
wrapped_space = spaces.Box(-np.inf, np.inf, [32 + 11 + 2], dtype=np.float32)
|
||||||
elif env_id == "KellyCoinflip-v0":
|
elif env_id == "KellyCoinflip-v0":
|
||||||
space = spaces.Tuple(
|
space = spaces.Tuple((spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1)))
|
||||||
(spaces.Box(0, 250.0, [1], dtype=np.float32), spaces.Discrete(300 + 1))
|
|
||||||
)
|
|
||||||
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
|
wrapped_space = spaces.Box(-np.inf, np.inf, [1 + (300 + 1)], dtype=np.float32)
|
||||||
|
|
||||||
assert space.contains(obs)
|
assert space.contains(obs)
|
||||||
|
@@ -19,9 +19,7 @@ except ImportError:
|
|||||||
[
|
[
|
||||||
pytest.param(
|
pytest.param(
|
||||||
True,
|
True,
|
||||||
marks=pytest.mark.skipif(
|
marks=pytest.mark.skipif(lz4 is None, reason="Need lz4 to run tests with compression"),
|
||||||
lz4 is None, reason="Need lz4 to run tests with compression"
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
],
|
],
|
||||||
|
@@ -10,9 +10,7 @@ pytest.importorskip("atari_py")
|
|||||||
pytest.importorskip("cv2")
|
pytest.importorskip("cv2")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"])
|
||||||
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("keep_dim", [True, False])
|
@pytest.mark.parametrize("keep_dim", [True, False])
|
||||||
def test_gray_scale_observation(env_id, keep_dim):
|
def test_gray_scale_observation(env_id, keep_dim):
|
||||||
gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True)
|
gray_env = AtariPreprocessing(gym.make(env_id), screen_size=84, grayscale_obs=True)
|
||||||
|
@@ -32,9 +32,7 @@ class FakeEnvironment(gym.Env):
|
|||||||
|
|
||||||
class FakeArrayObservationEnvironment(FakeEnvironment):
|
class FakeArrayObservationEnvironment(FakeEnvironment):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
|
||||||
shape=(2,), low=-1, high=1, dtype=np.float32
|
|
||||||
)
|
|
||||||
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
|
super(FakeArrayObservationEnvironment, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,10 +73,7 @@ class TestPixelObservationWrapper(object):
|
|||||||
assert len(wrapped_env.observation_space.spaces) == 1
|
assert len(wrapped_env.observation_space.spaces) == 1
|
||||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||||
else:
|
else:
|
||||||
assert (
|
assert len(wrapped_env.observation_space.spaces) == len(observation_space.spaces) + 1
|
||||||
len(wrapped_env.observation_space.spaces)
|
|
||||||
== len(observation_space.spaces) + 1
|
|
||||||
)
|
|
||||||
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
||||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||||
|
|
||||||
@@ -97,9 +92,7 @@ class TestPixelObservationWrapper(object):
|
|||||||
observation_space = env.observation_space
|
observation_space = env.observation_space
|
||||||
assert isinstance(observation_space, spaces.Box)
|
assert isinstance(observation_space, spaces.Box)
|
||||||
|
|
||||||
wrapped_env = PixelObservationWrapper(
|
wrapped_env = PixelObservationWrapper(env, pixel_keys=(pixel_key,), pixels_only=pixels_only)
|
||||||
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
|
|
||||||
)
|
|
||||||
wrapped_env.observation_space = wrapped_env.observation_space
|
wrapped_env.observation_space = wrapped_env.observation_space
|
||||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user