redo black

This commit is contained in:
Justin Terry
2021-07-29 12:42:48 -04:00
parent d5004b7ec1
commit e9d2c41f2b
109 changed files with 459 additions and 1363 deletions

View File

@@ -62,16 +62,12 @@ class MemorizeDigits(gym.Env):
def __init__(self):
self.seed()
self.viewer = None
self.observation_space = spaces.Box(
0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
)
self.observation_space = spaces.Box(0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8)
self.action_space = spaces.Discrete(10)
self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
for digit in range(10):
for y in range(6):
self.bogus_mnist[digit, y, :] = [
ord(char) for char in bogus_mnist[digit][y]
]
self.bogus_mnist[digit, y, :] = [ord(char) for char in bogus_mnist[digit][y]]
self.reset()
def seed(self, seed=None):
@@ -93,9 +89,7 @@ class MemorizeDigits(gym.Env):
self.color_bg = self.random_color() if self.use_random_colors else color_black
self.step_n = 0
while 1:
self.color_digit = (
self.random_color() if self.use_random_colors else color_white
)
self.color_digit = self.random_color() if self.use_random_colors else color_white
if np.linalg.norm(self.color_digit - self.color_bg) < 50:
continue
break
@@ -119,9 +113,7 @@ class MemorizeDigits(gym.Env):
digit_img[:] = self.color_bg
xxx = self.bogus_mnist[self.digit] == 42
digit_img[xxx] = self.color_digit
obs[
self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
] = digit_img
obs[self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3] = digit_img
self.last_obs = obs
return obs, reward, done, {}