Files
Gymnasium/tests/wrappers/test_autoreset.py
John Balis 05df86e104 autoreset wrapper (#2650)
* added autoreset wrapper and tests

* added basic inline documentation for autoreset wrapper

* changes to comply with flake8 style

* redid autoreset wrapper

* compliance with flake8

* added final_info to info

* removed unnecessary override from autoreset wrapper

* fixed ordering mistake

* fixed flake8 compliance

* improved clarify of inline documentation

* changes to address code review

* changed autoreset terminal state keys, added message to key overlap check assert statement, updated autoreset wrapper docstring
2022-03-25 13:20:02 -04:00

101 lines
2.8 KiB
Python

import pytest
from typing import Optional
import numpy as np
import gym
from gym.wrappers import AutoResetWrapper
class DummyResetEnv(gym.Env):
"""
A dummy environment which returns ascending numbers starting
at 0 when self.step() is called. After the third call to self.step()
done is true. Info dicts are also returned containing the same number
returned as an observation, accessible via the key "count".
This environment is provided for the purpose of testing the
autoreset wrapper.
"""
metadata = {}
def __init__(self):
self.action_space = gym.spaces.Box(low=np.array([-1.0]), high=np.array([1.0]))
self.observation_space = gym.spaces.Box(
low=np.array([-1.0]), high=np.array([1.0])
)
self.count = 0
def step(self, action):
self.count += 1
return (
np.array([self.count]),
1 if self.count > 2 else 0,
self.count > 2,
{"count": self.count},
)
def reset(
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
self.count = 0
if not return_info:
return np.array([self.count])
else:
return np.array([self.count]), {"count": self.count}
def test_autoreset_reset_info():
env = gym.make("CartPole-v1")
env = AutoResetWrapper(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)
def test_autoreset_autoreset():
env = DummyResetEnv()
env = AutoResetWrapper(env)
obs, info = env.reset(return_info=True)
assert obs == np.array([0])
assert info == {"count": 0}
action = 1
obs, reward, done, info = env.step(action)
assert obs == np.array([1])
assert reward == 0
assert done == False
assert info == {"count": 1}
obs, reward, done, info = env.step(action)
assert obs == np.array([2])
assert done == False
assert reward == 0
assert info == {"count": 2}
obs, reward, done, info = env.step(action)
assert obs == np.array([0])
assert done == True
assert reward == 1
assert info == {
"count": 0,
"terminal_observation": np.array([3]),
"terminal_info": {"count": 3},
}
obs, reward, done, info = env.step(action)
assert obs == np.array([1])
assert reward == 0
assert done == False
assert info == {"count": 1}
obs, reward, done, info = env.step(action)
assert obs == np.array([2])
assert reward == 0
assert done == False
assert info == {"count": 2}