mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
Algorithmic refactor (#383)
* Refactor/document algorithmic environments and add tests. * test for 3 row addition * Fix failing rollout test by reinserting quirk in reversedAddition env * todo regarding addition3-v0 * Fix python 3 division issues * typo fix * Re-generate python3 rollout file to account for ReversedAddition bug fix
This commit is contained in:
@@ -1,3 +1,35 @@
|
||||
"""
|
||||
Algorithmic environments have the following traits in common:
|
||||
|
||||
- A 1-d "input tape" or 2-d "input grid" of characters
|
||||
- A target string which is a deterministic function of the input characters
|
||||
|
||||
Agents control a read head that moves over the input tape. Observations consist
|
||||
of the single character currently under the read head. The read head may fall
|
||||
off the end of the tape in any direction. When this happens, agents will observe
|
||||
a special blank character (with index=env.base) until they get back in bounds.
|
||||
|
||||
Actions consist of 3 sub-actions:
|
||||
- Direction to move the read head (left or right, plus up and down for 2-d envs)
|
||||
- Whether to write to the output tape
|
||||
- Which character to write (ignored if the above sub-action is 0)
|
||||
|
||||
An episode ends when:
|
||||
- The agent writes the full target string to the output tape.
|
||||
- The agent writes an incorrect character.
|
||||
- The agent runs out the time limit. (Which is fairly conservative.)
|
||||
|
||||
Reward schedule:
|
||||
write a correct character: +1
|
||||
write a wrong character: -.5
|
||||
run out the clock: -1
|
||||
otherwise: 0
|
||||
|
||||
In the beginning, input strings will be fairly short. After an environment has
|
||||
been consistently solved over some window of episodes, the environment will
|
||||
increase the average length of generated strings. Typical env specs require
|
||||
leveling up many times to reach their reward threshold.
|
||||
"""
|
||||
from gym import Env
|
||||
from gym.spaces import Discrete, Tuple
|
||||
from gym.utils import colorize, seeding
|
||||
@@ -5,93 +37,82 @@ import numpy as np
|
||||
from six import StringIO
|
||||
import sys
|
||||
import math
|
||||
import logging
|
||||
|
||||
hash_base = None
|
||||
def ha(array):
|
||||
return (hash_base * (array + 5)).sum()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AlgorithmicEnv(Env):
|
||||
|
||||
metadata = {'render.modes': ['human', 'ansi']}
|
||||
# Only 'promote' the length of generated input strings if the worst of the
|
||||
# last n episodes was no more than this far from the maximum reward
|
||||
MIN_REWARD_SHORTFALL_FOR_PROMOTION = -1.0
|
||||
|
||||
def __init__(self, inp_dim=1, base=10, chars=False):
|
||||
global hash_base
|
||||
|
||||
hash_base = 50 ** np.arange(inp_dim)
|
||||
def __init__(self, base=10, chars=False, starting_min_length=2):
|
||||
"""
|
||||
base: Number of distinct characters.
|
||||
chars: If True, use uppercase alphabet. Otherwise, digits. Only affects
|
||||
rendering.
|
||||
starting_min_length: Minimum input string length. Ramps up as episodes
|
||||
are consistently solved.
|
||||
"""
|
||||
self.base = base
|
||||
# Keep track of this many past episodes
|
||||
self.last = 10
|
||||
self.total_reward = 0
|
||||
self.sum_reward = 0
|
||||
AlgorithmicEnv.sum_rewards = []
|
||||
self.chars = chars
|
||||
self.inp_dim = inp_dim
|
||||
AlgorithmicEnv.current_length = 2
|
||||
tape_control = []
|
||||
|
||||
self.action_space = Tuple(([Discrete(2 * self.inp_dim), Discrete(2), Discrete(self.base)]))
|
||||
# Cumulative reward earned this episode
|
||||
self.episode_total_reward = None
|
||||
# Running tally of reward shortfalls. e.g. if there were 10 points to earn and
|
||||
# we got 8, we'd append -2
|
||||
AlgorithmicEnv.reward_shortfalls = []
|
||||
if chars:
|
||||
self.charmap = [chr(ord('A')+i) for i in range(base)]
|
||||
else:
|
||||
self.charmap = [str(i) for i in range(base)]
|
||||
self.charmap.append(' ')
|
||||
# TODO: Not clear why this is a class variable rather than instance.
|
||||
# Could lead to some spooky action at a distance if someone is working
|
||||
# with multiple algorithmic envs at once. Also makes testing tricky.
|
||||
AlgorithmicEnv.min_length = starting_min_length
|
||||
# Three sub-actions:
|
||||
# 1. Move read head left or write (or up/down)
|
||||
# 2. Write or not
|
||||
# 3. Which character to write. (Ignored if should_write=0)
|
||||
self.action_space = Tuple(
|
||||
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
|
||||
)
|
||||
# Can see just what is on the input tape (one of n characters, or nothing)
|
||||
self.observation_space = Discrete(self.base + 1)
|
||||
|
||||
self._seed()
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def _movement_idx(kls, movement_name):
|
||||
return kls.MOVEMENTS.index(movement_name)
|
||||
|
||||
def _seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def _get_obs(self, pos=None):
|
||||
if pos is None:
|
||||
pos = self.x
|
||||
assert isinstance(pos, np.ndarray) and pos.shape[0] == self.inp_dim
|
||||
if ha(pos) not in self.content:
|
||||
self.content[ha(pos)] = self.base
|
||||
return self.content[ha(pos)]
|
||||
"""Return an observation corresponding to the given read head position
|
||||
(or the current read head position, if none is given)."""
|
||||
raise NotImplemented
|
||||
|
||||
def _get_str_obs(self, pos=None):
|
||||
ret = self._get_obs(pos)
|
||||
if ret == self.base:
|
||||
return " "
|
||||
else:
|
||||
if self.chars:
|
||||
return chr(ret + ord('A'))
|
||||
return str(ret)
|
||||
return self.charmap[ret]
|
||||
|
||||
def _get_str_target(self, pos=None):
|
||||
if pos not in self.target:
|
||||
def _get_str_target(self, pos):
|
||||
"""Return the ith character of the target string (or " " if index
|
||||
out of bounds)."""
|
||||
if pos < 0 or len(self.target) <= pos:
|
||||
return " "
|
||||
else:
|
||||
ret = self.target[pos]
|
||||
if self.chars:
|
||||
return chr(ret + ord('A'))
|
||||
return str(ret)
|
||||
return self.charmap[self.target[pos]]
|
||||
|
||||
def _render_observation(self):
|
||||
x = self.x
|
||||
if self.inp_dim == 1:
|
||||
x_str = "Observation Tape : "
|
||||
for i in range(-2, self.total_len + 2):
|
||||
if i == x:
|
||||
x_str += colorize(self._get_str_obs(np.array([i])), 'green', highlight=True)
|
||||
else:
|
||||
x_str += self._get_str_obs(np.array([i]))
|
||||
x_str += "\n"
|
||||
return x_str
|
||||
elif self.inp_dim == 2:
|
||||
label = "Observation Grid : "
|
||||
x_str = ""
|
||||
for j in range(-1, 3):
|
||||
if j != -1:
|
||||
x_str += " " * len(label)
|
||||
for i in range(-2, self.total_len + 2):
|
||||
if i == x[0] and j == x[1]:
|
||||
x_str += colorize(self._get_str_obs(np.array([i, j])), 'green', highlight=True)
|
||||
else:
|
||||
x_str += self._get_str_obs(np.array([i, j]))
|
||||
x_str += "\n"
|
||||
x_str = label + x_str
|
||||
return x_str
|
||||
else:
|
||||
assert False
|
||||
|
||||
"""Return a string representation of the input tape/grid."""
|
||||
raise NotImplemented
|
||||
|
||||
def _render(self, mode='human', close=False):
|
||||
if close:
|
||||
@@ -99,34 +120,25 @@ class AlgorithmicEnv(Env):
|
||||
return
|
||||
|
||||
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||
inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time)
|
||||
inp = "Total length of input instance: %d, step: %d\n" % (self.input_width, self.time)
|
||||
outfile.write(inp)
|
||||
x, y, action = self.x, self.y, self.last_action
|
||||
x, y, action = self.read_head_position, self.write_head_position, self.last_action
|
||||
if action is not None:
|
||||
inp_act, out_act, pred = action
|
||||
outfile.write("=" * (len(inp) - 1) + "\n")
|
||||
y_str = "Output Tape : "
|
||||
target_str = "Targets : "
|
||||
target_str = "Targets : "
|
||||
if action is not None:
|
||||
if self.chars:
|
||||
pred_str = chr(pred + ord('A'))
|
||||
else:
|
||||
pred_str = str(pred)
|
||||
pred_str = self.charmap[pred]
|
||||
x_str = self._render_observation()
|
||||
max_len = int(self.total_reward) + 1
|
||||
for i in range(-2, max_len):
|
||||
if i not in self.target:
|
||||
y_str += " "
|
||||
continue
|
||||
for i in range(-2, len(self.target) + 2):
|
||||
target_str += self._get_str_target(i)
|
||||
if i < y - 1:
|
||||
y_str += self._get_str_target(i)
|
||||
elif i == (y - 1):
|
||||
if action is not None and out_act == 1:
|
||||
if pred == self.target[i]:
|
||||
y_str += colorize(pred_str, 'green', highlight=True)
|
||||
else:
|
||||
y_str += colorize(pred_str, 'red', highlight=True)
|
||||
color = 'green' if pred == self.target[i] else 'red'
|
||||
y_str += colorize(pred_str, color, highlight=True)
|
||||
else:
|
||||
y_str += self._get_str_target(i)
|
||||
outfile.write(x_str)
|
||||
@@ -134,77 +146,185 @@ class AlgorithmicEnv(Env):
|
||||
outfile.write(target_str + "\n\n")
|
||||
|
||||
if action is not None:
|
||||
outfile.write("Current reward : %.3f\n" % self.reward)
|
||||
outfile.write("Cumulative reward : %.3f\n" % self.sum_reward)
|
||||
move = ""
|
||||
if inp_act == 0:
|
||||
move = "left"
|
||||
elif inp_act == 1:
|
||||
move = "right"
|
||||
elif inp_act == 2:
|
||||
move += "up"
|
||||
elif inp_act == 3:
|
||||
move += "down"
|
||||
outfile.write("Current reward : %.3f\n" % self.last_reward)
|
||||
outfile.write("Cumulative reward : %.3f\n" % self.episode_total_reward)
|
||||
move = self.MOVEMENTS[inp_act]
|
||||
outfile.write("Action : Tuple(move over input: %s,\n" % move)
|
||||
if out_act == 1:
|
||||
out_act = "True"
|
||||
else:
|
||||
out_act = "False"
|
||||
out_act = out_act == 1
|
||||
outfile.write(" write to the output tape: %s,\n" % out_act)
|
||||
outfile.write(" prediction: %s)\n" % pred_str)
|
||||
else:
|
||||
outfile.write("\n" * 5)
|
||||
return outfile
|
||||
|
||||
@property
|
||||
def input_width(self):
|
||||
return len(self.input_data)
|
||||
|
||||
def _step(self, action):
|
||||
assert self.action_space.contains(action)
|
||||
self.last_action = action
|
||||
inp_act, out_act, pred = action
|
||||
done = False
|
||||
reward = 0.0
|
||||
# We are outside the sample.
|
||||
self.time += 1
|
||||
if self.y not in self.target:
|
||||
reward = -10.0
|
||||
done = True
|
||||
else:
|
||||
if out_act == 1:
|
||||
if pred == self.target[self.y]:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = -0.5
|
||||
done = True
|
||||
self.y += 1
|
||||
if self.y not in self.target:
|
||||
done = True
|
||||
if inp_act == 0:
|
||||
self.x[0] -= 1
|
||||
elif inp_act == 1:
|
||||
self.x[0] += 1
|
||||
elif inp_act == 2:
|
||||
self.x[1] -= 1
|
||||
elif inp_act == 3:
|
||||
self.x[1] += 1
|
||||
if self.time > self.total_len + self.total_reward + 4:
|
||||
reward = -1.0
|
||||
assert 0 <= self.write_head_position
|
||||
if out_act == 1:
|
||||
try:
|
||||
correct = pred == self.target[self.write_head_position]
|
||||
except IndexError:
|
||||
logger.warn("It looks like you're calling step() even though this "+
|
||||
"environment has already returned done=True. You should always call "+
|
||||
"reset() once you receive done=True. Any further steps are undefined "+
|
||||
"behaviour.")
|
||||
correct = False
|
||||
if correct:
|
||||
reward = 1.0
|
||||
else:
|
||||
# Bail as soon as a wrong character is written to the tape
|
||||
reward = -0.5
|
||||
done = True
|
||||
self.write_head_position += 1
|
||||
if self.write_head_position >= len(self.target):
|
||||
done = True
|
||||
self._move(inp_act)
|
||||
if self.time > self.time_limit:
|
||||
reward = -1.0
|
||||
done = True
|
||||
obs = self._get_obs()
|
||||
self.reward = reward
|
||||
self.sum_reward += reward
|
||||
self.last_reward = reward
|
||||
self.episode_total_reward += reward
|
||||
return (obs, reward, done, {})
|
||||
|
||||
@property
|
||||
def time_limit(self):
|
||||
"""If an agent takes more than this many timesteps, end the episode
|
||||
immediately and return a negative reward."""
|
||||
# (Seemingly arbitrary)
|
||||
return self.input_width + len(self.target) + 4
|
||||
|
||||
def _check_levelup(self):
|
||||
"""Called between episodes. Update our running record of episode rewards
|
||||
and, if appropriate, 'level up' minimum input length."""
|
||||
if self.episode_total_reward is None:
|
||||
# This is before the first episode/call to reset(). Nothing to do
|
||||
return
|
||||
AlgorithmicEnv.reward_shortfalls.append(self.episode_total_reward - len(self.target))
|
||||
AlgorithmicEnv.reward_shortfalls = AlgorithmicEnv.reward_shortfalls[-self.last:]
|
||||
if len(AlgorithmicEnv.reward_shortfalls) == self.last and \
|
||||
min(AlgorithmicEnv.reward_shortfalls) >= self.MIN_REWARD_SHORTFALL_FOR_PROMOTION and \
|
||||
AlgorithmicEnv.min_length < 30:
|
||||
AlgorithmicEnv.min_length += 1
|
||||
AlgorithmicEnv.reward_shortfalls = []
|
||||
|
||||
|
||||
def _reset(self):
|
||||
self._check_levelup()
|
||||
self.last_action = None
|
||||
self.x = np.zeros(self.inp_dim).astype(np.int)
|
||||
self.y = 0
|
||||
AlgorithmicEnv.sum_rewards.append(self.sum_reward - self.total_reward)
|
||||
AlgorithmicEnv.sum_rewards = AlgorithmicEnv.sum_rewards[-self.last:]
|
||||
if len(AlgorithmicEnv.sum_rewards) == self.last and \
|
||||
min(AlgorithmicEnv.sum_rewards) >= -1.0 and \
|
||||
AlgorithmicEnv.current_length < 30:
|
||||
AlgorithmicEnv.current_length += 1
|
||||
AlgorithmicEnv.sum_rewards = []
|
||||
self.sum_reward = 0.0
|
||||
self.last_reward = 0
|
||||
self.read_head_position = self.READ_HEAD_START
|
||||
self.write_head_position = 0
|
||||
self.episode_total_reward = 0.0
|
||||
self.time = 0
|
||||
self.total_len = self.np_random.randint(3) + AlgorithmicEnv.current_length
|
||||
self.set_data()
|
||||
length = self.np_random.randint(3) + AlgorithmicEnv.min_length
|
||||
self.input_data = self.generate_input_data(length)
|
||||
self.target = self.target_from_input_data(self.input_data)
|
||||
return self._get_obs()
|
||||
|
||||
def generate_input_data(self, size):
|
||||
raise NotImplemented
|
||||
|
||||
def target_from_input_data(self, input_data):
|
||||
raise NotImplemented("Subclasses must implement")
|
||||
|
||||
def _move(self, movement):
|
||||
raise NotImplemented
|
||||
|
||||
class TapeAlgorithmicEnv(AlgorithmicEnv):
|
||||
"""An algorithmic env with a 1-d input tape."""
|
||||
MOVEMENTS = ['left', 'right']
|
||||
READ_HEAD_START = 0
|
||||
|
||||
def _move(self, movement):
|
||||
named = self.MOVEMENTS[movement]
|
||||
self.read_head_position += 1 if named == 'right' else -1
|
||||
|
||||
def _get_obs(self, pos=None):
|
||||
if pos is None:
|
||||
pos = self.read_head_position
|
||||
if pos < 0:
|
||||
return self.base
|
||||
try:
|
||||
return self.input_data[pos]
|
||||
except IndexError:
|
||||
return self.base
|
||||
|
||||
def generate_input_data(self, size):
|
||||
return [self.np_random.randint(self.base) for _ in range(size)]
|
||||
|
||||
def _render_observation(self):
|
||||
x = self.read_head_position
|
||||
x_str = "Observation Tape : "
|
||||
for i in range(-2, self.input_width + 2):
|
||||
if i == x:
|
||||
x_str += colorize(self._get_str_obs(np.array([i])), 'green', highlight=True)
|
||||
else:
|
||||
x_str += self._get_str_obs(np.array([i]))
|
||||
x_str += "\n"
|
||||
return x_str
|
||||
|
||||
class GridAlgorithmicEnv(AlgorithmicEnv):
|
||||
"""An algorithmic env with a 2-d input grid."""
|
||||
MOVEMENTS = ['left', 'right', 'up', 'down']
|
||||
READ_HEAD_START = (0, 0)
|
||||
def __init__(self, rows, *args, **kwargs):
|
||||
self.rows = rows
|
||||
AlgorithmicEnv.__init__(self, *args, **kwargs)
|
||||
|
||||
def _move(self, movement):
|
||||
named = self.MOVEMENTS[movement]
|
||||
x, y = self.read_head_position
|
||||
if named == 'left':
|
||||
x -= 1
|
||||
elif named == 'right':
|
||||
x += 1
|
||||
elif named == 'up':
|
||||
y -= 1
|
||||
elif named == 'down':
|
||||
y += 1
|
||||
else:
|
||||
raise ValueError("Unrecognized direction: {}".format(named))
|
||||
self.read_head_position = x, y
|
||||
|
||||
def generate_input_data(self, size):
|
||||
return [
|
||||
[self.np_random.randint(self.base) for _ in range(self.rows)]
|
||||
for __ in range(size)
|
||||
]
|
||||
|
||||
def _get_obs(self, pos=None):
|
||||
if pos is None:
|
||||
pos = self.read_head_position
|
||||
x, y = pos
|
||||
if any(idx < 0 for idx in pos):
|
||||
return self.base
|
||||
try:
|
||||
return self.input_data[x][y]
|
||||
except IndexError:
|
||||
return self.base
|
||||
|
||||
def _render_observation(self):
|
||||
x = self.read_head_position
|
||||
label = "Observation Grid : "
|
||||
x_str = ""
|
||||
for j in range(-1, self.rows+1):
|
||||
if j != -1:
|
||||
x_str += " " * len(label)
|
||||
for i in range(-2, self.input_width + 2):
|
||||
if i == x[0] and j == x[1]:
|
||||
x_str += colorize(self._get_str_obs((i, j)), 'green', highlight=True)
|
||||
else:
|
||||
x_str += self._get_str_obs((i, j))
|
||||
x_str += "\n"
|
||||
x_str = label + x_str
|
||||
return x_str
|
||||
|
Reference in New Issue
Block a user