redo black

This commit is contained in:
Justin Terry
2021-07-29 12:42:48 -04:00
parent d5004b7ec1
commit e9d2c41f2b
109 changed files with 459 additions and 1363 deletions

View File

@@ -73,9 +73,7 @@ class AlgorithmicEnv(Env):
# 1. Move read head left or right (or up/down)
# 2. Write or not
# 3. Which character to write. (Ignored if should_write=0)
self.action_space = Tuple(
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
)
self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
# Can see just what is on the input tape (one of n characters, or
# nothing)
self.observation_space = Discrete(self.base + 1)
@@ -147,10 +145,7 @@ class AlgorithmicEnv(Env):
move = self.MOVEMENTS[inp_act]
outfile.write("Action : Tuple(move over input: %s,\n" % move)
out_act = out_act == 1
outfile.write(
" write to the output tape: %s,\n"
% out_act
)
outfile.write(" write to the output tape: %s,\n" % out_act)
outfile.write(" prediction: %s)\n" % pred_str)
else:
outfile.write("\n" * 5)
@@ -276,9 +271,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
x_str = "Observation Tape : "
for i in range(-2, self.input_width + 2):
if i == x:
x_str += colorize(
self._get_str_obs(np.array([i])), "green", highlight=True
)
x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
else:
x_str += self._get_str_obs(np.array([i]))
x_str += "\n"
@@ -311,10 +304,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
self.read_head_position = x, y
def generate_input_data(self, size):
return [
[self.np_random.randint(self.base) for _ in range(self.rows)]
for __ in range(size)
]
return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]
def _get_obs(self, pos=None):
if pos is None:
@@ -336,9 +326,7 @@ class GridAlgorithmicEnv(AlgorithmicEnv):
x_str += " " * len(label)
for i in range(-2, self.input_width + 2):
if i == x[0] and j == x[1]:
x_str += colorize(
self._get_str_obs((i, j)), "green", highlight=True
)
x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
else:
x_str += self._get_str_obs((i, j))
x_str += "\n"