mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 16:57:10 +00:00
PEP-8 Fixes in algorithmic environment (#1382)
Remove trailing whitespaces. Make line breaks adhere to 80 character limit (not all, but quite a few). Remove unused imports. Other miscellaneous PEP-8 fixes.
This commit is contained in:
@@ -26,7 +26,7 @@ Reward schedule:
|
|||||||
otherwise: 0
|
otherwise: 0
|
||||||
|
|
||||||
In the beginning, input strings will be fairly short. After an environment has
|
In the beginning, input strings will be fairly short. After an environment has
|
||||||
been consistently solved over some window of episodes, the environment will
|
been consistently solved over some window of episodes, the environment will
|
||||||
increase the average length of generated strings. Typical env specs require
|
increase the average length of generated strings. Typical env specs require
|
||||||
leveling up many times to reach their reward threshold.
|
leveling up many times to reach their reward threshold.
|
||||||
"""
|
"""
|
||||||
@@ -42,16 +42,16 @@ from six import StringIO
|
|||||||
class AlgorithmicEnv(Env):
|
class AlgorithmicEnv(Env):
|
||||||
|
|
||||||
metadata = {'render.modes': ['human', 'ansi']}
|
metadata = {'render.modes': ['human', 'ansi']}
|
||||||
# Only 'promote' the length of generated input strings if the worst of the
|
# Only 'promote' the length of generated input strings if the worst of the
|
||||||
# last n episodes was no more than this far from the maximum reward
|
# last n episodes was no more than this far from the maximum reward
|
||||||
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -1.0
|
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -1.0
|
||||||
|
|
||||||
def __init__(self, base=10, chars=False, starting_min_length=2):
|
def __init__(self, base=10, chars=False, starting_min_length=2):
|
||||||
"""
|
"""
|
||||||
base: Number of distinct characters.
|
base: Number of distinct characters.
|
||||||
chars: If True, use uppercase alphabet. Otherwise, digits. Only affects
|
chars: If True, use uppercase alphabet. Otherwise, digits. Only affects
|
||||||
rendering.
|
rendering.
|
||||||
starting_min_length: Minimum input string length. Ramps up as episodes
|
starting_min_length: Minimum input string length. Ramps up as episodes
|
||||||
are consistently solved.
|
are consistently solved.
|
||||||
"""
|
"""
|
||||||
self.base = base
|
self.base = base
|
||||||
@@ -59,15 +59,15 @@ class AlgorithmicEnv(Env):
|
|||||||
self.last = 10
|
self.last = 10
|
||||||
# Cumulative reward earned this episode
|
# Cumulative reward earned this episode
|
||||||
self.episode_total_reward = None
|
self.episode_total_reward = None
|
||||||
# Running tally of reward shortfalls. e.g. if there were 10 points to earn and
|
# Running tally of reward shortfalls. e.g. if there were 10 points to
|
||||||
# we got 8, we'd append -2
|
# earn and we got 8, we'd append -2
|
||||||
AlgorithmicEnv.reward_shortfalls = []
|
AlgorithmicEnv.reward_shortfalls = []
|
||||||
if chars:
|
if chars:
|
||||||
self.charmap = [chr(ord('A')+i) for i in range(base)]
|
self.charmap = [chr(ord('A')+i) for i in range(base)]
|
||||||
else:
|
else:
|
||||||
self.charmap = [str(i) for i in range(base)]
|
self.charmap = [str(i) for i in range(base)]
|
||||||
self.charmap.append(' ')
|
self.charmap.append(' ')
|
||||||
# TODO: Not clear why this is a class variable rather than instance.
|
# TODO: Not clear why this is a class variable rather than instance.
|
||||||
# Could lead to some spooky action at a distance if someone is working
|
# Could lead to some spooky action at a distance if someone is working
|
||||||
# with multiple algorithmic envs at once. Also makes testing tricky.
|
# with multiple algorithmic envs at once. Also makes testing tricky.
|
||||||
AlgorithmicEnv.min_length = starting_min_length
|
AlgorithmicEnv.min_length = starting_min_length
|
||||||
@@ -78,7 +78,8 @@ class AlgorithmicEnv(Env):
|
|||||||
self.action_space = Tuple(
|
self.action_space = Tuple(
|
||||||
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
|
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
|
||||||
)
|
)
|
||||||
# Can see just what is on the input tape (one of n characters, or nothing)
|
# Can see just what is on the input tape (one of n characters, or
|
||||||
|
# nothing)
|
||||||
self.observation_space = Discrete(self.base + 1)
|
self.observation_space = Discrete(self.base + 1)
|
||||||
self.seed()
|
self.seed()
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -170,10 +171,11 @@ class AlgorithmicEnv(Env):
|
|||||||
try:
|
try:
|
||||||
correct = pred == self.target[self.write_head_position]
|
correct = pred == self.target[self.write_head_position]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.warn("It looks like you're calling step() even though this "+
|
logger.warn(
|
||||||
"environment has already returned done=True. You should always call "+
|
"It looks like you're calling step() even though this "
|
||||||
"reset() once you receive done=True. Any further steps are undefined "+
|
"environment has already returned done=True. You should "
|
||||||
"behaviour.")
|
"always call reset() once you receive done=True. Any "
|
||||||
|
"further steps are undefined behaviour.")
|
||||||
correct = False
|
correct = False
|
||||||
if correct:
|
if correct:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
@@ -201,7 +203,7 @@ class AlgorithmicEnv(Env):
|
|||||||
return self.input_width + len(self.target) + 4
|
return self.input_width + len(self.target) + 4
|
||||||
|
|
||||||
def _check_levelup(self):
|
def _check_levelup(self):
|
||||||
"""Called between episodes. Update our running record of episode rewards
|
"""Called between episodes. Update our running record of episode rewards
|
||||||
and, if appropriate, 'level up' minimum input length."""
|
and, if appropriate, 'level up' minimum input length."""
|
||||||
if self.episode_total_reward is None:
|
if self.episode_total_reward is None:
|
||||||
# This is before the first episode/call to reset(). Nothing to do
|
# This is before the first episode/call to reset(). Nothing to do
|
||||||
@@ -209,11 +211,10 @@ class AlgorithmicEnv(Env):
|
|||||||
AlgorithmicEnv.reward_shortfalls.append(self.episode_total_reward - len(self.target))
|
AlgorithmicEnv.reward_shortfalls.append(self.episode_total_reward - len(self.target))
|
||||||
AlgorithmicEnv.reward_shortfalls = AlgorithmicEnv.reward_shortfalls[-self.last:]
|
AlgorithmicEnv.reward_shortfalls = AlgorithmicEnv.reward_shortfalls[-self.last:]
|
||||||
if len(AlgorithmicEnv.reward_shortfalls) == self.last and \
|
if len(AlgorithmicEnv.reward_shortfalls) == self.last and \
|
||||||
min(AlgorithmicEnv.reward_shortfalls) >= self.MIN_REWARD_SHORTFALL_FOR_PROMOTION and \
|
min(AlgorithmicEnv.reward_shortfalls) >= self.MIN_REWARD_SHORTFALL_FOR_PROMOTION and \
|
||||||
AlgorithmicEnv.min_length < 30:
|
AlgorithmicEnv.min_length < 30:
|
||||||
AlgorithmicEnv.min_length += 1
|
AlgorithmicEnv.min_length += 1
|
||||||
AlgorithmicEnv.reward_shortfalls = []
|
AlgorithmicEnv.reward_shortfalls = []
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._check_levelup()
|
self._check_levelup()
|
||||||
@@ -258,13 +259,13 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
return self.input_data[pos]
|
return self.input_data[pos]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return self.base
|
return self.base
|
||||||
|
|
||||||
def generate_input_data(self, size):
|
def generate_input_data(self, size):
|
||||||
return [self.np_random.randint(self.base) for _ in range(size)]
|
return [self.np_random.randint(self.base) for _ in range(size)]
|
||||||
|
|
||||||
def render_observation(self):
|
def render_observation(self):
|
||||||
x = self.read_head_position
|
x = self.read_head_position
|
||||||
x_str = "Observation Tape : "
|
x_str = "Observation Tape : "
|
||||||
for i in range(-2, self.input_width + 2):
|
for i in range(-2, self.input_width + 2):
|
||||||
if i == x:
|
if i == x:
|
||||||
x_str += colorize(self._get_str_obs(np.array([i])), 'green', highlight=True)
|
x_str += colorize(self._get_str_obs(np.array([i])), 'green', highlight=True)
|
||||||
@@ -278,6 +279,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
"""An algorithmic env with a 2-d input grid."""
|
"""An algorithmic env with a 2-d input grid."""
|
||||||
MOVEMENTS = ['left', 'right', 'up', 'down']
|
MOVEMENTS = ['left', 'right', 'up', 'down']
|
||||||
READ_HEAD_START = (0, 0)
|
READ_HEAD_START = (0, 0)
|
||||||
|
|
||||||
def __init__(self, rows, *args, **kwargs):
|
def __init__(self, rows, *args, **kwargs):
|
||||||
self.rows = rows
|
self.rows = rows
|
||||||
AlgorithmicEnv.__init__(self, *args, **kwargs)
|
AlgorithmicEnv.__init__(self, *args, **kwargs)
|
||||||
@@ -316,7 +318,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
|
|
||||||
def render_observation(self):
|
def render_observation(self):
|
||||||
x = self.read_head_position
|
x = self.read_head_position
|
||||||
label = "Observation Grid : "
|
label = "Observation Grid : "
|
||||||
x_str = ""
|
x_str = ""
|
||||||
for j in range(-1, self.rows+1):
|
for j in range(-1, self.rows+1):
|
||||||
if j != -1:
|
if j != -1:
|
||||||
|
@@ -4,10 +4,10 @@ the output tape. http://arxiv.org/abs/1511.07275
|
|||||||
"""
|
"""
|
||||||
from gym.envs.algorithmic import algorithmic_env
|
from gym.envs.algorithmic import algorithmic_env
|
||||||
|
|
||||||
|
|
||||||
class CopyEnv(algorithmic_env.TapeAlgorithmicEnv):
|
class CopyEnv(algorithmic_env.TapeAlgorithmicEnv):
|
||||||
def __init__(self, base=5, chars=True):
|
def __init__(self, base=5, chars=True):
|
||||||
super(CopyEnv, self).__init__(base=base, chars=chars)
|
super(CopyEnv, self).__init__(base=base, chars=chars)
|
||||||
|
|
||||||
def target_from_input_data(self, input_data):
|
def target_from_input_data(self, input_data):
|
||||||
return input_data
|
return input_data
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@ http://arxiv.org/abs/1511.07275
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from gym.envs.algorithmic import algorithmic_env
|
from gym.envs.algorithmic import algorithmic_env
|
||||||
|
|
||||||
|
|
||||||
class DuplicatedInputEnv(algorithmic_env.TapeAlgorithmicEnv):
|
class DuplicatedInputEnv(algorithmic_env.TapeAlgorithmicEnv):
|
||||||
def __init__(self, duplication=2, base=5):
|
def __init__(self, duplication=2, base=5):
|
||||||
self.duplication = duplication
|
self.duplication = duplication
|
||||||
|
@@ -4,12 +4,13 @@ the output tape. http://arxiv.org/abs/1511.07275
|
|||||||
"""
|
"""
|
||||||
from gym.envs.algorithmic import algorithmic_env
|
from gym.envs.algorithmic import algorithmic_env
|
||||||
|
|
||||||
|
|
||||||
class RepeatCopyEnv(algorithmic_env.TapeAlgorithmicEnv):
|
class RepeatCopyEnv(algorithmic_env.TapeAlgorithmicEnv):
|
||||||
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -.1
|
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -.1
|
||||||
|
|
||||||
def __init__(self, base=5):
|
def __init__(self, base=5):
|
||||||
super(RepeatCopyEnv, self).__init__(base=base, chars=True)
|
super(RepeatCopyEnv, self).__init__(base=base, chars=True)
|
||||||
self.last = 50
|
self.last = 50
|
||||||
|
|
||||||
def target_from_input_data(self, input_data):
|
def target_from_input_data(self, input_data):
|
||||||
return input_data + list(reversed(input_data)) + input_data
|
return input_data + list(reversed(input_data)) + input_data
|
||||||
|
|
||||||
|
@@ -2,11 +2,12 @@
|
|||||||
Task is to reverse content over the input tape.
|
Task is to reverse content over the input tape.
|
||||||
http://arxiv.org/abs/1511.07275
|
http://arxiv.org/abs/1511.07275
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from gym.envs.algorithmic import algorithmic_env
|
from gym.envs.algorithmic import algorithmic_env
|
||||||
|
|
||||||
|
|
||||||
class ReverseEnv(algorithmic_env.TapeAlgorithmicEnv):
|
class ReverseEnv(algorithmic_env.TapeAlgorithmicEnv):
|
||||||
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -.1
|
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -.1
|
||||||
|
|
||||||
def __init__(self, base=2):
|
def __init__(self, base=2):
|
||||||
super(ReverseEnv, self).__init__(base=base, chars=True, starting_min_length=1)
|
super(ReverseEnv, self).__init__(base=base, chars=True, starting_min_length=1)
|
||||||
self.last = 50
|
self.last = 50
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
import numpy as np
|
|
||||||
from gym.envs.algorithmic import algorithmic_env
|
from gym.envs.algorithmic import algorithmic_env
|
||||||
|
|
||||||
|
|
||||||
class ReversedAdditionEnv(algorithmic_env.GridAlgorithmicEnv):
|
class ReversedAdditionEnv(algorithmic_env.GridAlgorithmicEnv):
|
||||||
def __init__(self, rows=2, base=3):
|
def __init__(self, rows=2, base=3):
|
||||||
super(ReversedAdditionEnv, self).__init__(rows=rows, base=base, chars=False)
|
super(ReversedAdditionEnv, self).__init__(rows=rows, base=base, chars=False)
|
||||||
|
Reference in New Issue
Block a user