2022-12-10 22:04:14 +00:00
|
|
|
"""Tests the functional api."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2022-12-10 22:04:14 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import Any
|
2022-11-18 22:25:33 +01:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2024-08-15 14:49:05 +01:00
|
|
|
from gymnasium.experimental.functional import FuncEnv
|
2022-11-18 22:25:33 +01:00
|
|
|
|
|
|
|
|
2022-11-29 23:37:53 +00:00
|
|
|
class GenericTestFuncEnv(FuncEnv):
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Generic testing functional environment."""
|
|
|
|
|
|
|
|
def __init__(self, options: dict[str, Any] | None = None):
|
|
|
|
"""Constructor that allows generic options to be set on the environment."""
|
2022-11-18 22:25:33 +01:00
|
|
|
super().__init__(options)
|
|
|
|
|
2024-08-15 14:49:05 +01:00
|
|
|
def initial(self, rng: Any, params=None) -> np.ndarray:
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Testing initial function."""
|
2022-11-18 22:25:33 +01:00
|
|
|
return np.array([0, 0], dtype=np.float32)
|
|
|
|
|
2024-08-15 14:49:05 +01:00
|
|
|
def observation(self, state: np.ndarray, rng: Any, params=None) -> np.ndarray:
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Testing observation function."""
|
2022-11-18 22:25:33 +01:00
|
|
|
return state
|
|
|
|
|
2024-08-15 14:49:05 +01:00
|
|
|
def transition(
|
|
|
|
self, state: np.ndarray, action: int, rng: None, params=None
|
|
|
|
) -> np.ndarray:
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Testing transition function."""
|
2022-11-18 22:25:33 +01:00
|
|
|
return state + np.array([0, action], dtype=np.float32)
|
|
|
|
|
2024-06-07 20:16:38 +00:00
|
|
|
def reward(
|
2024-08-15 14:49:05 +01:00
|
|
|
self,
|
|
|
|
state: np.ndarray,
|
|
|
|
action: int,
|
|
|
|
next_state: np.ndarray,
|
|
|
|
rng: Any,
|
|
|
|
params=None,
|
2024-06-07 20:16:38 +00:00
|
|
|
) -> float:
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Testing reward function."""
|
2022-11-18 22:25:33 +01:00
|
|
|
return 1.0 if next_state[1] > 0 else 0.0
|
|
|
|
|
2024-08-15 14:49:05 +01:00
|
|
|
def terminal(self, state: np.ndarray, rng: Any, params=None) -> bool:
|
2022-12-10 22:04:14 +00:00
|
|
|
"""Testing terminal function."""
|
2022-11-18 22:25:33 +01:00
|
|
|
return state[1] > 0
|
|
|
|
|
|
|
|
|
2022-12-10 22:04:14 +00:00
|
|
|
def test_functional_api():
|
|
|
|
"""Tests the core functional api specification using a generic testing environment."""
|
2022-11-29 23:37:53 +00:00
|
|
|
env = GenericTestFuncEnv()
|
2022-12-10 22:04:14 +00:00
|
|
|
|
2022-11-18 22:25:33 +01:00
|
|
|
state = env.initial(None)
|
2022-12-10 22:04:14 +00:00
|
|
|
|
2024-06-07 20:16:38 +00:00
|
|
|
obs = env.observation(state, None)
|
2022-12-10 22:04:14 +00:00
|
|
|
|
2022-11-18 22:25:33 +01:00
|
|
|
assert state.shape == (2,)
|
|
|
|
assert state.dtype == np.float32
|
|
|
|
assert obs.shape == (2,)
|
|
|
|
assert obs.dtype == np.float32
|
|
|
|
assert np.allclose(obs, state)
|
|
|
|
|
|
|
|
actions = [-1, -2, -5, 3, 5, 2]
|
|
|
|
for i, action in enumerate(actions):
|
|
|
|
next_state = env.transition(state, action, None)
|
|
|
|
assert next_state.shape == (2,)
|
|
|
|
assert next_state.dtype == np.float32
|
|
|
|
assert np.allclose(next_state, state + np.array([0, action]))
|
|
|
|
|
2024-06-07 20:16:38 +00:00
|
|
|
observation = env.observation(next_state, None)
|
2022-11-18 22:25:33 +01:00
|
|
|
assert observation.shape == (2,)
|
|
|
|
assert observation.dtype == np.float32
|
|
|
|
assert np.allclose(observation, next_state)
|
|
|
|
|
2024-06-07 20:16:38 +00:00
|
|
|
reward = env.reward(state, action, next_state, None)
|
2022-11-18 22:25:33 +01:00
|
|
|
assert reward == (1.0 if next_state[1] > 0 else 0.0)
|
|
|
|
|
2024-06-07 20:16:38 +00:00
|
|
|
terminal = env.terminal(next_state, None)
|
2022-11-18 22:25:33 +01:00
|
|
|
assert terminal == (i == 5) # terminal state is in the final action
|
|
|
|
|
|
|
|
state = next_state
|