Files
Gymnasium/gym/wrappers/record_episode_statistics.py
Ariel Kwiatkowski 925823661d Add options to the signature of env.reset (#2515)
* First find/replace, now tests

* Fixes to the vector env

* Make seed keyword only in wrappers

* (try to) fix the bug with old environments using new wrappers (with the seed keyword)

* black

* Change **kwargs to options, try to make it work; black

* Add OrderEnforcing wrapper to wrapper exports
Add a test for compatibility with old (pybullet-like) envs

* Add OrderEnforcing wrapper to wrapper exports
Add a test for compatibility with old (pybullet-like) envs
black

* Update the env checker

* Update the env checker

* Update the env checker to use inspect (might fail tests, let's see)

* Allow the signature to include kwargs in env_checker

* Minor fix
2022-01-19 17:28:59 -05:00

56 lines
1.9 KiB
Python

import time
from collections import deque
from typing import Optional
import numpy as np
import gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = time.perf_counter()
self.episode_count = 0
self.episode_returns = None
self.episode_lengths = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs):
observations = super().reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations
def step(self, action):
observations, rewards, dones, infos = super().step(action)
self.episode_returns += rewards
self.episode_lengths += 1
if not self.is_vector_env:
infos = [infos]
dones = [dones]
for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
}
infos[i]["episode"] = episode_info
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
return (
observations,
rewards,
dones if self.is_vector_env else dones[0],
infos if self.is_vector_env else infos[0],
)