mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-06 16:01:45 +00:00
214 lines
6.0 KiB
Python
214 lines
6.0 KiB
Python
![]() |
from dataclasses import dataclass, field
|
||
|
from typing import Callable, Optional, Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
import pygame
|
||
|
import pytest
|
||
|
from pygame import KEYDOWN, KEYUP, QUIT, event
|
||
|
from pygame.event import Event
|
||
|
|
||
|
import gym
|
||
|
from gym.utils.play import MissingKeysToAction, PlayableGame, play
|
||
|
|
||
|
RELEVANT_KEY_1 = ord("a") # 97
|
||
|
RELEVANT_KEY_2 = ord("d") # 100
|
||
|
IRRELEVANT_KEY = 1
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DummyEnvSpec:
|
||
|
id: str
|
||
|
|
||
|
|
||
|
class DummyPlayEnv(gym.Env):
|
||
|
def step(self, action):
|
||
|
obs = np.zeros((1, 1))
|
||
|
rew, done, info = 1, False, {}
|
||
|
return obs, rew, done, info
|
||
|
|
||
|
def reset(self, seed=None):
|
||
|
...
|
||
|
|
||
|
def render(self, mode="rgb_array"):
|
||
|
return np.zeros((1, 1))
|
||
|
|
||
|
|
||
|
class PlayStatus:
|
||
|
def __init__(self, callback: Callable):
|
||
|
self.data_callback = callback
|
||
|
self.cumulative_reward = 0
|
||
|
self.last_observation = None
|
||
|
|
||
|
def callback(self, obs_t, obs_tp1, action, rew, done, info):
|
||
|
_, obs_tp1, _, rew, _, _ = self.data_callback(
|
||
|
obs_t, obs_tp1, action, rew, done, info
|
||
|
)
|
||
|
self.cumulative_reward += rew
|
||
|
self.last_observation = obs_tp1
|
||
|
|
||
|
|
||
|
def dummy_keys_to_action():
|
||
|
return {(RELEVANT_KEY_1,): 0, (RELEVANT_KEY_2,): 1}
|
||
|
|
||
|
|
||
|
@pytest.fixture(autouse=True)
|
||
|
def close_pygame():
|
||
|
yield
|
||
|
pygame.quit()
|
||
|
|
||
|
|
||
|
def test_play_relevant_keys():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
assert game.relevant_keys == {RELEVANT_KEY_1, RELEVANT_KEY_2}
|
||
|
|
||
|
|
||
|
def test_play_relevant_keys_no_mapping():
|
||
|
env = DummyPlayEnv()
|
||
|
env.spec = DummyEnvSpec("DummyPlayEnv")
|
||
|
|
||
|
with pytest.raises(MissingKeysToAction) as info:
|
||
|
PlayableGame(env)
|
||
|
|
||
|
|
||
|
def test_play_relevant_keys_with_env_attribute():
|
||
|
"""Env has a keys_to_action attribute"""
|
||
|
env = DummyPlayEnv()
|
||
|
env.get_keys_to_action = dummy_keys_to_action
|
||
|
game = PlayableGame(env)
|
||
|
assert game.relevant_keys == {97, 100}
|
||
|
|
||
|
|
||
|
def test_video_size_no_zoom():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
assert game.video_size == list(env.render().shape)
|
||
|
|
||
|
|
||
|
def test_video_size_zoom():
|
||
|
env = DummyPlayEnv()
|
||
|
zoom = 2.2
|
||
|
game = PlayableGame(env, dummy_keys_to_action(), zoom)
|
||
|
assert game.video_size == tuple(int(shape * zoom) for shape in env.render().shape)
|
||
|
|
||
|
|
||
|
def test_keyboard_quit_event():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
event = Event(pygame.KEYDOWN, {"key": pygame.K_ESCAPE})
|
||
|
assert game.running == True
|
||
|
game.process_event(event)
|
||
|
assert game.running == False
|
||
|
|
||
|
|
||
|
def test_pygame_quit_event():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
event = Event(pygame.QUIT)
|
||
|
assert game.running == True
|
||
|
game.process_event(event)
|
||
|
assert game.running == False
|
||
|
|
||
|
|
||
|
def test_keyboard_relevant_keydown_event():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1})
|
||
|
game.process_event(event)
|
||
|
assert game.pressed_keys == [RELEVANT_KEY_1]
|
||
|
|
||
|
|
||
|
def test_keyboard_irrelevant_keydown_event():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
event = Event(pygame.KEYDOWN, {"key": IRRELEVANT_KEY})
|
||
|
game.process_event(event)
|
||
|
assert game.pressed_keys == []
|
||
|
|
||
|
|
||
|
def test_keyboard_keyup_event():
|
||
|
env = DummyPlayEnv()
|
||
|
game = PlayableGame(env, dummy_keys_to_action())
|
||
|
event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1})
|
||
|
game.process_event(event)
|
||
|
event = Event(pygame.KEYUP, {"key": RELEVANT_KEY_1})
|
||
|
game.process_event(event)
|
||
|
assert game.pressed_keys == []
|
||
|
|
||
|
|
||
|
def test_play_loop():
|
||
|
# set of key events to inject into the play loop as callback
|
||
|
callback_events = [
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
|
||
|
Event(QUIT),
|
||
|
]
|
||
|
|
||
|
def callback(obs_t, obs_tp1, action, rew, done, info):
|
||
|
event.post(callback_events.pop(0))
|
||
|
return obs_t, obs_tp1, action, rew, done, info
|
||
|
|
||
|
env = DummyPlayEnv()
|
||
|
cumulative_env_reward = 0
|
||
|
for s in range(
|
||
|
len(callback_events)
|
||
|
): # we run the same number of steps executed with play()
|
||
|
_, rew, _, _ = env.step(None)
|
||
|
cumulative_env_reward += rew
|
||
|
|
||
|
env_play = DummyPlayEnv()
|
||
|
status = PlayStatus(callback)
|
||
|
play(env_play, callback=status.callback, keys_to_action=dummy_keys_to_action())
|
||
|
|
||
|
assert status.cumulative_reward == cumulative_env_reward
|
||
|
|
||
|
|
||
|
def test_play_loop_real_env():
|
||
|
SEED = 42
|
||
|
ENV = "CartPole-v1"
|
||
|
|
||
|
# set of key events to inject into the play loop as callback
|
||
|
callback_events = [
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYUP, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_2}),
|
||
|
Event(KEYUP, {"key": RELEVANT_KEY_2}),
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYUP, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYUP, {"key": RELEVANT_KEY_1}),
|
||
|
Event(KEYDOWN, {"key": RELEVANT_KEY_2}),
|
||
|
Event(KEYUP, {"key": RELEVANT_KEY_2}),
|
||
|
Event(QUIT),
|
||
|
]
|
||
|
keydown_events = [k for k in callback_events if k.type == KEYDOWN]
|
||
|
|
||
|
def callback(obs_t, obs_tp1, action, rew, done, info):
|
||
|
pygame_event = callback_events.pop(0)
|
||
|
event.post(pygame_event)
|
||
|
|
||
|
# after releasing a key, post new events until
|
||
|
# we have one keydown
|
||
|
while pygame_event.type == KEYUP:
|
||
|
pygame_event = callback_events.pop(0)
|
||
|
event.post(pygame_event)
|
||
|
|
||
|
return obs_t, obs_tp1, action, rew, done, info
|
||
|
|
||
|
env = gym.make(ENV)
|
||
|
env.reset(seed=SEED)
|
||
|
keys_to_action = dummy_keys_to_action()
|
||
|
|
||
|
# first action is 0 because at the first iteration
|
||
|
# we can not inject a callback event into play()
|
||
|
env.step(0)
|
||
|
for e in keydown_events:
|
||
|
action = keys_to_action[(e.key,)]
|
||
|
obs, _, _, _ = env.step(action)
|
||
|
|
||
|
env_play = gym.make(ENV)
|
||
|
status = PlayStatus(callback)
|
||
|
play(env_play, callback=status.callback, keys_to_action=keys_to_action, seed=SEED)
|
||
|
|
||
|
assert (status.last_observation == obs).all()
|