Files
Gymnasium/tests/envs/functional/test_core.py
Kallinteris Andreas bff88248af Add params and rng argument to all FuncEnv member functions (#900)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
Co-authored-by: Pratik Ingle <prin@itu.dk>
Co-authored-by: Jose Antonio Martin H <ja.martin.h@repsol.com>
Co-authored-by: Oli <ollihaus@t-online.de>
Co-authored-by: Jared Swift <j.w.swift@outlook.com>
Co-authored-by: Tim Schneider <mail@tim-schneider.me>
Co-authored-by: Tim Schneider <tim@robot-learning.de>
Co-authored-by: Tim Schneider <tim.schneider94@t-online.de>
Co-authored-by: Manuel Goulão <msilvagoulao@gmail.com>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
Co-authored-by: TobiasKallehauge <tkal@es.aau.dk>
Co-authored-by: Ariel Kwiatkowski <ariel.j.kwiatkowski@gmail.com>
Co-authored-by: James Mochizuki-Freeman <jameymmf@gmail.com>
2024-06-07 22:16:38 +02:00

59 lines
1.8 KiB
Python

from typing import Any, Dict, Optional
import numpy as np
from gymnasium.functional import FuncEnv
class BasicTestEnv(FuncEnv):
def __init__(self, options: Optional[Dict[str, Any]] = None):
super().__init__(options)
def initial(self, rng: Any) -> np.ndarray:
return np.array([0, 0], dtype=np.float32)
def observation(self, state: np.ndarray, rng: Any) -> np.ndarray:
return state
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
return state + np.array([0, action], dtype=np.float32)
def reward(
self, state: np.ndarray, action: int, next_state: np.ndarray, rng: Any
) -> float:
return 1.0 if next_state[1] > 0 else 0.0
def terminal(self, state: np.ndarray, rng: Any) -> bool:
return state[1] > 0
def test_api():
env = BasicTestEnv()
state = env.initial(None)
obs = env.observation(state, None)
assert state.shape == (2,)
assert state.dtype == np.float32
assert obs.shape == (2,)
assert obs.dtype == np.float32
assert np.allclose(obs, state)
actions = [-1, -2, -5, 3, 5, 2]
for i, action in enumerate(actions):
next_state = env.transition(state, action, None)
assert next_state.shape == (2,)
assert next_state.dtype == np.float32
assert np.allclose(next_state, state + np.array([0, action]))
observation = env.observation(next_state, None)
assert observation.shape == (2,)
assert observation.dtype == np.float32
assert np.allclose(observation, next_state)
reward = env.reward(state, action, next_state, None)
assert reward == (1.0 if next_state[1] > 0 else 0.0)
terminal = env.terminal(next_state, None)
assert terminal == (i == 5) # terminal state is in the final action
state = next_state