import numpy as np import gym from gym import spaces from gym.utils import seeding # Unit test environment for CNNs. # Looks like this (RGB observations): # # --------------------------- # | | # | ****** | # | ****** | # | ** ** | # | ** ** | # | ** | # | ** | # | **** | # | **** | # | **** | # | **** | # | ********** | # | ********** | # | | # --------------------------- # # Agent should hit action 2 to gain reward. Catches off-by-one errors in your agent. # # To see how it works, run: # # python examples/agents/keyboard_agent.py MemorizeDigits-v0 FIELD_W = 32 FIELD_H = 24 bogus_mnist = [ [" **** ", "* *", "* *", "* *", "* *", " **** "], [" ** ", " * * ", " * ", " * ", " * ", " *** "], [" **** ", "* *", " *", " *** ", "** ", "******"], [" **** ", "* *", " ** ", " *", "* *", " **** "], [" * * ", " * * ", " * * ", " **** ", " * ", " * "], [" **** ", " * ", " **** ", " * ", " * ", " **** "], [" *** ", " * ", " **** ", " * * ", " * * ", " **** "], [" **** ", " * ", " * ", " * ", " * ", " * "], [" **** ", "* *", " **** ", "* *", "* *", " **** "], [" **** ", "* *", "* *", " *****", " *", " **** "], ] color_black = np.array((0, 0, 0)).astype("float32") color_white = np.array((255, 255, 255)).astype("float32") class MemorizeDigits(gym.Env): metadata = { "render.modes": ["human", "rgb_array"], "video.frames_per_second": 60, "video.res_w": FIELD_W, "video.res_h": FIELD_H, } use_random_colors = False def __init__(self): self.seed() self.viewer = None self.observation_space = spaces.Box( 0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8 ) self.action_space = spaces.Discrete(10) self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8) for digit in range(10): for y in range(6): self.bogus_mnist[digit, y, :] = [ ord(char) for char in bogus_mnist[digit][y] ] self.reset() def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] def random_color(self): return np.array( [ self.np_random.randint(low=0, high=255), self.np_random.randint(low=0, high=255), self.np_random.randint(low=0, high=255), ] ).astype("uint8") def reset(self): self.digit_x = self.np_random.randint(low=FIELD_W // 5, high=FIELD_W // 5 * 4) self.digit_y = self.np_random.randint(low=FIELD_H // 5, high=FIELD_H // 5 * 4) self.color_bg = self.random_color() if self.use_random_colors else color_black self.step_n = 0 while 1: self.color_digit = ( self.random_color() if self.use_random_colors else color_white ) if np.linalg.norm(self.color_digit - self.color_bg) < 50: continue break self.digit = -1 return self.step(0)[0] def step(self, action): reward = -1 done = False self.step_n += 1 if self.digit == -1: pass else: if self.digit == action: reward = +1 done = self.step_n > 20 and 0 == self.np_random.randint(low=0, high=5) self.digit = self.np_random.randint(low=0, high=10) obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8) obs[:, :, :] = self.color_bg digit_img = np.zeros((6, 6, 3), dtype=np.uint8) digit_img[:] = self.color_bg xxx = self.bogus_mnist[self.digit] == 42 digit_img[xxx] = self.color_digit obs[ self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3 ] = digit_img self.last_obs = obs return obs, reward, done, {} def render(self, mode="human"): if mode == "rgb_array": return self.last_obs elif mode == "human": from gym.envs.classic_control import rendering if self.viewer is None: self.viewer = rendering.SimpleImageViewer() self.viewer.imshow(self.last_obs) return self.viewer.isopen else: assert 0, f"Render mode '{mode}' is not supported" def close(self): if self.viewer is not None: self.viewer.close() self.viewer = None